a e?B@sddlZddlmZddlmmZddlZddlZddlm Z ddl m Z ddl m Z ddlmZddlmZeGdddejZGd d d eZdS) N) flow_warp)ConvResidualBlocks)SpyNet)ModulatedDeformConvPack) ARCH_REGISTRYcsJeZdZdZdfdd Zd d Zd d ZddZddZddZ Z S)BasicVSRPlusPlusaxBasicVSR++ network structure. Support either x4 upsampling or same size output. Since DCN is used in this model, it can only be used with CUDA enabled. If CUDA is not enabled, feature alignment will be skipped. Besides, we adopt the official DCN implementation and the version of torch need to be higher than 1.9. ``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment`` Args: mid_channels (int, optional): Channel number of the intermediate features. Default: 64. num_blocks (int, optional): The number of residual blocks in each propagation branch. Default: 7. max_residue_magnitude (int): The maximum magnitude of the offset residue (Eq. 6 in paper). Default: 10. is_low_res_input (bool, optional): Whether the input is low-resolution or not. If False, the output resolution is equal to the input resolution. Default: True. spynet_path (str): Path to the pretrained weights of SPyNet. Default: None. cpu_cache_length (int, optional): When the length of sequence is larger than this value, the intermediate features are sent to CPU. This saves GPU memory, but slows down the inference speed. You can increase this number if you have a GPU with large memory. Default: 100. @ TNdc st||_||_||_t||_|r:td|d|_nLt t d|dddt j dddt ||dddt j dddt||d|_t |_t |_gd}t|D]J\}} tjrtd||ddd |d |j| <td|||||j| <qtd||d|_t j ||d ddddd |_t j |d ddddd |_t d|_t ddddd|_t ddddd|_t jd ddd|_t j ddd|_d|_t |jdkrd|_!nd|_!t"#ddS)N皙?Tnegative_slopeinplace)Z backward_1Z forward_1Z backward_2Z forward_2)paddingdeformable_groupsmax_residue_magnitude)biasrbilinearF) scale_factormode align_cornersrzDeformable alignment module is not added. Probably your CUDA is not configured correctly. DCN can only be used with CUDA enabled. Alignment is skipped now.)$super__init__ mid_channelsis_low_res_inputcpu_cache_lengthrspynetr feat_extractnn SequentialConv2d LeakyReLU ModuleDict deform_alignbackbone enumeratetorchcuda is_availableSecondOrderDeformableAlignmentreconstructionupconv1upconv2 PixelShuffle pixel_shuffleconv_hr conv_lastUpsample img_upsamplelreluis_mirror_extendedlenis_with_alignmentwarningswarn) selfr! num_blocksrr"Z spynet_pathr#modulesimodule __class__-D:\face swap\basicsr\archs\basicvsrpp_arch.pyr +sN        zBasicVSRPlusPlus.__init__cCsH|dddkrDtj|ddd\}}t||ddkrDd|_dS)aCheck whether the input is a mirror-extended sequence. If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame. Args: lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w). rrrdimTN)sizer.chunknormflipr<)rAlqslqs_1lqs_2rHrHrIcheck_if_mirror_extendedms z)BasicVSRPlusPlus.check_if_mirror_extendedc Cs|\}}}}}|ddddddddddfd|||}|ddddddddddfd|||}|||||dd||} |jr| d} n|||||dd||} |jr| } | } | | fS)ahCompute optical flow using SPyNet for feature alignment. Note that if the input is an mirror-extended sequence, 'flows_forward' is not needed, since it is equal to 'flows_backward.flip(1)'. Args: lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w). Return: tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation (current to previous). 'flows_backward' corresponds to the flows used for backward-time propagation (current to next). Nrr)rLreshaper$viewr<rO cpu_cachecpu) rArPntchwrQrRflows_backward flows_forwardrHrHrI compute_flow{s22 zBasicVSRPlusPlus.compute_flowc s|\}}}}}td|d} td|} ttdtd} | | ddd7} dvrl| ddd} | } |||j||} t| D]\} d| }|jr|}| } | dkr|j r|dd| | ddddddf}|jr|}t | | dddd}t | }t |}t |}| dkrd }|jrV|}|dd| | dddddddf}|jr|}|t || dddd}t || dddd}t j|||gdd }t j| |gdd } |j| |||} |gfd d D| g}|jr2d d |D}t j|dd }| |j|} | |jrdd<t jqdvrddd<S)aPropagate the latent features throughout the sequence. Args: feats dict(list[tensor]): Features from previous branches. Each component is a list of tensors with shape (n, c, h, w). flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). module_name (str): The name of the propgation branches. Can either be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. Return: dict(list[tensor]): A dictionary containing all the propagated features. Each key in the dictionary corresponds to a propagation branch, which is represented by a list of tensors. rrrTspatialNbackwardrr rJcs$g|]}|dfvr|qS)rarH.0kfeatsidx module_namerHrI z.BasicVSRPlusPlus.propagate..cSsg|] }|qSrH)r/)refrHrHrIrkrl)rLrangelistr= new_zerosr!r-rWr/r>rpermuter. zeros_likecatr+r,appendrX empty_cache)rArhflowsrjrYrZ_r\r]Z frame_idxZflow_idx mapping_idxZ feat_proprDZ feat_currentZflow_n1Zcond_n1Zfeat_n2Zflow_n2Zcond_n2condfeatrHrgrI propagates\ &     *"  zBasicVSRPlusPlus.propagatec slg}td}ttd|}||ddd7}td|dD]}fddD}|dd||tj|dd}|jr|}| |}| | | |}| | | |}| ||}||}|jr|||dd|ddddddf7}n&||dd|ddddddf7}|jrR|}tj||q@tj|ddS) aGCompute the output image given the features. Args: lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w). feats (dict): The features from the propagation branches. Returns: Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). rarNrTrcs"g|]}|dkr|dqS)rar)poprdrhrHrIrkrlz-BasicVSRPlusPlus.upsample..rJ)r=rornrLinsertr.rsrWr/r2r;r6r3r4r7r8r"r:rXrurtstack)rArPrhoutputs num_outputsrxrDhrrHr}rIupsamples.    .&  zBasicVSRPlusPlus.upsamplec s|\}}}}}||jkr dnd|_|jr6|}n2tj|d|||ddd||||d|d}||i}|jrg|d<t d |D]H} | |d d | d d d d d d f } |d | t jqnV| |d|||jd d \}}||d||fd d t d |D|d<|ddkrP|ddksfJd|d|d||\} } dD]t} dD]h}|d| }g||<|dkr| }n| d ur| }n | d}||||}|jr~t jqqx|||S)zForward function for BasicVSR++. Args: lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w). Returns: Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). TFrTg?bicubic)rrrrarNrc s.g|]&}dd|ddddddfqS)NrH)rerDZfeats_rHrIrk=rlz,BasicVSRPlusPlus.forward..r rzDThe height and width of low-res inputs must be at least 64, but got z and .)rr)rbforwardrwrbr)rLr#rWr"cloneF interpolaterVrSrnr%rXrtr.r/rushaper`rOr{r)rArPrYrZr[r\r]Zlqs_downsamplerhrDrzr_r^iter_ directionrErvrHrrIrsV   ,"   zBasicVSRPlusPlus.forward)rr r TNr ) __name__ __module__ __qualname____doc__r rSr`r{rr __classcell__rHrHrFrIrsB!Q+rcs0eZdZdZfddZddZddZZS)r1aSecond-order deformable alignment module. Args: in_channels (int): Same as nn.Conv2d. out_channels (int): Same as nn.Conv2d. kernel_size (int or tuple[int]): Same as nn.Conv2d. stride (int or tuple[int]): Same as nn.Conv2d. padding (int or tuple[int]): Same as nn.Conv2d. dilation (int or tuple[int]): Same as nn.Conv2d. groups (int): Same as nn.Conv2d. bias (bool or str): If specified as `auto`, it will be decided by the norm_cfg. Bias will be set as True if norm_cfg is None, otherwise False. max_residue_magnitude (int): The maximum magnitude of the offset residue (Eq. 6 in paper). Default: 10. cs|dd|_tt|j|i|ttd|jd|jdddtj dddt|j|jdddtj dddt|j|jdddtj dddt|jd |j ddd|_ | dS) Nrr r rrrTr) r|rrr1r r&r'r( out_channelsr)r conv_offset init_offset)rAargskwargsrFrHrIr ms    z'SecondOrderDeformableAlignment.__init__cCs"ddd}||jdddddS)NrcSsLt|dr$|jdur$tj|j|t|drH|jdurHtj|j|dS)Nweightr)hasattrrr&init constant_r)rEvalrrHrHrI_constant_initszBSecondOrderDeformableAlignment.init_offset.._constant_initrT)rr)r)r)rArrHrHrIr~s z*SecondOrderDeformableAlignment.init_offsetc Cstj|||gdd}||}tj|ddd\}}}|jttj||fdd} tj| ddd\} } | |dd| dddd} | |dd| dddd} tj| | gdd} t |}t j || |j |j|j|j|j|S)NrrJr r)r.rsrrMrtanhrOrepeatrLsigmoid torchvisionops deform_conv2drrstriderdilation) rAxZ extra_featZflow_1Zflow_2outo1o2maskoffsetZoffset_1Zoffset_2rHrHrIrs $$ z&SecondOrderDeformableAlignment.forward)rrrrr rrrrHrHrFrIr1[s  r1)r.torch.nnr&Ztorch.nn.functional functionalrrr?Zbasicsr.archs.arch_utilrZbasicsr.archs.basicvsr_archrZbasicsr.archs.spynet_archrZbasicsr.ops.dcnrbasicsr.utils.registryrregisterModulerr1rHrHrHrIs      N