Spaces:
Running
Running
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from dtaidistance import dtw | |
from easy_ViTPose.inference import VitInference | |
import os | |
import requests | |
from pathlib import Path | |
from datetime import timedelta | |
from scipy.signal import savgol_filter | |
from scipy.stats import mstats | |
def predict_keypoints_vitpose( | |
video_path, | |
model_path, | |
model_name, | |
detector_path, | |
display_video=False | |
): | |
model = VitInference( | |
model=model_path, | |
yolo=detector_path, | |
model_name=model_name, | |
det_class=None, | |
dataset=None, | |
yolo_size=320, | |
is_video=False, | |
single_pose=False, | |
yolo_step=1 | |
) | |
cap = cv2.VideoCapture(video_path) | |
detection_results = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
print(f"Keypoints were extracted from {video_path}") | |
break | |
frame = cv2.resize(frame, (1280, 720)) | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame_keypoints = model.inference(frame) | |
if 0 in frame_keypoints: | |
detection_results.append(frame_keypoints[0]) | |
if display_video: | |
frame = model.draw(False, False, 0.5)[..., ::-1] | |
if display_video: | |
cv2.imshow('preview', frame) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
if display_video: | |
cap.release() | |
cv2.destroyAllWindows() | |
return np.array(detection_results) | |
def get_point_list_vitpose(detection_results): | |
return np.array(detection_results)[:, :, :-1] | |
def get_edge_groups(connections): | |
all_pairs = [] | |
for i in range(len(connections)): | |
pairs = [] | |
init_con = connections[i] | |
for k in range(len(connections)): | |
if k == i: | |
pass | |
candidat_con = connections[k] | |
point_1_init, point_2_init = init_con | |
point_1_candidat, point_2_candidat = candidat_con | |
if point_1_candidat == point_1_init or point_1_candidat == point_2_init or point_2_candidat == point_1_init or point_2_candidat == point_2_init: | |
pairs.append([init_con, candidat_con]) | |
all_pairs.append(pairs) | |
all_point_for_edges = [] | |
for set_of_pairs in all_pairs: | |
clean_pairs = [] | |
for pair in set_of_pairs: | |
pair_a = pair[0] | |
pair_b = pair[1] | |
if len(list(set(pair_a + pair_b))) == 3: | |
center = int(list(set(pair_a) & set(pair_b))[0]) | |
edges = list(set(pair_a) ^ set(pair_b)) | |
points_for_edge = [edges[0], center, edges[1]] | |
clean_pairs.append(points_for_edge) | |
all_point_for_edges.extend(clean_pairs) | |
unique_set = set() | |
unique_list = [] | |
for sublist in all_point_for_edges: | |
sublist_tuple = tuple(sublist) | |
if sublist_tuple not in unique_set: | |
unique_set.add(sublist_tuple) | |
unique_list.append(sublist) | |
unique_list.sort() | |
return unique_list | |
def calculate_angle(A, B, C): | |
A = np.round(np.array(A), decimals=3) | |
B = np.round(np.array(B), decimals=3) | |
C = np.round(np.array(C), decimals=3) | |
BA = A - B | |
BC = C - B | |
cosine_angle = np.dot(BA, BC) / ((np.linalg.norm(BA) * np.linalg.norm(BC))) | |
cosine_angle = np.clip(cosine_angle, -1, 1) | |
angle = np.arccos(cosine_angle) | |
if np.isnan(angle): | |
print(f"Invalid angle calculation.\n{A} \n{B} \n{C}") | |
minimum = np.min(np.array((np.linalg.norm(BA), np.linalg.norm(BC)))) | |
return np.degrees(angle), minimum | |
def compute_all_angels(keypoints, edge_groups): | |
all_angles = [] | |
for group in edge_groups: | |
A = keypoints[group[0]] | |
B = keypoints[group[1]] | |
C = keypoints[group[2]] | |
angle, minimum = calculate_angle(A, B, C) | |
all_angles.append([angle, minimum]) | |
return np.array(all_angles) | |
def xy2phi(points_result, connections): | |
edge_groups = get_edge_groups(connections) | |
new_array = np.zeros((points_result.shape[0], len(edge_groups), 1)) | |
for idx, frame in enumerate(points_result): | |
all_angels = compute_all_angels(keypoints=frame, edge_groups=edge_groups)[:, 0] | |
new_array[idx, :, :] = all_angels.reshape((len(edge_groups), 1)) | |
return new_array | |
def get_series(point_list, edge_groups): | |
list_of_series = [] | |
for edge_group in edge_groups: | |
keypoint_1, keypoint_2, keypoint_3 = edge_group | |
relevant_point_list = point_list[:, (keypoint_1, keypoint_2, keypoint_3), :] | |
series = [] | |
for frame in relevant_point_list: | |
angle, _ = calculate_angle(frame[0, :], frame[1, :], frame[2, :]) | |
series.append(angle) | |
list_of_series.append(series) | |
return np.array(list_of_series) | |
def plot_serieses(series_1, series_2): | |
plt.figure(dpi=150, figsize=(12, 5)) | |
plt.plot(series_1, label='Video #1', lw=1) | |
plt.plot(series_2, label='Video #2', lw=1) | |
plt.axis("on") | |
plt.grid(True) | |
plt.xlabel("frames") | |
plt.ylabel("angles") | |
plt.legend() | |
def z_score_normalization(serieses, axis_for_znorm=1): | |
serieses_mean = np.mean(serieses, axis=axis_for_znorm, keepdims=True) | |
serieses_std = np.std(serieses, axis=axis_for_znorm, keepdims=True) | |
serieses_normalized = (serieses - serieses_mean) / serieses_std | |
return serieses_normalized | |
def get_dtw_mean_path(serieses_teacher, serieses_student, dtw_mean, dtw_filter): | |
list_of_paths = [] | |
for idx in range(len(serieses_teacher)): | |
series_teacher = np.array(serieses_teacher[idx]) | |
series_student = np.array(serieses_student[idx]) | |
_ , paths = dtw.warping_paths(series_teacher, series_student, window=50) | |
path = dtw.best_path(paths) | |
list_of_paths.append(path) | |
all_dtw_tupples = [] | |
for path in list_of_paths: | |
all_dtw_tupples.extend(path) | |
mean_path = [] | |
for student_frame in range(len(serieses_student[0])): | |
frame_from_teacher = [] | |
for frame_teacher in all_dtw_tupples: | |
if frame_teacher[1] == student_frame: | |
frame_from_teacher.append(frame_teacher[0]) | |
mean_path.append((int(mstats.winsorize(np.array(frame_from_teacher), limits=[dtw_mean, dtw_mean]).mean()), student_frame)) | |
path_array = np.array(mean_path) | |
smoothed_data = savgol_filter(path_array, window_length=dtw_filter, polyorder=0, axis=0) | |
path_array = np.array(smoothed_data).astype(int) | |
alignments = np.unique(path_array, axis=0) # TODO check if this correct | |
return alignments | |
def modify_student_frame( | |
detection_result_teacher, | |
detection_result_student, | |
detection_result_teacher_angles, | |
detection_result_student_angles, | |
video_teacher, | |
video_student, | |
alignment_frames, | |
edge_groups, | |
connections, | |
thresholds, | |
previously_trigered, | |
previously_trigered_2, | |
triger_state, | |
show_arrows, | |
text_dictionary, | |
): | |
arrows_bgr = (175, 75, 190) | |
arrows_sz = 3 | |
skeleton_bgr = (0, 0, 255) | |
skeleton_sz = 3 | |
frame_copy = video_student[alignment_frames[1]] | |
frame_teacher_copy = video_teacher[alignment_frames[0]] | |
frame_errors = np.abs(detection_result_teacher_angles[alignment_frames[0]] - detection_result_student_angles[alignment_frames[1]]) | |
edge_groups_as_keys = [tuple(group) for group in edge_groups] | |
edge_groups2errors = dict(zip(edge_groups_as_keys, frame_errors)) | |
edge_groups2thresholds = dict(zip(edge_groups_as_keys, thresholds)) | |
edge_groups_relevant = [edge_group[1:] for edge_group in edge_groups] | |
text_info = [] | |
trigered_connections = [] | |
trigered_connections2 = [] | |
for connection in connections: | |
edges_for_given_connection = [edge for edge in edge_groups2errors if connection[0] in edge or connection[1] in edge] | |
for edge in edges_for_given_connection: | |
check_threshold = edge_groups2errors[edge] > edge_groups2thresholds[edge] | |
check_certain = True | |
for keypoint in edge: | |
prob = detection_result_student[:, :,-1][alignment_frames[1]][keypoint] | |
if prob < 0.7: | |
check_certain = False | |
relevant_plane = [connection[0], connection[1]] in edge_groups_relevant or [connection[1], connection[0]] in edge_groups_relevant | |
if check_threshold and check_certain and relevant_plane: | |
point1, point2, point2_t = align_points( | |
detection_result_student, | |
detection_result_teacher, | |
alignment_frames, | |
edge | |
) | |
arrow = get_arrow_direction(point2, point2_t) | |
if triger_state == "one": | |
_ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz) | |
if show_arrows: | |
_ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz) | |
if (connection[0], connection[1]) in text_dictionary: | |
text_info.append((text_dictionary[(connection[0], connection[1])], arrow)) | |
if (connection[1], connection[0]) in text_dictionary: | |
text_info.append((text_dictionary[(connection[1], connection[0])], arrow)) | |
if triger_state == "two": | |
trigered_connections.append((connection[0], connection[1])) | |
if (connection[0], connection[1]) in previously_trigered: | |
_ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz) | |
if show_arrows: | |
_ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz) | |
if (connection[0], connection[1]) in text_dictionary: | |
text_info.append((text_dictionary[(connection[0], connection[1])], arrow)) | |
if (connection[1], connection[0]) in text_dictionary: | |
text_info.append((text_dictionary[(connection[1], connection[0])], arrow)) | |
if triger_state == "three": | |
trigered_connections.append((connection[0], connection[1])) | |
if (connection[0], connection[1]) in previously_trigered: | |
trigered_connections2.append((connection[0], connection[1])) | |
if (connection[0], connection[1]) in previously_trigered_2: | |
_ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz) | |
if show_arrows: | |
_ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz) | |
if (connection[0], connection[1]) in text_dictionary: | |
text_info.append((text_dictionary[(connection[0], connection[1])], arrow)) | |
if (connection[1], connection[0]) in text_dictionary: | |
text_info.append((text_dictionary[(connection[1], connection[0])], arrow)) | |
return frame_copy, frame_teacher_copy, list(set(trigered_connections)), list(set(trigered_connections2)), text_info | |
def get_video_frames(video_path): | |
cap = cv2.VideoCapture(video_path) | |
video = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
print(f"Video {video_path} was loaded") | |
break | |
frame = cv2.resize(frame, (1280, 720)) | |
video.append(frame) | |
return np.array(video) | |
def download_file(url, save_path): | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
with open(save_path, 'wb') as file: | |
for chunk in response.iter_content(chunk_size=8192): | |
file.write(chunk) | |
def check_and_download_models(): | |
# vit_model_s_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-s-wholebody.pth?download=true" | |
vit_model_b_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-b-wholebody.pth?download=true" | |
# vit_model_l_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-l-wholebody.pth?download=true" | |
yolo_model_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/yolov8/yolov8s.pt?download=true" | |
# vit_model_s_path = "models/vitpose-s-wholebody.pth" | |
vit_model_b_path = "models/vitpose-b-wholebody.pth" | |
# vit_model_l_path = "models/vitpose-l-wholebody.pth" | |
yolo_model_path = "models/yolov8s.pt" | |
# Path(os.path.dirname(vit_model_s_path)).mkdir(parents=True, exist_ok=True) | |
Path(os.path.dirname(vit_model_b_path)).mkdir(parents=True, exist_ok=True) | |
# Path(os.path.dirname(vit_model_l_path)).mkdir(parents=True, exist_ok=True) | |
Path(os.path.dirname(yolo_model_path)).mkdir(parents=True, exist_ok=True) | |
# if not os.path.exists(vit_model_s_path): | |
# print("Downloading ViT-Pose-s model...") | |
# download_file(vit_model_s_url, vit_model_s_path) | |
# print("ViT-Pose-s model was downloaded.") | |
if not os.path.exists(vit_model_b_path): | |
print("Downloading ViT-Pose-b model...") | |
download_file(vit_model_b_url, vit_model_b_path) | |
print("ViT-Pose-b model was downloaded.") | |
# if not os.path.exists(vit_model_l_path): | |
# print("Downloading ViT-Pose-l model...") | |
# download_file(vit_model_l_url, vit_model_l_path) | |
# print("ViT-Pose-l model was downloaded.") | |
if not os.path.exists(yolo_model_path): | |
print("Downloading YOLO model...") | |
download_file(yolo_model_url, yolo_model_path) | |
print("YOLO model was downloaded.") | |
def generate_output_video(teacher_frames, student_frames, timestamp_str): | |
teacher_frames = np.array(teacher_frames) | |
student_frames = np.array(student_frames) | |
teacher_frames_resized = np.array([cv2.resize(frame, (1280, 720)) for frame in teacher_frames]) | |
student_frames_resized = np.array([cv2.resize(frame, (1280, 720)) for frame in student_frames]) | |
concat_video = np.concatenate((teacher_frames_resized, student_frames_resized), axis=2) | |
concat_video = np.array(concat_video) | |
root_dir = "videos" | |
if not os.path.exists(root_dir): | |
os.makedirs(root_dir) | |
video_path = f"{root_dir}/pose_{timestamp_str}.mp4" | |
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (1280 * 2, 720)) | |
for frame in concat_video: | |
out.write(frame) | |
out.release() | |
return video_path | |
def generate_log(all_text_summaries): | |
all_text_summaries_clean = list(set(all_text_summaries)) | |
all_text_summaries_clean.sort(key=lambda x: x[1]) | |
general_summary = [] | |
for log in all_text_summaries_clean: | |
comment, frame, arrow = log | |
total_seconds = frame / 30 | |
general_summary.append(f"{comment}. Direction: {arrow}. Video time: {str(timedelta(seconds=total_seconds))[3:-4]}") | |
general_summary = "\n".join(general_summary) | |
return general_summary | |
def write_log( | |
timestamp_str, | |
dtw_mean, | |
dtw_filter, | |
angles_sensitive, | |
angles_common, | |
angles_insensitive, | |
trigger_state, | |
general_summary | |
): | |
logs_dir = "logs" | |
if not os.path.exists(logs_dir): | |
os.makedirs(logs_dir) | |
log_path = f"{logs_dir}/log_{timestamp_str}.txt" | |
content = f""" | |
Settings: | |
Dynamic Time Warping: | |
- Winsorize mean: {dtw_mean} | |
- Savitzky-Golay Filter: {dtw_filter} | |
Thresholds: | |
- Sensitive: {angles_sensitive} | |
- Standart: {angles_common} | |
- Insensitive: {angles_insensitive} | |
Patience: | |
- trigger count: {trigger_state} | |
Error logs: | |
{general_summary} | |
""" | |
with open(log_path, "w") as file: | |
file.write(content) | |
print(f"log {log_path} was created.") | |
return log_path | |
def angle_between(v1, v2): | |
return np.arctan2(v2[1], v2[0]) - np.arctan2(v1[1], v1[0]) | |
def align_points(detection_result_student, detection_result_teacher, alignment_frames, edge): | |
point0 = detection_result_student[alignment_frames[1], edge[0], :-1].astype(int)[::-1] | |
point1 = detection_result_student[alignment_frames[1], edge[1], :-1].astype(int)[::-1] | |
point2 = detection_result_student[alignment_frames[1], edge[2], :-1].astype(int)[::-1] | |
point0_t = detection_result_teacher[alignment_frames[0], edge[0], :-1].astype(int)[::-1] | |
point1_t = detection_result_teacher[alignment_frames[0], edge[1], :-1].astype(int)[::-1] | |
point2_t = detection_result_teacher[alignment_frames[0], edge[2], :-1].astype(int)[::-1] | |
translation = point0 - point0_t | |
point0_t += translation | |
point1_t += translation | |
point2_t += translation | |
BsA = point1 - point0 | |
BtA = point1_t - point0 | |
theta = angle_between(BtA, BsA) | |
R = np.array([ | |
[np.cos(theta), -np.sin(theta)], | |
[np.sin(theta), np.cos(theta)] | |
]) | |
point1_t = np.dot(R, (point1_t - point0).T).T + point0 | |
point2_t = np.dot(R, (point2_t - point0).T).T + point0 | |
point2_t = point2_t.astype(int) | |
return point1, point2, point2_t | |
def get_arrow_direction(A, B): | |
translation_vector = B - A | |
angle_deg = np.degrees(np.arctan2(translation_vector[0], translation_vector[1])) | |
match angle_deg: | |
case angle if -22.5 <= angle < 22.5: | |
arrow = "⬆" | |
case angle if 22.5 <= angle < 67.5: | |
arrow = "⬈" | |
case angle if 67.5 <= angle < 112.5: | |
arrow = "➡" | |
case angle if 112.5 <= angle < 157.5: | |
arrow = "⬊" | |
case angle if 157.5 <= angle or angle < -157.5: | |
arrow = "⬇" | |
case angle if -157.5 <= angle < -112.5: | |
arrow = "⬋" | |
case angle if -112.5 <= angle < -67.5: | |
arrow = "⬅" | |
case angle if -67.5 <= angle < -22.5: | |
arrow = "⬉" | |
case _: | |
arrow = "" | |
return arrow |