svjack commited on
Commit
1a5df61
·
verified ·
1 Parent(s): 67b24c1

Create video_to_sketch_script.py

Browse files
Files changed (1) hide show
  1. video_to_sketch_script.py +110 -0
video_to_sketch_script.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ pip install gradio huggingface_hub torch==1.11.0 torchvision==0.12.0 pytorchvideo==0.1.5 pyav==11.4.1
3
+
4
+ huggingface-cli download \
5
+ --repo-type dataset svjack/video-dataset-Lily-Bikini-organized \
6
+ --local-dir video-dataset-Lily-Bikini-organized
7
+
8
+ python video_to_sketch_script.py video-dataset-Lily-Bikini-organized video-dataset-Lily-Bikini-sketch-organized --copy_others
9
+ '''
10
+
11
+ import gc
12
+ import os
13
+ import shutil
14
+ import argparse
15
+ import numpy as np
16
+ import torch
17
+ from huggingface_hub import hf_hub_download
18
+ from PIL.Image import Resampling
19
+ from pytorchvideo.data.encoded_video import EncodedVideo
20
+ from pytorchvideo.transforms.functional import uniform_temporal_subsample
21
+ from torchvision.io import write_video
22
+ from torchvision.transforms.functional import resize
23
+ from tqdm import tqdm
24
+
25
+ from modeling import Generator
26
+
27
+ MAX_DURATION = 60
28
+ OUT_FPS = 18
29
+ DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"
30
+
31
+ # Load the model
32
+ model = Generator(3, 1, 3)
33
+ weights_path = hf_hub_download("nateraw/image-2-line-drawing", "pytorch_model.bin")
34
+ model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
35
+ model.eval()
36
+
37
+ def process_one_second(vid, start_sec, out_fps):
38
+ """Process one second of a video at a given fps
39
+ Args:
40
+ vid (_type_): A pytorchvideo.EncodedVideo instance containing the video to process
41
+ start_sec (_type_): The second to start processing at
42
+ out_fps (_type_): The fps to output the video at
43
+ Returns:
44
+ np.array: The processed video as a numpy array with shape (T, H, W, C)
45
+ """
46
+ # C, T, H, W
47
+ video_arr = vid.get_clip(start_sec, start_sec + 1)["video"]
48
+ # C, T, H, W where T == frames per second
49
+ x = uniform_temporal_subsample(video_arr, out_fps)
50
+ # C, T, H, W where H has been scaled to 256 (This will probably be no bueno on vertical vids but whatever)
51
+ x = resize(x, 256, Resampling.BICUBIC)
52
+ # C, T, H, W -> T, C, H, W (basically T acts as batch size now)
53
+ x = x.permute(1, 0, 2, 3)
54
+
55
+ with torch.no_grad():
56
+ # T, 1, H, W
57
+ out = model(x)
58
+
59
+ # T, C, H, W -> T, H, W, C Rescaled to 0-255
60
+ out = out.permute(0, 2, 3, 1).clip(0, 1) * 255
61
+ # Greyscale -> RGB
62
+ out = out.repeat(1, 1, 1, 3)
63
+ return out
64
+
65
+ def process_video(input_video_path, output_video_path):
66
+ start_sec = 0
67
+ vid = EncodedVideo.from_path(input_video_path)
68
+ duration = min(MAX_DURATION, int(vid.duration))
69
+ for i in tqdm(range(duration), desc="Processing video"):
70
+ video = process_one_second(vid, start_sec=i + start_sec, out_fps=OUT_FPS)
71
+ gc.collect()
72
+ if i == 0:
73
+ video_all = video
74
+ else:
75
+ video_all = np.concatenate((video_all, video))
76
+
77
+ write_video(output_video_path, video_all, fps=OUT_FPS)
78
+
79
+ def copy_non_video_files(input_path, output_path):
80
+ """Copy non-video files and directories from input path to output path."""
81
+ for item in os.listdir(input_path):
82
+ item_path = os.path.join(input_path, item)
83
+ if not item.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
84
+ dest_path = os.path.join(output_path, item)
85
+ if os.path.isdir(item_path):
86
+ shutil.copytree(item_path, dest_path)
87
+ else:
88
+ shutil.copy2(item_path, dest_path)
89
+
90
+ def main(input_path, output_path, copy_others=False):
91
+ if not os.path.exists(output_path):
92
+ os.makedirs(output_path)
93
+
94
+ if copy_others:
95
+ copy_non_video_files(input_path, output_path)
96
+
97
+ for video_name in os.listdir(input_path):
98
+ if video_name.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
99
+ input_video_path = os.path.join(input_path, video_name)
100
+ output_video_path = os.path.join(output_path, video_name)
101
+ process_video(input_video_path, output_video_path)
102
+
103
+ if __name__ == "__main__":
104
+ parser = argparse.ArgumentParser(description="Process videos to convert them into sketch videos.")
105
+ parser.add_argument("input_path", type=str, help="Path to the input directory containing videos.")
106
+ parser.add_argument("output_path", type=str, help="Path to the output directory for processed videos.")
107
+ parser.add_argument("--copy_others", action="store_true", help="Copy non-video files and directories from input to output.")
108
+
109
+ args = parser.parse_args()
110
+ main(args.input_path, args.output_path, args.copy_others)