Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,458 Bytes
ca6b1f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/Trainer.py
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from i2v_enhance.thirdparty.VFI.model.loss import *
from i2v_enhance.thirdparty.VFI.config import *
class Model:
def __init__(self, local_rank):
backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
self.name = MODEL_CONFIG['LOGNAME']
# self.device()
# train
self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
self.lap = LapLoss()
if local_rank != -1:
self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank)
def train(self):
self.net.train()
def eval(self):
self.net.eval()
def device(self):
self.net.to(torch.device("cuda"))
def unload(self):
self.net.to(torch.device("cpu"))
def load_model(self, name=None, rank=0):
def convert(param):
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k and 'attn_mask' not in k and 'HW' not in k
}
if rank <= 0 :
if name is None:
name = self.name
# self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))
self.net.load_state_dict(convert(torch.load(f'{name}')))
def save_model(self, rank=0):
if rank == 0:
torch.save(self.net.state_dict(),f'ckpt/{self.name}.pkl')
@torch.no_grad()
def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
'''
Infer with down_scale flow
Noting: return BxCxHxW
'''
def infer(imgs):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
flow, mask = self.net.calculate_flow(imgs_down, timestep)
flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
af, _ = self.net.feature_bone(img0, img1)
pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
return pred
imgs = torch.cat((img0, img1), 1)
if fast_TTA:
imgs_ = imgs.flip(2).flip(3)
input = torch.cat((imgs, imgs_), 0)
preds = infer(input)
return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
if TTA == False:
return infer(imgs)
else:
return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2
@torch.no_grad()
def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False):
imgs = torch.cat((img0, img1), 1)
'''
Noting: return BxCxHxW
'''
if fast_TTA:
imgs_ = imgs.flip(2).flip(3)
input = torch.cat((imgs, imgs_), 0)
_, _, _, preds = self.net(input, timestep=timestep)
return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
_, _, _, pred = self.net(imgs, timestep=timestep)
if TTA == False:
return pred
else:
_, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep)
return (pred + pred2.flip(2).flip(3)) / 2
@torch.no_grad()
def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False):
'''
Run backbone once, get multi frames at different timesteps
Noting: return a list of [CxHxW]
'''
assert len(time_list) > 0, 'Time_list should not be empty!'
def infer(imgs):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
af, mf = self.net.feature_bone(img0, img1)
imgs_down = None
if down_scale != 1.0:
imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6])
pred_list = []
for timestep in time_list:
if imgs_down is None:
flow, mask = self.net.calculate_flow(imgs, timestep, af, mf)
else:
flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd)
flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
pred_list.append(pred)
return pred_list
imgs = torch.cat((img0, img1), 1)
if fast_TTA:
imgs_ = imgs.flip(2).flip(3)
input = torch.cat((imgs, imgs_), 0)
preds_lst = infer(input)
return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))]
preds = infer(imgs)
if TTA is False:
return [preds[i][0] for i in range(len(time_list))]
else:
flip_pred = infer(imgs.flip(2).flip(3))
return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))]
def update(self, imgs, gt, learning_rate=0, training=True):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
if training:
flow, mask, merged, pred = self.net(imgs)
loss_l1 = (self.lap(pred, gt)).mean()
for merge in merged:
loss_l1 += (self.lap(merge, gt)).mean() * 0.5
self.optimG.zero_grad()
loss_l1.backward()
self.optimG.step()
return pred, loss_l1
else:
with torch.no_grad():
flow, mask, merged, pred = self.net(imgs)
return pred, 0
|