Maksym-Lysyi commited on
Commit
4daa026
·
1 Parent(s): 5079645

add arrows

Browse files
Files changed (3) hide show
  1. app.py +4 -1
  2. main_func.py +5 -2
  3. utils.py +107 -17
app.py CHANGED
@@ -59,6 +59,9 @@ with gr.Blocks() as demo:
59
 
60
  trigger_state = gr.Radio(value="one", choices=["one", "two", "three"], label="Trigger Count")
61
 
 
 
 
62
  input_teacher = gr.Video(show_share_button=False, show_download_button=False, sources=["upload"], label="Teacher's Video")
63
  input_student = gr.Video(show_share_button=False, show_download_button=False, sources=["upload"], label="Student's Video")
64
 
@@ -141,7 +144,7 @@ with gr.Blocks() as demo:
141
 
142
  gr_button.click(
143
  fn=video_identity,
144
- inputs=[dtw_mean, dtw_filter, angles_sensitive, angles_common, angles_insensitive, trigger_state, input_teacher, input_student],
145
  outputs=[output_merged, general_log, text_log]
146
  )
147
 
 
59
 
60
  trigger_state = gr.Radio(value="one", choices=["one", "two", "three"], label="Trigger Count")
61
 
62
+ gr.Markdown("#### Plot arrows:")
63
+ show_arrows = gr.Checkbox(label="If True, arrows will be plotted on the video")
64
+
65
  input_teacher = gr.Video(show_share_button=False, show_download_button=False, sources=["upload"], label="Teacher's Video")
66
  input_student = gr.Video(show_share_button=False, show_download_button=False, sources=["upload"], label="Student's Video")
67
 
 
144
 
145
  gr_button.click(
146
  fn=video_identity,
147
+ inputs=[dtw_mean, dtw_filter, angles_sensitive, angles_common, angles_insensitive, trigger_state, show_arrows, input_teacher, input_student],
148
  outputs=[output_merged, general_log, text_log]
149
  )
150
 
main_func.py CHANGED
@@ -11,7 +11,7 @@ from utils import (
11
  get_dtw_mean_path,
12
  generate_output_video,
13
  generate_log,
14
- write_log,
15
  )
16
 
