Spaces:
Running
Running
import datetime | |
from utils import ( | |
predict_keypoints_vitpose, | |
get_edge_groups, | |
get_series, | |
z_score_normalization, | |
modify_student_frame, | |
get_video_frames, | |
check_and_download_models, | |
get_dtw_mean_path, | |
generate_output_video, | |
generate_log, | |
write_log | |
) | |
from config import ( | |
CONNECTIONS_VIT_FULL, | |
CONNECTIONS_FOR_ERROR, | |
EDGE_GROUPS_FOR_ERRORS, | |
EDGE_GROUPS_FOR_SUMMARY, | |
get_thresholds | |
) | |
def video_identity( | |
dtw_mean, | |
dtw_filter, | |
angles_sensitive, | |
angles_common, | |
angles_insensitive, | |
trigger_state, | |
show_arrows, | |
video_teacher, | |
video_student | |
): | |
# ====================================================================================== | |
# This part is responsible for keypoints detection on two videos. | |
# ====================================================================================== | |
check_and_download_models() | |
detection_result_teacher = predict_keypoints_vitpose( | |
video_path=video_teacher, | |
model_path="models/vitpose-b-wholebody.pth", | |
model_name="b", | |
detector_path="models/yolov8s.pt" | |
) | |
detection_result_student = predict_keypoints_vitpose( | |
video_path=video_student, | |
model_path="models/vitpose-b-wholebody.pth", | |
model_name="b", | |
detector_path="models/yolov8s.pt" | |
) | |
# ====================================================================================== | |
# Here we perform transformations of keypoints to angles between keypoints and normalize them. | |
# ====================================================================================== | |
detection_result_teacher_angles = get_series(detection_result_teacher[:, :,:-1], EDGE_GROUPS_FOR_ERRORS).T | |
detection_result_student_angles = get_series(detection_result_student[:, :,:-1], EDGE_GROUPS_FOR_ERRORS).T | |
edge_groups_for_dtw = get_edge_groups(CONNECTIONS_VIT_FULL) | |
serieses_teacher = get_series(detection_result_teacher[:, :,:-1], edge_groups_for_dtw) | |
serieses_student = get_series(detection_result_student[:, :,:-1], edge_groups_for_dtw) | |
serieses_teacher = z_score_normalization(serieses_teacher) | |
serieses_student = z_score_normalization(serieses_student) | |
# ====================================================================================== | |
# Finding of frame to frame mean alignment with DTW algorithm. | |
# ====================================================================================== | |
alignments = get_dtw_mean_path(serieses_teacher, serieses_student, dtw_mean, dtw_filter) | |
# ====================================================================================== | |
# Adding visual marks on student's video, speed alignment and error log generation. | |
# ====================================================================================== | |
trigger_1 = [] | |
trigger_2 = [] | |
save_teacher_frames = [] | |
save_student_frames = [] | |
all_text_summaries = [] | |
video_teacher_loaded = get_video_frames(video_teacher) | |
video_student_loaded = get_video_frames(video_student) | |
threshouds_for_errors = get_thresholds(angles_sensitive, angles_common, angles_insensitive) | |
for idx, alignment in enumerate(alignments): | |
frame_student_out, frame_teacher_out, trigger_1, trigger_2, text_info_summary = modify_student_frame( | |
detection_result_teacher=detection_result_teacher, | |
detection_result_student=detection_result_student, | |
detection_result_teacher_angles=detection_result_teacher_angles, | |
detection_result_student_angles=detection_result_student_angles, | |
video_teacher=video_teacher_loaded, | |
video_student=video_student_loaded, | |
alignment_frames=alignment, | |
edge_groups=EDGE_GROUPS_FOR_ERRORS, | |
connections=CONNECTIONS_FOR_ERROR, | |
thresholds=threshouds_for_errors, | |
previously_trigered=trigger_1, | |
previously_trigered_2=trigger_2, | |
triger_state=trigger_state, | |
show_arrows=show_arrows, | |
text_dictionary=EDGE_GROUPS_FOR_SUMMARY | |
) | |
save_teacher_frames.append(frame_teacher_out) | |
save_student_frames.append(frame_student_out) | |
all_text_summaries.extend([(log, idx, arrow) for (log, arrow) in text_info_summary]) | |
# ====================================================================================== | |
# create files for downloading and displaying. | |
# ====================================================================================== | |
timestamp_str = datetime.datetime.now().strftime("%Y_%m-%d_%H_%M_%S") | |
video_path = generate_output_video(save_teacher_frames, save_student_frames, timestamp_str) | |
general_summary = generate_log(all_text_summaries) | |
log_path = write_log( | |
timestamp_str, | |
dtw_mean, | |
dtw_filter, | |
angles_sensitive, | |
angles_common, | |
angles_insensitive, | |
trigger_state, | |
general_summary | |
) | |
return video_path, general_summary, log_path | |