pose_demo_01 / main_func.py
Maksym-Lysyi's picture
add arrows
4daa026
import datetime
from utils import (
predict_keypoints_vitpose,
get_edge_groups,
get_series,
z_score_normalization,
modify_student_frame,
get_video_frames,
check_and_download_models,
get_dtw_mean_path,
generate_output_video,
generate_log,
write_log
)
from config import (
CONNECTIONS_VIT_FULL,
CONNECTIONS_FOR_ERROR,
EDGE_GROUPS_FOR_ERRORS,
EDGE_GROUPS_FOR_SUMMARY,
get_thresholds
)
def video_identity(
dtw_mean,
dtw_filter,
angles_sensitive,
angles_common,
angles_insensitive,
trigger_state,
show_arrows,
video_teacher,
video_student
):
# ======================================================================================
# This part is responsible for keypoints detection on two videos.
# ======================================================================================
check_and_download_models()
detection_result_teacher = predict_keypoints_vitpose(
video_path=video_teacher,
model_path="models/vitpose-b-wholebody.pth",
model_name="b",
detector_path="models/yolov8s.pt"
)
detection_result_student = predict_keypoints_vitpose(
video_path=video_student,
model_path="models/vitpose-b-wholebody.pth",
model_name="b",
detector_path="models/yolov8s.pt"
)
# ======================================================================================
# Here we perform transformations of keypoints to angles between keypoints and normalize them.
# ======================================================================================
detection_result_teacher_angles = get_series(detection_result_teacher[:, :,:-1], EDGE_GROUPS_FOR_ERRORS).T
detection_result_student_angles = get_series(detection_result_student[:, :,:-1], EDGE_GROUPS_FOR_ERRORS).T
edge_groups_for_dtw = get_edge_groups(CONNECTIONS_VIT_FULL)
serieses_teacher = get_series(detection_result_teacher[:, :,:-1], edge_groups_for_dtw)
serieses_student = get_series(detection_result_student[:, :,:-1], edge_groups_for_dtw)
serieses_teacher = z_score_normalization(serieses_teacher)
serieses_student = z_score_normalization(serieses_student)
# ======================================================================================
# Finding of frame to frame mean alignment with DTW algorithm.
# ======================================================================================
alignments = get_dtw_mean_path(serieses_teacher, serieses_student, dtw_mean, dtw_filter)
# ======================================================================================
# Adding visual marks on student's video, speed alignment and error log generation.
# ======================================================================================
trigger_1 = []
trigger_2 = []
save_teacher_frames = []
save_student_frames = []
all_text_summaries = []
video_teacher_loaded = get_video_frames(video_teacher)
video_student_loaded = get_video_frames(video_student)
threshouds_for_errors = get_thresholds(angles_sensitive, angles_common, angles_insensitive)
for idx, alignment in enumerate(alignments):
frame_student_out, frame_teacher_out, trigger_1, trigger_2, text_info_summary = modify_student_frame(
detection_result_teacher=detection_result_teacher,
detection_result_student=detection_result_student,
detection_result_teacher_angles=detection_result_teacher_angles,
detection_result_student_angles=detection_result_student_angles,
video_teacher=video_teacher_loaded,
video_student=video_student_loaded,
alignment_frames=alignment,
edge_groups=EDGE_GROUPS_FOR_ERRORS,
connections=CONNECTIONS_FOR_ERROR,
thresholds=threshouds_for_errors,
previously_trigered=trigger_1,
previously_trigered_2=trigger_2,
triger_state=trigger_state,
show_arrows=show_arrows,
text_dictionary=EDGE_GROUPS_FOR_SUMMARY
)
save_teacher_frames.append(frame_teacher_out)
save_student_frames.append(frame_student_out)
all_text_summaries.extend([(log, idx, arrow) for (log, arrow) in text_info_summary])
# ======================================================================================
# create files for downloading and displaying.
# ======================================================================================
timestamp_str = datetime.datetime.now().strftime("%Y_%m-%d_%H_%M_%S")
video_path = generate_output_video(save_teacher_frames, save_student_frames, timestamp_str)
general_summary = generate_log(all_text_summaries)
log_path = write_log(
timestamp_str,
dtw_mean,
dtw_filter,
angles_sensitive,
angles_common,
angles_insensitive,
trigger_state,
general_summary
)
return video_path, general_summary, log_path