pose_demo_01 / utils.py
Maksym-Lysyi's picture
add arrows
4daa026
raw
history blame
18.1 kB
import cv2
import numpy as np
import matplotlib.pyplot as plt
from dtaidistance import dtw
from easy_ViTPose.inference import VitInference
import os
import requests
from pathlib import Path
from datetime import timedelta
from scipy.signal import savgol_filter
from scipy.stats import mstats
def predict_keypoints_vitpose(
video_path,
model_path,
model_name,
detector_path,
display_video=False
):
model = VitInference(
model=model_path,
yolo=detector_path,
model_name=model_name,
det_class=None,
dataset=None,
yolo_size=320,
is_video=False,
single_pose=False,
yolo_step=1
)
cap = cv2.VideoCapture(video_path)
detection_results = []
while True:
ret, frame = cap.read()
if not ret:
print(f"Keypoints were extracted from {video_path}")
break
frame = cv2.resize(frame, (1280, 720))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_keypoints = model.inference(frame)
if 0 in frame_keypoints:
detection_results.append(frame_keypoints[0])
if display_video:
frame = model.draw(False, False, 0.5)[..., ::-1]
if display_video:
cv2.imshow('preview', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if display_video:
cap.release()
cv2.destroyAllWindows()
return np.array(detection_results)
def get_point_list_vitpose(detection_results):
return np.array(detection_results)[:, :, :-1]
def get_edge_groups(connections):
all_pairs = []
for i in range(len(connections)):
pairs = []
init_con = connections[i]
for k in range(len(connections)):
if k == i:
pass
candidat_con = connections[k]
point_1_init, point_2_init = init_con
point_1_candidat, point_2_candidat = candidat_con
if point_1_candidat == point_1_init or point_1_candidat == point_2_init or point_2_candidat == point_1_init or point_2_candidat == point_2_init:
pairs.append([init_con, candidat_con])
all_pairs.append(pairs)
all_point_for_edges = []
for set_of_pairs in all_pairs:
clean_pairs = []
for pair in set_of_pairs:
pair_a = pair[0]
pair_b = pair[1]
if len(list(set(pair_a + pair_b))) == 3:
center = int(list(set(pair_a) & set(pair_b))[0])
edges = list(set(pair_a) ^ set(pair_b))
points_for_edge = [edges[0], center, edges[1]]
clean_pairs.append(points_for_edge)
all_point_for_edges.extend(clean_pairs)
unique_set = set()
unique_list = []
for sublist in all_point_for_edges:
sublist_tuple = tuple(sublist)
if sublist_tuple not in unique_set:
unique_set.add(sublist_tuple)
unique_list.append(sublist)
unique_list.sort()
return unique_list
def calculate_angle(A, B, C):
A = np.round(np.array(A), decimals=3)
B = np.round(np.array(B), decimals=3)
C = np.round(np.array(C), decimals=3)
BA = A - B
BC = C - B
cosine_angle = np.dot(BA, BC) / ((np.linalg.norm(BA) * np.linalg.norm(BC)))
cosine_angle = np.clip(cosine_angle, -1, 1)
angle = np.arccos(cosine_angle)
if np.isnan(angle):
print(f"Invalid angle calculation.\n{A} \n{B} \n{C}")
minimum = np.min(np.array((np.linalg.norm(BA), np.linalg.norm(BC))))
return np.degrees(angle), minimum
def compute_all_angels(keypoints, edge_groups):
all_angles = []
for group in edge_groups:
A = keypoints[group[0]]
B = keypoints[group[1]]
C = keypoints[group[2]]
angle, minimum = calculate_angle(A, B, C)
all_angles.append([angle, minimum])
return np.array(all_angles)
def xy2phi(points_result, connections):
edge_groups = get_edge_groups(connections)
new_array = np.zeros((points_result.shape[0], len(edge_groups), 1))
for idx, frame in enumerate(points_result):
all_angels = compute_all_angels(keypoints=frame, edge_groups=edge_groups)[:, 0]
new_array[idx, :, :] = all_angels.reshape((len(edge_groups), 1))
return new_array
def get_series(point_list, edge_groups):
list_of_series = []
for edge_group in edge_groups:
keypoint_1, keypoint_2, keypoint_3 = edge_group
relevant_point_list = point_list[:, (keypoint_1, keypoint_2, keypoint_3), :]
series = []
for frame in relevant_point_list:
angle, _ = calculate_angle(frame[0, :], frame[1, :], frame[2, :])
series.append(angle)
list_of_series.append(series)
return np.array(list_of_series)
def plot_serieses(series_1, series_2):
plt.figure(dpi=150, figsize=(12, 5))
plt.plot(series_1, label='Video #1', lw=1)
plt.plot(series_2, label='Video #2', lw=1)
plt.axis("on")
plt.grid(True)
plt.xlabel("frames")
plt.ylabel("angles")
plt.legend()
def z_score_normalization(serieses, axis_for_znorm=1):
serieses_mean = np.mean(serieses, axis=axis_for_znorm, keepdims=True)
serieses_std = np.std(serieses, axis=axis_for_znorm, keepdims=True)
serieses_normalized = (serieses - serieses_mean) / serieses_std
return serieses_normalized
def get_dtw_mean_path(serieses_teacher, serieses_student, dtw_mean, dtw_filter):
list_of_paths = []
for idx in range(len(serieses_teacher)):
series_teacher = np.array(serieses_teacher[idx])
series_student = np.array(serieses_student[idx])
_ , paths = dtw.warping_paths(series_teacher, series_student, window=50)
path = dtw.best_path(paths)
list_of_paths.append(path)
all_dtw_tupples = []
for path in list_of_paths:
all_dtw_tupples.extend(path)
mean_path = []
for student_frame in range(len(serieses_student[0])):
frame_from_teacher = []
for frame_teacher in all_dtw_tupples:
if frame_teacher[1] == student_frame:
frame_from_teacher.append(frame_teacher[0])
mean_path.append((int(mstats.winsorize(np.array(frame_from_teacher), limits=[dtw_mean, dtw_mean]).mean()), student_frame))
path_array = np.array(mean_path)
smoothed_data = savgol_filter(path_array, window_length=dtw_filter, polyorder=0, axis=0)
path_array = np.array(smoothed_data).astype(int)
alignments = np.unique(path_array, axis=0) # TODO check if this correct
return alignments
def modify_student_frame(
detection_result_teacher,
detection_result_student,
detection_result_teacher_angles,
detection_result_student_angles,
video_teacher,
video_student,
alignment_frames,
edge_groups,
connections,
thresholds,
previously_trigered,
previously_trigered_2,
triger_state,
show_arrows,
text_dictionary,
):
arrows_bgr = (175, 75, 190)
arrows_sz = 3
skeleton_bgr = (0, 0, 255)
skeleton_sz = 3
frame_copy = video_student[alignment_frames[1]]
frame_teacher_copy = video_teacher[alignment_frames[0]]
frame_errors = np.abs(detection_result_teacher_angles[alignment_frames[0]] - detection_result_student_angles[alignment_frames[1]])
edge_groups_as_keys = [tuple(group) for group in edge_groups]
edge_groups2errors = dict(zip(edge_groups_as_keys, frame_errors))
edge_groups2thresholds = dict(zip(edge_groups_as_keys, thresholds))
edge_groups_relevant = [edge_group[1:] for edge_group in edge_groups]
text_info = []
trigered_connections = []
trigered_connections2 = []
for connection in connections:
edges_for_given_connection = [edge for edge in edge_groups2errors if connection[0] in edge or connection[1] in edge]
for edge in edges_for_given_connection:
check_threshold = edge_groups2errors[edge] > edge_groups2thresholds[edge]
check_certain = True
for keypoint in edge:
prob = detection_result_student[:, :,-1][alignment_frames[1]][keypoint]
if prob < 0.7:
check_certain = False
relevant_plane = [connection[0], connection[1]] in edge_groups_relevant or [connection[1], connection[0]] in edge_groups_relevant
if check_threshold and check_certain and relevant_plane:
point1, point2, point2_t = align_points(
detection_result_student,
detection_result_teacher,
alignment_frames,
edge
)
arrow = get_arrow_direction(point2, point2_t)
if triger_state == "one":
_ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)
if show_arrows:
_ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz)
if (connection[0], connection[1]) in text_dictionary:
text_info.append((text_dictionary[(connection[0], connection[1])], arrow))
if (connection[1], connection[0]) in text_dictionary:
text_info.append((text_dictionary[(connection[1], connection[0])], arrow))
if triger_state == "two":
trigered_connections.append((connection[0], connection[1]))
if (connection[0], connection[1]) in previously_trigered:
_ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)
if show_arrows:
_ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz)
if (connection[0], connection[1]) in text_dictionary:
text_info.append((text_dictionary[(connection[0], connection[1])], arrow))
if (connection[1], connection[0]) in text_dictionary:
text_info.append((text_dictionary[(connection[1], connection[0])], arrow))
if triger_state == "three":
trigered_connections.append((connection[0], connection[1]))
if (connection[0], connection[1]) in previously_trigered:
trigered_connections2.append((connection[0], connection[1]))
if (connection[0], connection[1]) in previously_trigered_2:
_ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)
if show_arrows:
_ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz)
if (connection[0], connection[1]) in text_dictionary:
text_info.append((text_dictionary[(connection[0], connection[1])], arrow))
if (connection[1], connection[0]) in text_dictionary:
text_info.append((text_dictionary[(connection[1], connection[0])], arrow))
return frame_copy, frame_teacher_copy, list(set(trigered_connections)), list(set(trigered_connections2)), text_info
def get_video_frames(video_path):
cap = cv2.VideoCapture(video_path)
video = []
while True:
ret, frame = cap.read()
if not ret:
print(f"Video {video_path} was loaded")
break
frame = cv2.resize(frame, (1280, 720))
video.append(frame)
return np.array(video)
def download_file(url, save_path):
response = requests.get(url, stream=True)
response.raise_for_status()
with open(save_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
def check_and_download_models():
# vit_model_s_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-s-wholebody.pth?download=true"
vit_model_b_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-b-wholebody.pth?download=true"
# vit_model_l_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/torch/wholebody/vitpose-l-wholebody.pth?download=true"
yolo_model_url = "https://huggingface.co/JunkyByte/easy_ViTPose/resolve/main/yolov8/yolov8s.pt?download=true"
# vit_model_s_path = "models/vitpose-s-wholebody.pth"
vit_model_b_path = "models/vitpose-b-wholebody.pth"
# vit_model_l_path = "models/vitpose-l-wholebody.pth"
yolo_model_path = "models/yolov8s.pt"
# Path(os.path.dirname(vit_model_s_path)).mkdir(parents=True, exist_ok=True)
Path(os.path.dirname(vit_model_b_path)).mkdir(parents=True, exist_ok=True)
# Path(os.path.dirname(vit_model_l_path)).mkdir(parents=True, exist_ok=True)
Path(os.path.dirname(yolo_model_path)).mkdir(parents=True, exist_ok=True)
# if not os.path.exists(vit_model_s_path):
# print("Downloading ViT-Pose-s model...")
# download_file(vit_model_s_url, vit_model_s_path)
# print("ViT-Pose-s model was downloaded.")
if not os.path.exists(vit_model_b_path):
print("Downloading ViT-Pose-b model...")
download_file(vit_model_b_url, vit_model_b_path)
print("ViT-Pose-b model was downloaded.")
# if not os.path.exists(vit_model_l_path):
# print("Downloading ViT-Pose-l model...")
# download_file(vit_model_l_url, vit_model_l_path)
# print("ViT-Pose-l model was downloaded.")
if not os.path.exists(yolo_model_path):
print("Downloading YOLO model...")
download_file(yolo_model_url, yolo_model_path)
print("YOLO model was downloaded.")
def generate_output_video(teacher_frames, student_frames, timestamp_str):
teacher_frames = np.array(teacher_frames)
student_frames = np.array(student_frames)
teacher_frames_resized = np.array([cv2.resize(frame, (1280, 720)) for frame in teacher_frames])
student_frames_resized = np.array([cv2.resize(frame, (1280, 720)) for frame in student_frames])
concat_video = np.concatenate((teacher_frames_resized, student_frames_resized), axis=2)
concat_video = np.array(concat_video)
root_dir = "videos"
if not os.path.exists(root_dir):
os.makedirs(root_dir)
video_path = f"{root_dir}/pose_{timestamp_str}.mp4"
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 30, (1280 * 2, 720))
for frame in concat_video:
out.write(frame)
out.release()
return video_path
def generate_log(all_text_summaries):
all_text_summaries_clean = list(set(all_text_summaries))
all_text_summaries_clean.sort(key=lambda x: x[1])
general_summary = []
for log in all_text_summaries_clean:
comment, frame, arrow = log
total_seconds = frame / 30
general_summary.append(f"{comment}. Direction: {arrow}. Video time: {str(timedelta(seconds=total_seconds))[3:-4]}")
general_summary = "\n".join(general_summary)
return general_summary
def write_log(
timestamp_str,
dtw_mean,
dtw_filter,
angles_sensitive,
angles_common,
angles_insensitive,
trigger_state,
general_summary
):
logs_dir = "logs"
if not os.path.exists(logs_dir):
os.makedirs(logs_dir)
log_path = f"{logs_dir}/log_{timestamp_str}.txt"
content = f"""
Settings:
Dynamic Time Warping:
- Winsorize mean: {dtw_mean}
- Savitzky-Golay Filter: {dtw_filter}
Thresholds:
- Sensitive: {angles_sensitive}
- Standart: {angles_common}
- Insensitive: {angles_insensitive}
Patience:
- trigger count: {trigger_state}
Error logs:
{general_summary}
"""
with open(log_path, "w") as file:
file.write(content)
print(f"log {log_path} was created.")
return log_path
def angle_between(v1, v2):
return np.arctan2(v2[1], v2[0]) - np.arctan2(v1[1], v1[0])
def align_points(detection_result_student, detection_result_teacher, alignment_frames, edge):
point0 = detection_result_student[alignment_frames[1], edge[0], :-1].astype(int)[::-1]
point1 = detection_result_student[alignment_frames[1], edge[1], :-1].astype(int)[::-1]
point2 = detection_result_student[alignment_frames[1], edge[2], :-1].astype(int)[::-1]
point0_t = detection_result_teacher[alignment_frames[0], edge[0], :-1].astype(int)[::-1]
point1_t = detection_result_teacher[alignment_frames[0], edge[1], :-1].astype(int)[::-1]
point2_t = detection_result_teacher[alignment_frames[0], edge[2], :-1].astype(int)[::-1]
translation = point0 - point0_t
point0_t += translation
point1_t += translation
point2_t += translation
BsA = point1 - point0
BtA = point1_t - point0
theta = angle_between(BtA, BsA)
R = np.array([
[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]
])
point1_t = np.dot(R, (point1_t - point0).T).T + point0
point2_t = np.dot(R, (point2_t - point0).T).T + point0
point2_t = point2_t.astype(int)
return point1, point2, point2_t
def get_arrow_direction(A, B):
translation_vector = B - A
angle_deg = np.degrees(np.arctan2(translation_vector[0], translation_vector[1]))
match angle_deg:
case angle if -22.5 <= angle < 22.5:
arrow = "⬆"
case angle if 22.5 <= angle < 67.5:
arrow = "⬈"
case angle if 67.5 <= angle < 112.5:
arrow = "➡"
case angle if 112.5 <= angle < 157.5:
arrow = "⬊"
case angle if 157.5 <= angle or angle < -157.5:
arrow = "⬇"
case angle if -157.5 <= angle < -112.5:
arrow = "⬋"
case angle if -112.5 <= angle < -67.5:
arrow = "⬅"
case angle if -67.5 <= angle < -22.5:
arrow = "⬉"
case _:
arrow = ""
return arrow