17
  from config import (
@@ -30,6 +30,7 @@ def video_identity(
30
  angles_common,
31
  angles_insensitive,
32
  trigger_state,
 
33
  video_teacher,
34
  video_student
35
  ):
@@ -91,6 +92,7 @@ def video_identity(
91
  for idx, alignment in enumerate(alignments):
92
 
93
  frame_student_out, frame_teacher_out, trigger_1, trigger_2, text_info_summary = modify_student_frame(
 
94
  detection_result_student=detection_result_student,
95
  detection_result_teacher_angles=detection_result_teacher_angles,
96
  detection_result_student_angles=detection_result_student_angles,
@@ -103,12 +105,13 @@ def video_identity(
103
  previously_trigered=trigger_1,
104
  previously_trigered_2=trigger_2,
105
  triger_state=trigger_state,
 
106
  text_dictionary=EDGE_GROUPS_FOR_SUMMARY
107
  )
108
 
109
  save_teacher_frames.append(frame_teacher_out)
110
  save_student_frames.append(frame_student_out)
111
- all_text_summaries.extend([(log, idx) for log in text_info_summary])
112
 
113
  # ======================================================================================
114
  # create files for downloading and displaying.
 
11
  get_dtw_mean_path,
12
  generate_output_video,
13
  generate_log,
14
+ write_log
15
  )
16
 
17
  from config import (
 
30
  angles_common,
31
  angles_insensitive,
32
  trigger_state,
33
+ show_arrows,
34
  video_teacher,
35
  video_student
36
  ):
 
92
  for idx, alignment in enumerate(alignments):
93
 
94
  frame_student_out, frame_teacher_out, trigger_1, trigger_2, text_info_summary = modify_student_frame(
95
+ detection_result_teacher=detection_result_teacher,
96
  detection_result_student=detection_result_student,
97
  detection_result_teacher_angles=detection_result_teacher_angles,
98
  detection_result_student_angles=detection_result_student_angles,
 
105
  previously_trigered=trigger_1,
106
  previously_trigered_2=trigger_2,
107
  triger_state=trigger_state,
108
+ show_arrows=show_arrows,
109
  text_dictionary=EDGE_GROUPS_FOR_SUMMARY
110
  )
111
 
112
  save_teacher_frames.append(frame_teacher_out)
113
  save_student_frames.append(frame_student_out)
114
+ all_text_summaries.extend([(log, idx, arrow) for (log, arrow) in text_info_summary])
115
 
116
  # ======================================================================================
117
  # create files for downloading and displaying.
utils.py CHANGED
@@ -10,6 +10,7 @@ from datetime import timedelta
10
  from scipy.signal import savgol_filter
11
  from scipy.stats import mstats
12
 
 
13
  def predict_keypoints_vitpose(
14
  video_path,
15
  model_path,
@@ -229,6 +230,7 @@ def get_dtw_mean_path(serieses_teacher, serieses_student, dtw_mean, dtw_filter):
229
 
230
 
231
  def modify_student_frame(
 
232
  detection_result_student,
233
  detection_result_teacher_angles,
234
  detection_result_student_angles,
@@ -241,9 +243,14 @@ def modify_student_frame(
241
  previously_trigered,
242
  previously_trigered_2,
243
  triger_state,
 
244
  text_dictionary,
245
  ):
246
-
 
 
 
 
247
  frame_copy = video_student[alignment_frames[1]]
248
  frame_teacher_copy = video_teacher[alignment_frames[0]]
249
  frame_errors = np.abs(detection_result_teacher_angles[alignment_frames[0]] - detection_result_student_angles[alignment_frames[1]])
@@ -272,22 +279,27 @@ def modify_student_frame(
272
 
273
  if check_threshold and check_certain and relevant_plane:
274
 
275
- point1 = detection_result_student[:, :, :-1][alignment_frames[1]][connection[0]]
276
- point2 = detection_result_student[:, :, :-1][alignment_frames[1]][connection[1]]
277
-
278
- point1 = np.array(point1).astype(int)
279
- point2 = np.array(point2).astype(int)
280
-
281
- point1 = [point1[1], point1[0]]
282
- point2 = [ point2[1], point2[0]]
283
 
 
284
 
285
  if triger_state == "one":
286
 
287
- _ = cv2.line(frame_copy, point1, point2, (0, 0, 255), 10)
 
 
 
288
 
289
  if (connection[0], connection[1]) in text_dictionary:
290
- text_info.append(text_dictionary[(connection[0], connection[1])])
 
 
 
291
 
292
  if triger_state == "two":
293
 
@@ -295,10 +307,16 @@ def modify_student_frame(
295
 
296
  if (connection[0], connection[1]) in previously_trigered:
297
 
298
- _ = cv2.line(frame_copy, point1, point2, (0, 0, 255), 10)
 
 
 
299
 
300
  if (connection[0], connection[1]) in text_dictionary:
301
- text_info.append(text_dictionary[(connection[0], connection[1])])
 
 
 
302
 
303
  if triger_state == "three":
304
 
@@ -310,10 +328,16 @@ def modify_student_frame(
310
 
311
  if (connection[0], connection[1]) in previously_trigered_2:
312
 
313
- _ = cv2.line(frame_copy, point1, point2, (0, 0, 255), 10)
 
 
 
314
 
315
  if (connection[0], connection[1]) in text_dictionary:
316
- text_info.append(text_dictionary[(connection[0], connection[1])])
 
 
 
317
 
318
  return frame_copy, frame_teacher_copy, list(set(trigered_connections)), list(set(trigered_connections2)), text_info
319
 
@@ -412,9 +436,9 @@ def generate_log(all_text_summaries):
412
 
413
  general_summary = []
414
  for log in all_text_summaries_clean:
415
- comment, frame = log
416
  total_seconds = frame / 30
417
- general_summary.append(f"{comment}. Video time: {str(timedelta(seconds=total_seconds))[3:-4]}")
418
 
419
  general_summary = "\n".join(general_summary)
420
 
@@ -465,3 +489,69 @@ Error logs:
465
  print(f"log {log_path} was created.")
466
 
467
  return log_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from scipy.signal import savgol_filter
11
  from scipy.stats import mstats
12
 
13
+
14
  def predict_keypoints_vitpose(
15
  video_path,
16
  model_path,
 
230
 
231
 
232
  def modify_student_frame(
233
+ detection_result_teacher,
234
  detection_result_student,
235
  detection_result_teacher_angles,
236
  detection_result_student_angles,
 
243
  previously_trigered,
244
  previously_trigered_2,
245
  triger_state,
246
+ show_arrows,
247
  text_dictionary,
248
  ):
249
+ arrows_bgr = (175, 75, 190)
250
+ arrows_sz = 3
251
+ skeleton_bgr = (0, 0, 255)
252
+ skeleton_sz = 3
253
+
254
  frame_copy = video_student[alignment_frames[1]]
255
  frame_teacher_copy = video_teacher[alignment_frames[0]]
256
  frame_errors = np.abs(detection_result_teacher_angles[alignment_frames[0]] - detection_result_student_angles[alignment_frames[1]])
 
279
 
280
  if check_threshold and check_certain and relevant_plane:
281
 
282
+ point1, point2, point2_t = align_points(
283
+ detection_result_student,
284
+ detection_result_teacher,
285
+ alignment_frames,
286
+ edge
287
+ )
 
 
288
 
289
+ arrow = get_arrow_direction(point2, point2_t)
290
 
291
  if triger_state == "one":
292
 
293
+ _ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)
294
+
295
+ if show_arrows:
296
+ _ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz)
297
 
298
  if (connection[0], connection[1]) in text_dictionary:
299
+ text_info.append((text_dictionary[(connection[0], connection[1])], arrow))
300
+
301
+ if (connection[1], connection[0]) in text_dictionary:
302
+ text_info.append((text_dictionary[(connection[1], connection[0])], arrow))
303
 
304
  if triger_state == "two":
305
 
 
307
 
308
  if (connection[0], connection[1]) in previously_trigered:
309
 
310
+ _ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)
311
+
312
+ if show_arrows:
313
+ _ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz)
314
 
315
  if (connection[0], connection[1]) in text_dictionary:
316
+ text_info.append((text_dictionary[(connection[0], connection[1])], arrow))
317
+
318
+ if (connection[1], connection[0]) in text_dictionary:
319
+ text_info.append((text_dictionary[(connection[1], connection[0])], arrow))
320
 
321
  if triger_state == "three":
322
 
 
328
 
329
  if (connection[0], connection[1]) in previously_trigered_2:
330
 
331
+ _ = cv2.line(frame_copy, point1, point2, skeleton_bgr, skeleton_sz)
332
+
333
+ if show_arrows:
334
+ _ = cv2.arrowedLine(frame_copy, point2, point2_t, arrows_bgr, arrows_sz)
335
 
336
  if (connection[0], connection[1]) in text_dictionary:
337
+ text_info.append((text_dictionary[(connection[0], connection[1])], arrow))
338
+
339
+ if (connection[1], connection[0]) in text_dictionary:
340
+ text_info.append((text_dictionary[(connection[1], connection[0])], arrow))
341
 
342
  return frame_copy, frame_teacher_copy, list(set(trigered_connections)), list(set(trigered_connections2)), text_info
343
 
 
436
 
437
  general_summary = []
438
  for log in all_text_summaries_clean:
439
+ comment, frame, arrow = log
440
  total_seconds = frame / 30
441
+ general_summary.append(f"{comment}. Direction: {arrow}. Video time: {str(timedelta(seconds=total_seconds))[3:-4]}")
442
 
443
  general_summary = "\n".join(general_summary)
444
 
 
489
  print(f"log {log_path} was created.")
490
 
491
  return log_path
492
+
493
+
494
+ def angle_between(v1, v2):
495
+ return np.arctan2(v2[1], v2[0]) - np.arctan2(v1[1], v1[0])
496
+
497
+
498
+ def align_points(detection_result_student, detection_result_teacher, alignment_frames, edge):
499
+
500
+ point0 = detection_result_student[alignment_frames[1], edge[0], :-1].astype(int)[::-1]
501
+ point1 = detection_result_student[alignment_frames[1], edge[1], :-1].astype(int)[::-1]
502
+ point2 = detection_result_student[alignment_frames[1], edge[2], :-1].astype(int)[::-1]
503
+
504
+ point0_t = detection_result_teacher[alignment_frames[0], edge[0], :-1].astype(int)[::-1]
505
+ point1_t = detection_result_teacher[alignment_frames[0], edge[1], :-1].astype(int)[::-1]
506
+ point2_t = detection_result_teacher[alignment_frames[0], edge[2], :-1].astype(int)[::-1]
507
+
508
+ translation = point0 - point0_t
509
+
510
+ point0_t += translation
511
+ point1_t += translation
512
+ point2_t += translation
513
+
514
+ BsA = point1 - point0
515
+ BtA = point1_t - point0
516
+
517
+ theta = angle_between(BtA, BsA)
518
+
519
+ R = np.array([
520
+ [np.cos(theta), -np.sin(theta)],
521
+ [np.sin(theta), np.cos(theta)]
522
+ ])
523
+
524
+ point1_t = np.dot(R, (point1_t - point0).T).T + point0
525
+ point2_t = np.dot(R, (point2_t - point0).T).T + point0
526
+
527
+ point2_t = point2_t.astype(int)
528
+
529
+ return point1, point2, point2_t
530
+
531
+
532
+ def get_arrow_direction(A, B):
533
+
534
+ translation_vector = B - A
535
+ angle_deg = np.degrees(np.arctan2(translation_vector[0], translation_vector[1]))
536
+
537
+ match angle_deg:
538
+ case angle if -22.5 <= angle < 22.5:
539
+ arrow = "⬆"
540
+ case angle if 22.5 <= angle < 67.5:
541
+ arrow = "⬈"
542
+ case angle if 67.5 <= angle < 112.5:
543
+ arrow = "➡"
544
+ case angle if 112.5 <= angle < 157.5:
545
+ arrow = "⬊"
546
+ case angle if 157.5 <= angle or angle < -157.5:
547
+ arrow = "⬇"
548
+ case angle if -157.5 <= angle < -112.5:
549
+ arrow = "⬋"
550
+ case angle if -112.5 <= angle < -67.5:
551
+ arrow = "⬅"
552
+ case angle if -67.5 <= angle < -22.5:
553
+ arrow = "⬉"
554
+ case _:
555
+ arrow = ""
556
+
557
+ return arrow