hpoghos's picture
Update i2v_enhance/thirdparty/VFI/Trainer.py
ca6b1f0 verified
# Adapted from https://github.com/MCG-NJU/EMA-VFI/blob/main/Trainer.py
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from i2v_enhance.thirdparty.VFI.model.loss import *
from i2v_enhance.thirdparty.VFI.config import *
class Model:
def __init__(self, local_rank):
backbonetype, multiscaletype = MODEL_CONFIG['MODEL_TYPE']
backbonecfg, multiscalecfg = MODEL_CONFIG['MODEL_ARCH']
self.net = multiscaletype(backbonetype(**backbonecfg), **multiscalecfg)
self.name = MODEL_CONFIG['LOGNAME']
# self.device()
# train
self.optimG = AdamW(self.net.parameters(), lr=2e-4, weight_decay=1e-4)
self.lap = LapLoss()
if local_rank != -1:
self.net = DDP(self.net, device_ids=[local_rank], output_device=local_rank)
def train(self):
self.net.train()
def eval(self):
self.net.eval()
def device(self):
self.net.to(torch.device("cuda"))
def unload(self):
self.net.to(torch.device("cpu"))
def load_model(self, name=None, rank=0):
def convert(param):
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k and 'attn_mask' not in k and 'HW' not in k
}
if rank <= 0 :
if name is None:
name = self.name
# self.net.load_state_dict(convert(torch.load(f'ckpt/{name}.pkl')))
self.net.load_state_dict(convert(torch.load(f'{name}')))
def save_model(self, rank=0):
if rank == 0:
torch.save(self.net.state_dict(),f'ckpt/{self.name}.pkl')
@torch.no_grad()
def hr_inference(self, img0, img1, TTA = False, down_scale = 1.0, timestep = 0.5, fast_TTA = False):
'''
Infer with down_scale flow
Noting: return BxCxHxW
'''
def infer(imgs):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
flow, mask = self.net.calculate_flow(imgs_down, timestep)
flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
af, _ = self.net.feature_bone(img0, img1)
pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
return pred
imgs = torch.cat((img0, img1), 1)
if fast_TTA:
imgs_ = imgs.flip(2).flip(3)
input = torch.cat((imgs, imgs_), 0)
preds = infer(input)
return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
if TTA == False:
return infer(imgs)
else:
return (infer(imgs) + infer(imgs.flip(2).flip(3)).flip(2).flip(3)) / 2
@torch.no_grad()
def inference(self, img0, img1, TTA = False, timestep = 0.5, fast_TTA = False):
imgs = torch.cat((img0, img1), 1)
'''
Noting: return BxCxHxW
'''
if fast_TTA:
imgs_ = imgs.flip(2).flip(3)
input = torch.cat((imgs, imgs_), 0)
_, _, _, preds = self.net(input, timestep=timestep)
return (preds[0] + preds[1].flip(1).flip(2)).unsqueeze(0) / 2.
_, _, _, pred = self.net(imgs, timestep=timestep)
if TTA == False:
return pred
else:
_, _, _, pred2 = self.net(imgs.flip(2).flip(3), timestep=timestep)
return (pred + pred2.flip(2).flip(3)) / 2
@torch.no_grad()
def multi_inference(self, img0, img1, TTA = False, down_scale = 1.0, time_list=[], fast_TTA = False):
'''
Run backbone once, get multi frames at different timesteps
Noting: return a list of [CxHxW]
'''
assert len(time_list) > 0, 'Time_list should not be empty!'
def infer(imgs):
img0, img1 = imgs[:, :3], imgs[:, 3:6]
af, mf = self.net.feature_bone(img0, img1)
imgs_down = None
if down_scale != 1.0:
imgs_down = F.interpolate(imgs, scale_factor=down_scale, mode="bilinear", align_corners=False)
afd, mfd = self.net.feature_bone(imgs_down[:, :3], imgs_down[:, 3:6])
pred_list = []
for timestep in time_list:
if imgs_down is None:
flow, mask = self.net.calculate_flow(imgs, timestep, af, mf)
else:
flow, mask = self.net.calculate_flow(imgs_down, timestep, afd, mfd)
flow = F.interpolate(flow, scale_factor = 1/down_scale, mode="bilinear", align_corners=False) * (1/down_scale)
mask = F.interpolate(mask, scale_factor = 1/down_scale, mode="bilinear", align_corners=False)
pred = self.net.coraseWarp_and_Refine(imgs, af, flow, mask)
pred_list.append(pred)
return pred_list
imgs = torch.cat((img0, img1), 1)
if fast_TTA:
imgs_ = imgs.flip(2).flip(3)
input = torch.cat((imgs, imgs_), 0)
preds_lst = infer(input)
return [(preds_lst[i][0] + preds_lst[i][1].flip(1).flip(2))/2 for i in range(len(time_list))]
preds = infer(imgs)
if TTA is False:
return [preds[i][0] for i in range(len(time_list))]
else:
flip_pred = infer(imgs.flip(2).flip(3))
return [(preds[i][0] + flip_pred[i][0].flip(1).flip(2))/2 for i in range(len(time_list))]
def update(self, imgs, gt, learning_rate=0, training=True):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
if training:
flow, mask, merged, pred = self.net(imgs)
loss_l1 = (self.lap(pred, gt)).mean()
for merge in merged:
loss_l1 += (self.lap(merge, gt)).mean() * 0.5
self.optimG.zero_grad()
loss_l1.backward()
self.optimG.step()
return pred, loss_l1
else:
with torch.no_grad():
flow, mask, merged, pred = self.net(imgs)
return pred, 0