omninexus commited on
Commit
9d3ac6e
·
verified ·
1 Parent(s): 68add72

Upload 36 files

Browse files
Files changed (36) hide show
  1. Upsample/__init__.py +1 -0
  2. Upsample/__pycache__/__init__.cpython-38.pyc +0 -0
  3. Upsample/__pycache__/arch_utils.cpython-38.pyc +0 -0
  4. Upsample/__pycache__/model.cpython-38.pyc +0 -0
  5. Upsample/__pycache__/rrdbnet_arch.cpython-38.pyc +0 -0
  6. Upsample/__pycache__/utils.cpython-38.pyc +0 -0
  7. Upsample/arch_utils.py +197 -0
  8. Upsample/model.py +93 -0
  9. Upsample/rrdbnet_arch.py +121 -0
  10. Upsample/utils.py +135 -0
  11. janus/.DS_Store +0 -0
  12. janus/__init__.py +31 -0
  13. janus/__pycache__/__init__.cpython-38.pyc +0 -0
  14. janus/models/__init__.py +28 -0
  15. janus/models/__pycache__/__init__.cpython-38.pyc +0 -0
  16. janus/models/__pycache__/clip_encoder.cpython-38.pyc +0 -0
  17. janus/models/__pycache__/image_processing_vlm.cpython-38.pyc +0 -0
  18. janus/models/__pycache__/modeling_vlm.cpython-38.pyc +0 -0
  19. janus/models/__pycache__/processing_vlm.cpython-38.pyc +0 -0
  20. janus/models/__pycache__/projector.cpython-38.pyc +0 -0
  21. janus/models/__pycache__/siglip_vit.cpython-38.pyc +0 -0
  22. janus/models/__pycache__/vq_model.cpython-38.pyc +0 -0
  23. janus/models/clip_encoder.py +122 -0
  24. janus/models/image_processing_vlm.py +208 -0
  25. janus/models/modeling_vlm.py +272 -0
  26. janus/models/processing_vlm.py +418 -0
  27. janus/models/projector.py +100 -0
  28. janus/models/siglip_vit.py +681 -0
  29. janus/models/vq_model.py +527 -0
  30. janus/utils/__init__.py +18 -0
  31. janus/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  32. janus/utils/__pycache__/conversation.cpython-38.pyc +0 -0
  33. janus/utils/__pycache__/io.cpython-38.pyc +0 -0
  34. janus/utils/conversation.py +365 -0
  35. janus/utils/io.py +89 -0
  36. weights/RealESRGAN_x2.pth +3 -0
Upsample/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import RealESRGAN
Upsample/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (213 Bytes). View file
 
Upsample/__pycache__/arch_utils.cpython-38.pyc ADDED
Binary file (7.14 kB). View file
 
Upsample/__pycache__/model.cpython-38.pyc ADDED
Binary file (3.11 kB). View file
 
Upsample/__pycache__/rrdbnet_arch.cpython-38.pyc ADDED
Binary file (4.47 kB). View file
 
Upsample/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.05 kB). View file
 
Upsample/arch_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ @torch.no_grad()
9
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10
+ """Initialize network weights.
11
+
12
+ Args:
13
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14
+ scale (float): Scale initialized weights, especially for residual
15
+ blocks. Default: 1.
16
+ bias_fill (float): The value to fill bias. Default: 0
17
+ kwargs (dict): Other arguments for initialization function.
18
+ """
19
+ if not isinstance(module_list, list):
20
+ module_list = [module_list]
21
+ for module in module_list:
22
+ for m in module.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ init.kaiming_normal_(m.weight, **kwargs)
25
+ m.weight.data *= scale
26
+ if m.bias is not None:
27
+ m.bias.data.fill_(bias_fill)
28
+ elif isinstance(m, nn.Linear):
29
+ init.kaiming_normal_(m.weight, **kwargs)
30
+ m.weight.data *= scale
31
+ if m.bias is not None:
32
+ m.bias.data.fill_(bias_fill)
33
+ elif isinstance(m, _BatchNorm):
34
+ init.constant_(m.weight, 1)
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+
38
+
39
+ def make_layer(basic_block, num_basic_block, **kwarg):
40
+ """Make layers by stacking the same blocks.
41
+
42
+ Args:
43
+ basic_block (nn.module): nn.module class for basic block.
44
+ num_basic_block (int): number of blocks.
45
+
46
+ Returns:
47
+ nn.Sequential: Stacked blocks in nn.Sequential.
48
+ """
49
+ layers = []
50
+ for _ in range(num_basic_block):
51
+ layers.append(basic_block(**kwarg))
52
+ return nn.Sequential(*layers)
53
+
54
+
55
+ class ResidualBlockNoBN(nn.Module):
56
+ """Residual block without BN.
57
+
58
+ It has a style of:
59
+ ---Conv-ReLU-Conv-+-
60
+ |________________|
61
+
62
+ Args:
63
+ num_feat (int): Channel number of intermediate features.
64
+ Default: 64.
65
+ res_scale (float): Residual scale. Default: 1.
66
+ pytorch_init (bool): If set to True, use pytorch default init,
67
+ otherwise, use default_init_weights. Default: False.
68
+ """
69
+
70
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71
+ super(ResidualBlockNoBN, self).__init__()
72
+ self.res_scale = res_scale
73
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75
+ self.relu = nn.ReLU(inplace=True)
76
+
77
+ if not pytorch_init:
78
+ default_init_weights([self.conv1, self.conv2], 0.1)
79
+
80
+ def forward(self, x):
81
+ identity = x
82
+ out = self.conv2(self.relu(self.conv1(x)))
83
+ return identity + out * self.res_scale
84
+
85
+
86
+ class Upsample(nn.Sequential):
87
+ """Upsample module.
88
+
89
+ Args:
90
+ scale (int): Scale factor. Supported scales: 2^n and 3.
91
+ num_feat (int): Channel number of intermediate features.
92
+ """
93
+
94
+ def __init__(self, scale, num_feat):
95
+ m = []
96
+ if (scale & (scale - 1)) == 0: # scale = 2^n
97
+ for _ in range(int(math.log(scale, 2))):
98
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99
+ m.append(nn.PixelShuffle(2))
100
+ elif scale == 3:
101
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102
+ m.append(nn.PixelShuffle(3))
103
+ else:
104
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
105
+ super(Upsample, self).__init__(*m)
106
+
107
+
108
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
109
+ """Warp an image or feature map with optical flow.
110
+
111
+ Args:
112
+ x (Tensor): Tensor with size (n, c, h, w).
113
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
114
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
115
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
116
+ Default: 'zeros'.
117
+ align_corners (bool): Before pytorch 1.3, the default value is
118
+ align_corners=True. After pytorch 1.3, the default value is
119
+ align_corners=False. Here, we use the True as default.
120
+
121
+ Returns:
122
+ Tensor: Warped image or feature map.
123
+ """
124
+ assert x.size()[-2:] == flow.size()[1:3]
125
+ _, _, h, w = x.size()
126
+ # create mesh grid
127
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
128
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
129
+ grid.requires_grad = False
130
+
131
+ vgrid = grid + flow
132
+ # scale grid to [-1,1]
133
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
134
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
135
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
136
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
137
+
138
+ # TODO, what if align_corners=False
139
+ return output
140
+
141
+
142
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
143
+ """Resize a flow according to ratio or shape.
144
+
145
+ Args:
146
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
147
+ size_type (str): 'ratio' or 'shape'.
148
+ sizes (list[int | float]): the ratio for resizing or the final output
149
+ shape.
150
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
151
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
152
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
153
+ ratio > 1.0).
154
+ 2) The order of output_size should be [out_h, out_w].
155
+ interp_mode (str): The mode of interpolation for resizing.
156
+ Default: 'bilinear'.
157
+ align_corners (bool): Whether align corners. Default: False.
158
+
159
+ Returns:
160
+ Tensor: Resized flow.
161
+ """
162
+ _, _, flow_h, flow_w = flow.size()
163
+ if size_type == 'ratio':
164
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
165
+ elif size_type == 'shape':
166
+ output_h, output_w = sizes[0], sizes[1]
167
+ else:
168
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
169
+
170
+ input_flow = flow.clone()
171
+ ratio_h = output_h / flow_h
172
+ ratio_w = output_w / flow_w
173
+ input_flow[:, 0, :, :] *= ratio_w
174
+ input_flow[:, 1, :, :] *= ratio_h
175
+ resized_flow = F.interpolate(
176
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
177
+ return resized_flow
178
+
179
+
180
+ # TODO: may write a cpp file
181
+ def pixel_unshuffle(x, scale):
182
+ """ Pixel unshuffle.
183
+
184
+ Args:
185
+ x (Tensor): Input feature with shape (b, c, hh, hw).
186
+ scale (int): Downsample ratio.
187
+
188
+ Returns:
189
+ Tensor: the pixel unshuffled feature.
190
+ """
191
+ b, c, hh, hw = x.size()
192
+ out_channel = c * (scale**2)
193
+ assert hh % scale == 0 and hw % scale == 0
194
+ h = hh // scale
195
+ w = hw // scale
196
+ x_view = x.view(b, c, h, scale, w, scale)
197
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
Upsample/model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import cv2
7
+ from huggingface_hub import hf_hub_url, hf_hub_download
8
+
9
+ from .rrdbnet_arch import RRDBNet
10
+ from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
11
+ unpad_image
12
+
13
+ HF_MODELS = {
14
+ 2: dict(
15
+ repo_id='sberbank-ai/Real-ESRGAN',
16
+ filename='RealESRGAN_x2.pth',
17
+ ),
18
+ 4: dict(
19
+ repo_id='sberbank-ai/Real-ESRGAN',
20
+ filename='RealESRGAN_x4.pth',
21
+ ),
22
+ 8: dict(
23
+ repo_id='sberbank-ai/Real-ESRGAN',
24
+ filename='RealESRGAN_x8.pth',
25
+ ),
26
+ }
27
+
28
+
29
+ class RealESRGAN:
30
+ def __init__(self, device, scale=4):
31
+ self.device = device
32
+ self.scale = scale
33
+ self.model = RRDBNet(
34
+ num_in_ch=3, num_out_ch=3, num_feat=64,
35
+ num_block=23, num_grow_ch=32, scale=scale
36
+ )
37
+
38
+ def load_weights(self, model_path, download=True):
39
+ if not os.path.exists(model_path) and download:
40
+ assert self.scale in [2, 4, 8], 'You can download models only with scales: 2, 4, 8'
41
+ config = HF_MODELS[self.scale]
42
+ cache_dir = os.path.dirname(model_path)
43
+ local_filename = os.path.basename(model_path)
44
+ config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
45
+ htr = hf_hub_download(repo_id=config['repo_id'], cache_dir=cache_dir, local_dir=cache_dir,
46
+ filename=config['filename'])
47
+ print(htr)
48
+ # cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
49
+ print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
50
+
51
+ loadnet = torch.load(model_path)
52
+ if 'params' in loadnet:
53
+ self.model.load_state_dict(loadnet['params'], strict=True)
54
+ elif 'params_ema' in loadnet:
55
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
56
+ else:
57
+ self.model.load_state_dict(loadnet, strict=True)
58
+ self.model.eval()
59
+ self.model.to(self.device)
60
+
61
+ # @torch.cuda.amp.autocast()
62
+ def predict(self, lr_image, batch_size=4, patches_size=192,
63
+ padding=24, pad_size=15):
64
+ torch.autocast(device_type=self.device.type)
65
+ scale = self.scale
66
+ device = self.device
67
+ lr_image = np.array(lr_image)
68
+ lr_image = pad_reflect(lr_image, pad_size)
69
+
70
+ patches, p_shape = split_image_into_overlapping_patches(
71
+ lr_image, patch_size=patches_size, padding_size=padding
72
+ )
73
+ img = torch.FloatTensor(patches / 255).permute((0, 3, 1, 2)).to(device).detach()
74
+
75
+ with torch.no_grad():
76
+ res = self.model(img[0:batch_size])
77
+ for i in range(batch_size, img.shape[0], batch_size):
78
+ res = torch.cat((res, self.model(img[i:i + batch_size])), 0)
79
+
80
+ sr_image = res.permute((0, 2, 3, 1)).cpu().clamp_(0, 1)
81
+ np_sr_image = sr_image.numpy()
82
+
83
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
84
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
85
+ np_sr_image = stich_together(
86
+ np_sr_image, padded_image_shape=padded_size_scaled,
87
+ target_shape=scaled_image_shape, padding_size=padding * scale
88
+ )
89
+ sr_img = (np_sr_image * 255).astype(np.uint8)
90
+ sr_img = unpad_image(sr_img, pad_size * scale)
91
+ sr_img = Image.fromarray(sr_img)
92
+
93
+ return sr_img
Upsample/rrdbnet_arch.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Emperically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Emperically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+
65
+ class RRDBNet(nn.Module):
66
+ """Networks consisting of Residual in Residual Dense Block, which is used
67
+ in ESRGAN.
68
+
69
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
+
71
+ We extend ESRGAN for scale x2 and scale x1.
72
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
73
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
+
76
+ Args:
77
+ num_in_ch (int): Channel number of inputs.
78
+ num_out_ch (int): Channel number of outputs.
79
+ num_feat (int): Channel number of intermediate features.
80
+ Default: 64
81
+ num_block (int): Block number in the trunk network. Defaults: 23
82
+ num_grow_ch (int): Channels for each growth. Default: 32.
83
+ """
84
+
85
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
+ super(RRDBNet, self).__init__()
87
+ self.scale = scale
88
+ if scale == 2:
89
+ num_in_ch = num_in_ch * 4
90
+ elif scale == 1:
91
+ num_in_ch = num_in_ch * 16
92
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
+ # upsample
96
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ if scale == 8:
99
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ if self.scale == 8:
119
+ feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
120
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
121
+ return out
Upsample/utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import io
6
+
7
+
8
+ def pad_reflect(image, pad_size):
9
+ imsize = image.shape
10
+ height, width = imsize[:2]
11
+ new_img = np.zeros([height + pad_size * 2, width + pad_size * 2, imsize[2]]).astype(np.uint8)
12
+ new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
13
+
14
+ new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) # top
15
+ new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) # bottom
16
+ new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size * 2, :], axis=1) # left
17
+ new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size * 2:-pad_size, :], axis=1) # right
18
+
19
+ return new_img
20
+
21
+
22
+ def unpad_image(image, pad_size):
23
+ return image[pad_size:-pad_size, pad_size:-pad_size, :]
24
+
25
+
26
+ def process_array(image_array, expand=True):
27
+ """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
28
+
29
+ image_batch = image_array / 255.0
30
+ if expand:
31
+ image_batch = np.expand_dims(image_batch, axis=0)
32
+ return image_batch
33
+
34
+
35
+ def process_output(output_tensor):
36
+ """ Transforms the 4-dimensional output tensor into a suitable image format. """
37
+
38
+ sr_img = output_tensor.clip(0, 1) * 255
39
+ sr_img = np.uint8(sr_img)
40
+ return sr_img
41
+
42
+
43
+ def pad_patch(image_patch, padding_size, channel_last=True):
44
+ """ Pads image_patch with with padding_size edge values. """
45
+
46
+ if channel_last:
47
+ return np.pad(
48
+ image_patch,
49
+ ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
50
+ 'edge',
51
+ )
52
+ else:
53
+ return np.pad(
54
+ image_patch,
55
+ ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
56
+ 'edge',
57
+ )
58
+
59
+
60
+ def unpad_patches(image_patches, padding_size):
61
+ return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
62
+
63
+
64
+ def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
65
+ """ Splits the image into partially overlapping patches.
66
+ The patches overlap by padding_size pixels.
67
+ Pads the image twice:
68
+ - first to have a size multiple of the patch size,
69
+ - then to have equal padding at the borders.
70
+ Args:
71
+ image_array: numpy array of the input image.
72
+ patch_size: size of the patches from the original image (without padding).
73
+ padding_size: size of the overlapping area.
74
+ """
75
+
76
+ xmax, ymax, _ = image_array.shape
77
+ x_remainder = xmax % patch_size
78
+ y_remainder = ymax % patch_size
79
+
80
+ # modulo here is to avoid extending of patch_size instead of 0
81
+ x_extend = (patch_size - x_remainder) % patch_size
82
+ y_extend = (patch_size - y_remainder) % patch_size
83
+
84
+ # make sure the image is divisible into regular patches
85
+ extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
86
+
87
+ # add padding around the image to simplify computations
88
+ padded_image = pad_patch(extended_image, padding_size, channel_last=True)
89
+
90
+ xmax, ymax, _ = padded_image.shape
91
+ patches = []
92
+
93
+ x_lefts = range(padding_size, xmax - padding_size, patch_size)
94
+ y_tops = range(padding_size, ymax - padding_size, patch_size)
95
+
96
+ for x in x_lefts:
97
+ for y in y_tops:
98
+ x_left = x - padding_size
99
+ y_top = y - padding_size
100
+ x_right = x + patch_size + padding_size
101
+ y_bottom = y + patch_size + padding_size
102
+ patch = padded_image[x_left:x_right, y_top:y_bottom, :]
103
+ patches.append(patch)
104
+
105
+ return np.array(patches), padded_image.shape
106
+
107
+
108
+ def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
109
+ """ Reconstruct the image from overlapping patches.
110
+ After scaling, shapes and padding should be scaled too.
111
+ Args:
112
+ patches: patches obtained with split_image_into_overlapping_patches
113
+ padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
114
+ target_shape: shape of the final image
115
+ padding_size: size of the overlapping area.
116
+ """
117
+
118
+ xmax, ymax, _ = padded_image_shape
119
+ patches = unpad_patches(patches, padding_size)
120
+ patch_size = patches.shape[1]
121
+ n_patches_per_row = ymax // patch_size
122
+
123
+ complete_image = np.zeros((xmax, ymax, 3))
124
+
125
+ row = -1
126
+ col = 0
127
+ for i in range(len(patches)):
128
+ if i % n_patches_per_row == 0:
129
+ row += 1
130
+ col = 0
131
+ complete_image[
132
+ row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, :
133
+ ] = patches[i]
134
+ col += 1
135
+ return complete_image[0: target_shape[0], 0: target_shape[1], :]
janus/.DS_Store ADDED
Binary file (6.15 kB). View file
 
janus/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+
21
+ # check if python version is above 3.10
22
+ import sys
23
+
24
+ if sys.version_info >= (3, 10):
25
+ print("Python version is above 3.10, patching the collections module.")
26
+ # Monkey patch collections
27
+ import collections
28
+ import collections.abc
29
+
30
+ for type_name in collections.abc.__all__:
31
+ setattr(collections, type_name, getattr(collections.abc, type_name))
janus/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (433 Bytes). View file
 
janus/models/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from .image_processing_vlm import VLMImageProcessor
21
+ from .modeling_vlm import MultiModalityCausalLM
22
+ from .processing_vlm import VLChatProcessor
23
+
24
+ __all__ = [
25
+ "VLMImageProcessor",
26
+ "VLChatProcessor",
27
+ "MultiModalityCausalLM",
28
+ ]
janus/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (391 Bytes). View file
 
janus/models/__pycache__/clip_encoder.cpython-38.pyc ADDED
Binary file (2.74 kB). View file
 
janus/models/__pycache__/image_processing_vlm.cpython-38.pyc ADDED
Binary file (4.98 kB). View file
 
janus/models/__pycache__/modeling_vlm.cpython-38.pyc ADDED
Binary file (7.1 kB). View file
 
janus/models/__pycache__/processing_vlm.cpython-38.pyc ADDED
Binary file (11.1 kB). View file
 
janus/models/__pycache__/projector.cpython-38.pyc ADDED
Binary file (2.23 kB). View file
 
janus/models/__pycache__/siglip_vit.cpython-38.pyc ADDED
Binary file (18.4 kB). View file
 
janus/models/__pycache__/vq_model.cpython-38.pyc ADDED
Binary file (12.5 kB). View file
 
janus/models/clip_encoder.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import Dict, List, Literal, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision.transforms
25
+ from einops import rearrange
26
+
27
+ from janus.models.siglip_vit import create_siglip_vit
28
+
29
+
30
+ class CLIPVisionTower(nn.Module):
31
+ def __init__(
32
+ self,
33
+ model_name: str = "siglip_large_patch16_384",
34
+ image_size: Union[Tuple[int, int], int] = 336,
35
+ select_feature: str = "patch",
36
+ select_layer: int = -2,
37
+ select_layers: list = None,
38
+ ckpt_path: str = "",
39
+ pixel_mean: Optional[List[float]] = None,
40
+ pixel_std: Optional[List[float]] = None,
41
+ **kwargs,
42
+ ):
43
+ super().__init__()
44
+
45
+ self.model_name = model_name
46
+ self.select_feature = select_feature
47
+ self.select_layer = select_layer
48
+ self.select_layers = select_layers
49
+
50
+ vision_tower_params = {
51
+ "model_name": model_name,
52
+ "image_size": image_size,
53
+ "ckpt_path": ckpt_path,
54
+ "select_layer": select_layer,
55
+ }
56
+ vision_tower_params.update(kwargs)
57
+ self.vision_tower, self.forward_kwargs = self.build_vision_tower(
58
+ vision_tower_params
59
+ )
60
+
61
+ if pixel_mean is not None and pixel_std is not None:
62
+ image_norm = torchvision.transforms.Normalize(
63
+ mean=pixel_mean, std=pixel_std
64
+ )
65
+ else:
66
+ image_norm = None
67
+
68
+ self.image_norm = image_norm
69
+
70
+ def build_vision_tower(self, vision_tower_params):
71
+ if self.model_name.startswith("siglip"):
72
+ self.select_feature = "same"
73
+ vision_tower = create_siglip_vit(**vision_tower_params)
74
+ forward_kwargs = dict()
75
+
76
+ elif self.model_name.startswith("sam"):
77
+ vision_tower = create_sam_vit(**vision_tower_params)
78
+ forward_kwargs = dict()
79
+
80
+ else: # huggingface
81
+ from transformers import CLIPVisionModel
82
+
83
+ vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
84
+ forward_kwargs = dict(output_hidden_states=True)
85
+
86
+ return vision_tower, forward_kwargs
87
+
88
+ def feature_select(self, image_forward_outs):
89
+ if isinstance(image_forward_outs, torch.Tensor):
90
+ # the output has been the self.select_layer"s features
91
+ image_features = image_forward_outs
92
+ else:
93
+ image_features = image_forward_outs.hidden_states[self.select_layer]
94
+
95
+ if self.select_feature == "patch":
96
+ # if the output has cls_token
97
+ image_features = image_features[:, 1:]
98
+ elif self.select_feature == "cls_patch":
99
+ image_features = image_features
100
+ elif self.select_feature == "same":
101
+ image_features = image_features
102
+
103
+ else:
104
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
105
+ return image_features
106
+
107
+ def forward(self, images):
108
+ """
109
+
110
+ Args:
111
+ images (torch.Tensor): [b, 3, H, W]
112
+
113
+ Returns:
114
+ image_features (torch.Tensor): [b, n_patch, d]
115
+ """
116
+
117
+ if self.image_norm is not None:
118
+ images = self.image_norm(images)
119
+
120
+ image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
121
+ image_features = self.feature_select(image_forward_outs)
122
+ return image_features
janus/models/image_processing_vlm.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import List, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torchvision
25
+ import torchvision.transforms.functional
26
+ from PIL import Image
27
+ from transformers import AutoImageProcessor, PretrainedConfig
28
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
29
+ from transformers.image_utils import to_numpy_array
30
+ from transformers.utils import logging
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
35
+ IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
36
+ IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
37
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
38
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
39
+
40
+
41
+ def expand2square(pil_img, background_color):
42
+ width, height = pil_img.size
43
+ if width == height:
44
+ return pil_img
45
+ elif width > height:
46
+ result = Image.new(pil_img.mode, (width, width), background_color)
47
+ result.paste(pil_img, (0, (width - height) // 2))
48
+ return result
49
+ else:
50
+ result = Image.new(pil_img.mode, (height, height), background_color)
51
+ result.paste(pil_img, ((height - width) // 2, 0))
52
+ return result
53
+
54
+
55
+ class VLMImageProcessorConfig(PretrainedConfig):
56
+ model_type = "deepseek_vlm"
57
+ image_size: int
58
+ min_size: int
59
+ image_mean: Union[Tuple[float, float, float], List[float]]
60
+ image_std: Union[Tuple[float, float, float], List[float]]
61
+ rescale_factor: float
62
+ do_normalize: bool
63
+
64
+ def __init__(
65
+ self,
66
+ image_size: int,
67
+ min_size: int = 14,
68
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
69
+ 0.48145466,
70
+ 0.4578275,
71
+ 0.40821073,
72
+ ),
73
+ image_std: Union[Tuple[float, float, float], List[float]] = (
74
+ 0.26862954,
75
+ 0.26130258,
76
+ 0.27577711,
77
+ ),
78
+ rescale_factor: float = 1.0 / 255.0,
79
+ do_normalize: bool = True,
80
+ **kwargs,
81
+ ):
82
+ self.image_size = image_size
83
+ self.min_size = min_size
84
+ self.image_mean = image_mean
85
+ self.image_std = image_std
86
+ self.rescale_factor = rescale_factor
87
+ self.do_normalize = do_normalize
88
+
89
+ super().__init__(**kwargs)
90
+
91
+
92
+ class VLMImageProcessor(BaseImageProcessor):
93
+ model_input_names = ["pixel_values"]
94
+
95
+ def __init__(
96
+ self,
97
+ image_size: int,
98
+ min_size: int = 14,
99
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
100
+ 0.48145466,
101
+ 0.4578275,
102
+ 0.40821073,
103
+ ),
104
+ image_std: Union[Tuple[float, float, float], List[float]] = (
105
+ 0.26862954,
106
+ 0.26130258,
107
+ 0.27577711,
108
+ ),
109
+ rescale_factor: float = 1.0 / 255.0,
110
+ do_normalize: bool = True,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(**kwargs)
114
+
115
+ self.image_size = image_size
116
+ self.rescale_factor = rescale_factor
117
+ self.image_mean = image_mean
118
+ self.image_std = image_std
119
+ self.min_size = min_size
120
+ self.do_normalize = do_normalize
121
+
122
+ if image_mean is None:
123
+ self.background_color = (127, 127, 127)
124
+ else:
125
+ self.background_color = tuple([int(x * 255) for x in image_mean])
126
+
127
+ def resize(self, pil_img: Image) -> np.ndarray:
128
+ """
129
+
130
+ Args:
131
+ pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
132
+
133
+ Returns:
134
+ x (np.ndarray): [3, self.image_size, self.image_size]
135
+ """
136
+
137
+ width, height = pil_img.size
138
+ max_size = max(width, height)
139
+
140
+ size = [
141
+ max(int(height / max_size * self.image_size), self.min_size),
142
+ max(int(width / max_size * self.image_size), self.min_size),
143
+ ]
144
+
145
+ if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
146
+ print(f"orig size = {pil_img.size}, new size = {size}")
147
+ raise ValueError("Invalid size!")
148
+
149
+ pil_img = torchvision.transforms.functional.resize(
150
+ pil_img,
151
+ size,
152
+ interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
153
+ antialias=True,
154
+ )
155
+
156
+ pil_img = expand2square(pil_img, self.background_color)
157
+ x = to_numpy_array(pil_img)
158
+
159
+ # [H, W, 3] -> [3, H, W]
160
+ x = np.transpose(x, (2, 0, 1))
161
+
162
+ return x
163
+
164
+ def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
165
+ # resize and pad to [self.image_size, self.image_size]
166
+ # then convert from [H, W, 3] to [3, H, W]
167
+ images: List[np.ndarray] = [self.resize(image) for image in images]
168
+
169
+ # resacle from [0, 255] -> [0, 1]
170
+ images = [
171
+ self.rescale(
172
+ image=image,
173
+ scale=self.rescale_factor,
174
+ input_data_format="channels_first",
175
+ )
176
+ for image in images
177
+ ]
178
+
179
+ # normalize
180
+ if self.do_normalize:
181
+ images = [
182
+ self.normalize(
183
+ image=image,
184
+ mean=self.image_mean,
185
+ std=self.image_std,
186
+ input_data_format="channels_first",
187
+ )
188
+ for image in images
189
+ ]
190
+
191
+ data = {"pixel_values": images}
192
+ return BatchFeature(data=data, tensor_type=return_tensors)
193
+
194
+ @property
195
+ def default_shape(self):
196
+ return [3, self.image_size, self.image_size]
197
+
198
+
199
+ AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ image_processor = VLMImageProcessor(
204
+ image_size=1024,
205
+ image_mean=IMAGENET_INCEPTION_MEAN,
206
+ image_std=IMAGENET_INCEPTION_STD,
207
+ do_normalize=True,
208
+ )
janus/models/modeling_vlm.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import torch
21
+ from attrdict import AttrDict
22
+ from einops import rearrange
23
+ from transformers import (
24
+ AutoConfig,
25
+ AutoModelForCausalLM,
26
+ LlamaConfig,
27
+ LlamaForCausalLM,
28
+ PreTrainedModel,
29
+ )
30
+ from transformers.configuration_utils import PretrainedConfig
31
+
32
+ from janus.models.clip_encoder import CLIPVisionTower
33
+ from janus.models.projector import MlpProjector
34
+
35
+
36
+ class vision_head(torch.nn.Module):
37
+ def __init__(self, params):
38
+ super().__init__()
39
+ self.output_mlp_projector = torch.nn.Linear(
40
+ params.n_embed, params.image_token_embed
41
+ )
42
+ self.vision_activation = torch.nn.GELU()
43
+ self.vision_head = torch.nn.Linear(
44
+ params.image_token_embed, params.image_token_size
45
+ )
46
+
47
+ def forward(self, x):
48
+ x = self.output_mlp_projector(x)
49
+ x = self.vision_activation(x)
50
+ x = self.vision_head(x)
51
+ return x
52
+
53
+
54
+ def model_name_to_cls(cls_name):
55
+ if "MlpProjector" in cls_name:
56
+ cls = MlpProjector
57
+
58
+ elif "CLIPVisionTower" in cls_name:
59
+ cls = CLIPVisionTower
60
+
61
+ elif "VQ" in cls_name:
62
+ from janus.models.vq_model import VQ_models
63
+
64
+ cls = VQ_models[cls_name]
65
+ elif "vision_head" in cls_name:
66
+ cls = vision_head
67
+ else:
68
+ raise ValueError(f"class_name {cls_name} is invalid.")
69
+
70
+ return cls
71
+
72
+
73
+ class VisionConfig(PretrainedConfig):
74
+ model_type = "vision"
75
+ cls: str = ""
76
+ params: AttrDict = {}
77
+
78
+ def __init__(self, **kwargs):
79
+ super().__init__(**kwargs)
80
+
81
+ self.cls = kwargs.get("cls", "")
82
+ if not isinstance(self.cls, str):
83
+ self.cls = self.cls.__name__
84
+
85
+ self.params = AttrDict(kwargs.get("params", {}))
86
+
87
+
88
+ class AlignerConfig(PretrainedConfig):
89
+ model_type = "aligner"
90
+ cls: str = ""
91
+ params: AttrDict = {}
92
+
93
+ def __init__(self, **kwargs):
94
+ super().__init__(**kwargs)
95
+
96
+ self.cls = kwargs.get("cls", "")
97
+ if not isinstance(self.cls, str):
98
+ self.cls = self.cls.__name__
99
+
100
+ self.params = AttrDict(kwargs.get("params", {}))
101
+
102
+
103
+ class GenVisionConfig(PretrainedConfig):
104
+ model_type = "gen_vision"
105
+ cls: str = ""
106
+ params: AttrDict = {}
107
+
108
+ def __init__(self, **kwargs):
109
+ super().__init__(**kwargs)
110
+
111
+ self.cls = kwargs.get("cls", "")
112
+ if not isinstance(self.cls, str):
113
+ self.cls = self.cls.__name__
114
+
115
+ self.params = AttrDict(kwargs.get("params", {}))
116
+
117
+
118
+ class GenAlignerConfig(PretrainedConfig):
119
+ model_type = "gen_aligner"
120
+ cls: str = ""
121
+ params: AttrDict = {}
122
+
123
+ def __init__(self, **kwargs):
124
+ super().__init__(**kwargs)
125
+
126
+ self.cls = kwargs.get("cls", "")
127
+ if not isinstance(self.cls, str):
128
+ self.cls = self.cls.__name__
129
+
130
+ self.params = AttrDict(kwargs.get("params", {}))
131
+
132
+
133
+ class GenHeadConfig(PretrainedConfig):
134
+ model_type = "gen_head"
135
+ cls: str = ""
136
+ params: AttrDict = {}
137
+
138
+ def __init__(self, **kwargs):
139
+ super().__init__(**kwargs)
140
+
141
+ self.cls = kwargs.get("cls", "")
142
+ if not isinstance(self.cls, str):
143
+ self.cls = self.cls.__name__
144
+
145
+ self.params = AttrDict(kwargs.get("params", {}))
146
+
147
+
148
+ class MultiModalityConfig(PretrainedConfig):
149
+ model_type = "multi_modality"
150
+ vision_config: VisionConfig
151
+ aligner_config: AlignerConfig
152
+
153
+ gen_vision_config: GenVisionConfig
154
+ gen_aligner_config: GenAlignerConfig
155
+ gen_head_config: GenHeadConfig
156
+
157
+ language_config: LlamaConfig
158
+
159
+ def __init__(self, **kwargs):
160
+ super().__init__(**kwargs)
161
+ vision_config = kwargs.get("vision_config", {})
162
+ self.vision_config = VisionConfig(**vision_config)
163
+
164
+ aligner_config = kwargs.get("aligner_config", {})
165
+ self.aligner_config = AlignerConfig(**aligner_config)
166
+
167
+ gen_vision_config = kwargs.get("gen_vision_config", {})
168
+ self.gen_vision_config = GenVisionConfig(**gen_vision_config)
169
+
170
+ gen_aligner_config = kwargs.get("gen_aligner_config", {})
171
+ self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
172
+
173
+ gen_head_config = kwargs.get("gen_head_config", {})
174
+ self.gen_head_config = GenHeadConfig(**gen_head_config)
175
+
176
+ language_config = kwargs.get("language_config", {})
177
+ if isinstance(language_config, LlamaConfig):
178
+ self.language_config = language_config
179
+ else:
180
+ self.language_config = LlamaConfig(**language_config)
181
+
182
+
183
+ class MultiModalityPreTrainedModel(PreTrainedModel):
184
+ config_class = MultiModalityConfig
185
+ base_model_prefix = "multi_modality"
186
+ _no_split_modules = []
187
+ _skip_keys_device_placement = "past_key_values"
188
+
189
+
190
+ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
191
+ def __init__(self, config: MultiModalityConfig):
192
+ super().__init__(config)
193
+
194
+ vision_config = config.vision_config
195
+ vision_cls = model_name_to_cls(vision_config.cls)
196
+ self.vision_model = vision_cls(**vision_config.params)
197
+
198
+ aligner_config = config.aligner_config
199
+ aligner_cls = model_name_to_cls(aligner_config.cls)
200
+ self.aligner = aligner_cls(aligner_config.params)
201
+
202
+ gen_vision_config = config.gen_vision_config
203
+ gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
204
+ self.gen_vision_model = gen_vision_cls()
205
+
206
+ gen_aligner_config = config.gen_aligner_config
207
+ gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
208
+ self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
209
+
210
+ gen_head_config = config.gen_head_config
211
+ gen_head_cls = model_name_to_cls(gen_head_config.cls)
212
+ self.gen_head = gen_head_cls(gen_head_config.params)
213
+
214
+ self.gen_embed = torch.nn.Embedding(
215
+ gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
216
+ )
217
+
218
+ language_config = config.language_config
219
+ self.language_model = LlamaForCausalLM(language_config)
220
+
221
+ def prepare_inputs_embeds(
222
+ self,
223
+ input_ids: torch.LongTensor,
224
+ pixel_values: torch.FloatTensor,
225
+ images_seq_mask: torch.LongTensor,
226
+ images_emb_mask: torch.LongTensor,
227
+ **kwargs,
228
+ ):
229
+ """
230
+
231
+ Args:
232
+ input_ids (torch.LongTensor): [b, T]
233
+ pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
234
+ images_seq_mask (torch.BoolTensor): [b, T]
235
+ images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
236
+
237
+ assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
238
+
239
+ Returns:
240
+ input_embeds (torch.Tensor): [b, T, D]
241
+ """
242
+
243
+ bs, n = pixel_values.shape[0:2]
244
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
245
+ # [b x n, T2, D]
246
+ images_embeds = self.aligner(self.vision_model(images))
247
+
248
+ # [b x n, T2, D] -> [b, n x T2, D]
249
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
250
+ # [b, n, T2] -> [b, n x T2]
251
+ images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
252
+
253
+ # [b, T, D]
254
+ input_ids[input_ids < 0] = 0 # ignore the image embeddings
255
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
256
+
257
+ # replace with the image embeddings
258
+ inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
259
+
260
+ return inputs_embeds
261
+
262
+ def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
263
+ return self.gen_aligner(self.gen_embed(image_ids))
264
+
265
+
266
+ AutoConfig.register("vision", VisionConfig)
267
+ AutoConfig.register("aligner", AlignerConfig)
268
+ AutoConfig.register("gen_vision", GenVisionConfig)
269
+ AutoConfig.register("gen_aligner", GenAlignerConfig)
270
+ AutoConfig.register("gen_head", GenHeadConfig)
271
+ AutoConfig.register("multi_modality", MultiModalityConfig)
272
+ AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
janus/models/processing_vlm.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Dict, List
22
+
23
+ import torch
24
+ from PIL.Image import Image
25
+ from transformers import LlamaTokenizerFast
26
+ from transformers.processing_utils import ProcessorMixin
27
+
28
+ from janus.models.image_processing_vlm import VLMImageProcessor
29
+ from janus.utils.conversation import get_conv_template
30
+
31
+
32
+ class DictOutput(object):
33
+ def keys(self):
34
+ return self.__dict__.keys()
35
+
36
+ def __getitem__(self, item):
37
+ return self.__dict__[item]
38
+
39
+ def __setitem__(self, key, value):
40
+ self.__dict__[key] = value
41
+
42
+
43
+ @dataclass
44
+ class VLChatProcessorOutput(DictOutput):
45
+ sft_format: str
46
+ input_ids: torch.Tensor
47
+ pixel_values: torch.Tensor
48
+ num_image_tokens: torch.IntTensor
49
+
50
+ def __len__(self):
51
+ return len(self.input_ids)
52
+
53
+
54
+ @dataclass
55
+ class BatchedVLChatProcessorOutput(DictOutput):
56
+ sft_format: List[str]
57
+ input_ids: torch.Tensor
58
+ pixel_values: torch.Tensor
59
+ attention_mask: torch.Tensor
60
+ images_seq_mask: torch.BoolTensor
61
+ images_emb_mask: torch.BoolTensor
62
+
63
+ def to(self, device, dtype=torch.bfloat16):
64
+ self.input_ids = self.input_ids.to(device)
65
+ self.attention_mask = self.attention_mask.to(device)
66
+ self.images_seq_mask = self.images_seq_mask.to(device)
67
+ self.images_emb_mask = self.images_emb_mask.to(device)
68
+ self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
69
+ return self
70
+
71
+
72
+ class VLChatProcessor(ProcessorMixin):
73
+ image_processor_class = "AutoImageProcessor"
74
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
75
+
76
+ attributes = ["image_processor", "tokenizer"]
77
+
78
+ system_prompt = (
79
+ "You are a helpful language and vision assistant. "
80
+ "You are able to understand the visual content that the user provides, "
81
+ "and assist the user with a variety of tasks using natural language."
82
+ )
83
+
84
+ def __init__(
85
+ self,
86
+ image_processor: VLMImageProcessor,
87
+ tokenizer: LlamaTokenizerFast,
88
+ image_tag: str = "<image_placeholder>",
89
+ image_start_tag: str = "<begin_of_image>",
90
+ image_end_tag: str = "<end_of_image>",
91
+ pad_tag: str = "<|▁pad▁|>",
92
+ num_image_tokens: int = 576,
93
+ add_special_token: bool = False,
94
+ sft_format: str = "deepseek",
95
+ mask_prompt: bool = True,
96
+ ignore_id: int = -100,
97
+ **kwargs,
98
+ ):
99
+ self.image_processor = image_processor
100
+ self.tokenizer = tokenizer
101
+
102
+ image_id = self.tokenizer.vocab.get(image_tag)
103
+ if image_id is None:
104
+ special_tokens = [image_tag]
105
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
106
+ self.tokenizer.add_special_tokens(special_tokens_dict)
107
+ print(f"Add image tag = {image_tag} to the tokenizer")
108
+
109
+ self.image_tag = image_tag
110
+ self.image_start_tag = image_start_tag
111
+ self.image_end_tag = image_end_tag
112
+ self.pad_tag = pad_tag
113
+
114
+ self.num_image_tokens = num_image_tokens
115
+ self.add_special_token = add_special_token
116
+ self.sft_format = sft_format
117
+ self.mask_prompt = mask_prompt
118
+ self.ignore_id = ignore_id
119
+
120
+ super().__init__(
121
+ image_processor,
122
+ tokenizer,
123
+ image_tag,
124
+ num_image_tokens,
125
+ add_special_token,
126
+ sft_format,
127
+ mask_prompt,
128
+ ignore_id,
129
+ **kwargs,
130
+ )
131
+
132
+ def new_chat_template(self):
133
+ conv = get_conv_template(self.sft_format)
134
+ conv.set_system_message(self.system_prompt)
135
+ return conv
136
+
137
+ def apply_sft_template_for_multi_turn_prompts(
138
+ self,
139
+ conversations: List[Dict[str, str]],
140
+ sft_format: str = "deepseek",
141
+ system_prompt: str = "",
142
+ ):
143
+ """
144
+ Applies the SFT template to conversation.
145
+
146
+ An example of conversation:
147
+ conversation = [
148
+ {
149
+ "role": "User",
150
+ "content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
151
+ "images": [
152
+ "./multi-images/attribute_comparison_1.png",
153
+ "./multi-images/attribute_comparison_2.png"
154
+ ]
155
+ },
156
+ {
157
+ "role": "Assistant",
158
+ "content": ""
159
+ }
160
+ ]
161
+
162
+ Args:
163
+ conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
164
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
165
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
166
+
167
+ Returns:
168
+ sft_prompt (str): The formatted text.
169
+ """
170
+
171
+ conv = get_conv_template(sft_format)
172
+ conv.set_system_message(system_prompt)
173
+ for message in conversations:
174
+ conv.append_message(message["role"], message["content"].strip())
175
+ sft_prompt = conv.get_prompt().strip()
176
+
177
+ return sft_prompt
178
+
179
+ @property
180
+ def image_token(self):
181
+ return self.image_tag
182
+
183
+ @property
184
+ def image_id(self):
185
+ image_id = self.tokenizer.vocab.get(self.image_tag)
186
+ return image_id
187
+
188
+ @property
189
+ def image_start_id(self):
190
+ image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
191
+ return image_start_id
192
+
193
+ @property
194
+ def image_end_id(self):
195
+ image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
196
+ return image_end_id
197
+
198
+ @property
199
+ def image_start_token(self):
200
+ return self.image_start_tag
201
+
202
+ @property
203
+ def image_end_token(self):
204
+ return self.image_end_tag
205
+
206
+ @property
207
+ def pad_id(self):
208
+ pad_id = self.tokenizer.vocab.get(self.pad_tag)
209
+ # pad_id = self.tokenizer.pad_token_id
210
+ # if pad_id is None:
211
+ # pad_id = self.tokenizer.eos_token_id
212
+
213
+ return pad_id
214
+
215
+ def add_image_token(
216
+ self,
217
+ image_indices: List[int],
218
+ input_ids: torch.LongTensor,
219
+ ):
220
+ """
221
+
222
+ Args:
223
+ image_indices (List[int]): [index_0, index_1, ..., index_j]
224
+ input_ids (torch.LongTensor): [N]
225
+
226
+ Returns:
227
+ input_ids (torch.LongTensor): [N + image tokens]
228
+ num_image_tokens (torch.IntTensor): [n_images]
229
+ """
230
+
231
+ input_slices = []
232
+
233
+ start = 0
234
+ for index in image_indices:
235
+ if self.add_special_token:
236
+ end = index + 1
237
+ else:
238
+ end = index
239
+
240
+ # original text tokens
241
+ input_slices.append(input_ids[start:end])
242
+
243
+ # add boi, image tokens, eoi and set the mask as False
244
+ input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
245
+ input_slices.append(
246
+ self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
247
+ )
248
+ input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
249
+ start = index + 1
250
+
251
+ # the left part
252
+ input_slices.append(input_ids[start:])
253
+
254
+ # concat all slices
255
+ input_ids = torch.cat(input_slices, dim=0)
256
+ num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
257
+
258
+ return input_ids, num_image_tokens
259
+
260
+ def process_one(
261
+ self,
262
+ prompt: str = None,
263
+ conversations: List[Dict[str, str]] = None,
264
+ images: List[Image] = None,
265
+ **kwargs,
266
+ ):
267
+ """
268
+
269
+ Args:
270
+ prompt (str): the formatted prompt;
271
+ conversations (List[Dict]): conversations with a list of messages;
272
+ images (List[ImageType]): the list of images;
273
+ **kwargs:
274
+
275
+ Returns:
276
+ outputs (BaseProcessorOutput): the output of the processor,
277
+ - input_ids (torch.LongTensor): [N + image tokens]
278
+ - target_ids (torch.LongTensor): [N + image tokens]
279
+ - images (torch.FloatTensor): [n_images, 3, H, W]
280
+ - image_id (int): the id of the image token
281
+ - num_image_tokens (List[int]): the number of image tokens
282
+ """
283
+
284
+ assert (
285
+ prompt is None or conversations is None
286
+ ), "prompt and conversations cannot be used at the same time."
287
+
288
+ if prompt is None:
289
+ # apply sft format
290
+ sft_format = self.apply_sft_template_for_multi_turn_prompts(
291
+ conversations=conversations,
292
+ sft_format=self.sft_format,
293
+ system_prompt=self.system_prompt,
294
+ )
295
+ else:
296
+ sft_format = prompt
297
+
298
+ # tokenize
299
+ input_ids = self.tokenizer.encode(sft_format)
300
+ input_ids = torch.LongTensor(input_ids)
301
+
302
+ # add image tokens to the input_ids
303
+ image_token_mask: torch.BoolTensor = input_ids == self.image_id
304
+ image_indices = image_token_mask.nonzero()
305
+ input_ids, num_image_tokens = self.add_image_token(
306
+ image_indices=image_indices,
307
+ input_ids=input_ids,
308
+ )
309
+
310
+ # load images
311
+ images_outputs = self.image_processor(images, return_tensors="pt")
312
+
313
+ prepare = VLChatProcessorOutput(
314
+ sft_format=sft_format,
315
+ input_ids=input_ids,
316
+ pixel_values=images_outputs.pixel_values,
317
+ num_image_tokens=num_image_tokens,
318
+ )
319
+
320
+ return prepare
321
+
322
+ def __call__(
323
+ self,
324
+ *,
325
+ prompt: str = None,
326
+ conversations: List[Dict[str, str]] = None,
327
+ images: List[Image] = None,
328
+ force_batchify: bool = True,
329
+ **kwargs,
330
+ ):
331
+ """
332
+
333
+ Args:
334
+ prompt (str): the formatted prompt;
335
+ conversations (List[Dict]): conversations with a list of messages;
336
+ images (List[ImageType]): the list of images;
337
+ force_batchify (bool): force batchify the inputs;
338
+ **kwargs:
339
+
340
+ Returns:
341
+ outputs (BaseProcessorOutput): the output of the processor,
342
+ - input_ids (torch.LongTensor): [N + image tokens]
343
+ - images (torch.FloatTensor): [n_images, 3, H, W]
344
+ - image_id (int): the id of the image token
345
+ - num_image_tokens (List[int]): the number of image tokens
346
+ """
347
+
348
+ prepare = self.process_one(
349
+ prompt=prompt, conversations=conversations, images=images
350
+ )
351
+
352
+ if force_batchify:
353
+ prepare = self.batchify([prepare])
354
+
355
+ return prepare
356
+
357
+ def batchify(
358
+ self, prepare_list: List[VLChatProcessorOutput]
359
+ ) -> BatchedVLChatProcessorOutput:
360
+ """
361
+ Preprocesses the inputs for multimodal inference.
362
+
363
+ Args:
364
+ prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
365
+
366
+ Returns:
367
+ BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
368
+ """
369
+
370
+ batch_size = len(prepare_list)
371
+ sft_format = []
372
+ n_images = []
373
+ seq_lens = []
374
+ for prepare in prepare_list:
375
+ n_images.append(len(prepare.num_image_tokens))
376
+ seq_lens.append(len(prepare))
377
+
378
+ input_token_max_len = max(seq_lens)
379
+ max_n_images = max(1, max(n_images))
380
+
381
+ batched_input_ids = torch.full(
382
+ (batch_size, input_token_max_len), self.pad_id
383
+ ).long() # FIXME
384
+ batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
385
+ batched_pixel_values = torch.zeros(
386
+ (batch_size, max_n_images, *self.image_processor.default_shape)
387
+ ).float()
388
+ batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
389
+ batched_images_emb_mask = torch.zeros(
390
+ (batch_size, max_n_images, self.num_image_tokens)
391
+ ).bool()
392
+
393
+ for i, prepare in enumerate(prepare_list):
394
+ input_ids = prepare.input_ids
395
+ seq_len = len(prepare)
396
+ n_image = len(prepare.num_image_tokens)
397
+ # left-padding
398
+ batched_attention_mask[i, -seq_len:] = 1
399
+ batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
400
+ batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
401
+
402
+ if n_image > 0:
403
+ batched_pixel_values[i, :n_image] = prepare.pixel_values
404
+ for j, n_image_tokens in enumerate(prepare.num_image_tokens):
405
+ batched_images_emb_mask[i, j, :n_image_tokens] = True
406
+
407
+ sft_format.append(prepare.sft_format)
408
+
409
+ batched_prepares = BatchedVLChatProcessorOutput(
410
+ input_ids=batched_input_ids,
411
+ attention_mask=batched_attention_mask,
412
+ pixel_values=batched_pixel_values,
413
+ images_seq_mask=batched_images_seq_mask,
414
+ images_emb_mask=batched_images_emb_mask,
415
+ sft_format=sft_format,
416
+ )
417
+
418
+ return batched_prepares
janus/models/projector.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from attrdict import AttrDict
25
+
26
+
27
+ class MlpProjector(nn.Module):
28
+ def __init__(self, cfg):
29
+ super().__init__()
30
+
31
+ self.cfg = cfg
32
+
33
+ if cfg.projector_type == "identity":
34
+ modules = nn.Identity()
35
+
36
+ elif cfg.projector_type == "linear":
37
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
38
+
39
+ elif cfg.projector_type == "mlp_gelu":
40
+ mlp_depth = cfg.get("depth", 1)
41
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
45
+ modules = nn.Sequential(*modules)
46
+
47
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
48
+ mlp_depth = cfg.get("depth", 1)
49
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
50
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
51
+
52
+ modules = []
53
+ for _ in range(1, mlp_depth):
54
+ modules.append(nn.GELU())
55
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
56
+ modules = nn.Sequential(*modules)
57
+
58
+ else:
59
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
60
+
61
+ self.layers = modules
62
+
63
+ def forward(
64
+ self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
65
+ ):
66
+ """
67
+
68
+ Args:
69
+ x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
70
+ then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
71
+ otherwise it is the feature from the single vision encoder.
72
+
73
+ Returns:
74
+ x (torch.Tensor): [b, s, c]
75
+ """
76
+
77
+ if isinstance(x_or_tuple, tuple):
78
+ # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
79
+ high_x, low_x = x_or_tuple
80
+ high_x = self.high_up_proj(high_x)
81
+ low_x = self.low_up_proj(low_x)
82
+ x = torch.concat([high_x, low_x], dim=-1)
83
+ else:
84
+ x = x_or_tuple
85
+
86
+ return self.layers(x)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ cfg = AttrDict(
91
+ input_dim=1024,
92
+ n_embed=2048,
93
+ depth=2,
94
+ projector_type="low_high_hybrid_split_mlp_gelu",
95
+ )
96
+ inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
97
+
98
+ m = MlpProjector(cfg)
99
+ out = m(inputs)
100
+ print(out.shape)
janus/models/siglip_vit.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
21
+ import math
22
+ import warnings
23
+ from dataclasses import dataclass
24
+ from functools import partial
25
+ from typing import (
26
+ Callable,
27
+ Dict,
28
+ Final,
29
+ List,
30
+ Literal,
31
+ Optional,
32
+ Sequence,
33
+ Set,
34
+ Tuple,
35
+ Type,
36
+ Union,
37
+ )
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ from timm.layers import (
43
+ AttentionPoolLatent,
44
+ DropPath,
45
+ LayerType,
46
+ Mlp,
47
+ PatchDropout,
48
+ PatchEmbed,
49
+ resample_abs_pos_embed,
50
+ )
51
+ from timm.models._manipulate import checkpoint_seq, named_apply
52
+
53
+
54
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
55
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
56
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
57
+ def norm_cdf(x):
58
+ # Computes standard normal cumulative distribution function
59
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
60
+
61
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
62
+ warnings.warn(
63
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
64
+ "The distribution of values may be incorrect.",
65
+ stacklevel=2,
66
+ )
67
+
68
+ with torch.no_grad():
69
+ # Values are generated by using a truncated uniform distribution and
70
+ # then using the inverse CDF for the normal distribution.
71
+ # Get upper and lower cdf values
72
+ l = norm_cdf((a - mean) / std) # noqa: E741
73
+ u = norm_cdf((b - mean) / std)
74
+
75
+ # Uniformly fill tensor with values from [l, u], then translate to
76
+ # [2l-1, 2u-1].
77
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
78
+
79
+ # Use inverse cdf transform for normal distribution to get truncated
80
+ # standard normal
81
+ tensor.erfinv_()
82
+
83
+ # Transform to proper mean, std
84
+ tensor.mul_(std * math.sqrt(2.0))
85
+ tensor.add_(mean)
86
+
87
+ # Clamp to ensure it's in the proper range
88
+ tensor.clamp_(min=a, max=b)
89
+ return tensor
90
+
91
+
92
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
93
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
94
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
95
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
96
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
97
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98
+ with values outside :math:`[a, b]` redrawn until they are within
99
+ the bounds. The method used for generating the random values works
100
+ best when :math:`a \leq \text{mean} \leq b`.
101
+ Args:
102
+ tensor: an n-dimensional `torch.Tensor`
103
+ mean: the mean of the normal distribution
104
+ std: the standard deviation of the normal distribution
105
+ a: the minimum cutoff value
106
+ b: the maximum cutoff value
107
+ Examples:
108
+ >>> w = torch.empty(3, 5)
109
+ >>> nn.init.trunc_normal_(w)
110
+ """
111
+
112
+ with torch.no_grad():
113
+ dtype = tensor.dtype
114
+ tensor_fp32 = tensor.float()
115
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
116
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
117
+ tensor.copy_(tensor_dtype)
118
+
119
+
120
+ def init_weights(self):
121
+ if self.pos_embed is not None:
122
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
123
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
124
+
125
+
126
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
127
+ """ViT weight initialization, original timm impl (for reproducibility)"""
128
+ if isinstance(module, nn.Linear):
129
+ trunc_normal_(module.weight, std=0.02)
130
+ if module.bias is not None:
131
+ nn.init.zeros_(module.bias)
132
+ elif hasattr(module, "init_weights"):
133
+ module.init_weights()
134
+
135
+
136
+ class Attention(nn.Module):
137
+ fused_attn: Final[bool]
138
+
139
+ def __init__(
140
+ self,
141
+ dim: int,
142
+ num_heads: int = 8,
143
+ qkv_bias: bool = False,
144
+ qk_norm: bool = False,
145
+ attn_drop: float = 0.0,
146
+ proj_drop: float = 0.0,
147
+ norm_layer: nn.Module = nn.LayerNorm,
148
+ ) -> None:
149
+ super().__init__()
150
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
151
+ self.num_heads = num_heads
152
+ self.head_dim = dim // num_heads
153
+ self.scale = self.head_dim**-0.5
154
+ # self.fused_attn = use_fused_attn()
155
+ self.fused_attn = True
156
+
157
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
158
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
159
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
160
+ self.attn_drop = nn.Dropout(attn_drop)
161
+ self.proj = nn.Linear(dim, dim)
162
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ B, N, C = x.shape
166
+ qkv = (
167
+ self.qkv(x)
168
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
169
+ .permute(2, 0, 3, 1, 4)
170
+ )
171
+ q, k, v = qkv.unbind(0)
172
+ q, k = self.q_norm(q), self.k_norm(k)
173
+
174
+ if self.fused_attn:
175
+ x = F.scaled_dot_product_attention(
176
+ q,
177
+ k,
178
+ v,
179
+ dropout_p=self.attn_drop.p if self.training else 0.0,
180
+ )
181
+ else:
182
+ q = q * self.scale
183
+ attn = q @ k.transpose(-2, -1)
184
+ attn = attn.softmax(dim=-1)
185
+ attn = self.attn_drop(attn)
186
+ x = attn @ v
187
+
188
+ x = x.transpose(1, 2).reshape(B, N, C)
189
+ x = self.proj(x)
190
+ x = self.proj_drop(x)
191
+ return x
192
+
193
+
194
+ class LayerScale(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim: int,
198
+ init_values: float = 1e-5,
199
+ inplace: bool = False,
200
+ ) -> None:
201
+ super().__init__()
202
+ self.inplace = inplace
203
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
207
+
208
+
209
+ class Block(nn.Module):
210
+ def __init__(
211
+ self,
212
+ dim: int,
213
+ num_heads: int,
214
+ mlp_ratio: float = 4.0,
215
+ qkv_bias: bool = False,
216
+ qk_norm: bool = False,
217
+ proj_drop: float = 0.0,
218
+ attn_drop: float = 0.0,
219
+ init_values: Optional[float] = None,
220
+ drop_path: float = 0.0,
221
+ act_layer: nn.Module = nn.GELU,
222
+ norm_layer: nn.Module = nn.LayerNorm,
223
+ mlp_layer: nn.Module = Mlp,
224
+ ) -> None:
225
+ super().__init__()
226
+ self.norm1 = norm_layer(dim)
227
+ self.attn = Attention(
228
+ dim,
229
+ num_heads=num_heads,
230
+ qkv_bias=qkv_bias,
231
+ qk_norm=qk_norm,
232
+ attn_drop=attn_drop,
233
+ proj_drop=proj_drop,
234
+ norm_layer=norm_layer,
235
+ )
236
+ self.ls1 = (
237
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
238
+ )
239
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
240
+
241
+ self.norm2 = norm_layer(dim)
242
+ self.mlp = mlp_layer(
243
+ in_features=dim,
244
+ hidden_features=int(dim * mlp_ratio),
245
+ act_layer=act_layer,
246
+ drop=proj_drop,
247
+ )
248
+ self.ls2 = (
249
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
250
+ )
251
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
252
+
253
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
254
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
255
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
256
+ return x
257
+
258
+
259
+ class VisionTransformer(nn.Module):
260
+ """Vision Transformer
261
+
262
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
263
+ - https://arxiv.org/abs/2010.11929
264
+ """
265
+
266
+ dynamic_img_size: Final[bool]
267
+
268
+ def __init__(
269
+ self,
270
+ img_size: Union[int, Tuple[int, int]] = 224,
271
+ patch_size: Union[int, Tuple[int, int]] = 16,
272
+ in_chans: int = 3,
273
+ num_classes: int = 1000,
274
+ global_pool: Literal["", "avg", "token", "map"] = "token",
275
+ embed_dim: int = 768,
276
+ depth: int = 12,
277
+ num_heads: int = 12,
278
+ mlp_ratio: float = 4.0,
279
+ qkv_bias: bool = True,
280
+ qk_norm: bool = False,
281
+ init_values: Optional[float] = None,
282
+ class_token: bool = True,
283
+ no_embed_class: bool = False,
284
+ reg_tokens: int = 0,
285
+ pre_norm: bool = False,
286
+ fc_norm: Optional[bool] = None,
287
+ dynamic_img_size: bool = False,
288
+ dynamic_img_pad: bool = False,
289
+ drop_rate: float = 0.0,
290
+ pos_drop_rate: float = 0.0,
291
+ patch_drop_rate: float = 0.0,
292
+ proj_drop_rate: float = 0.0,
293
+ attn_drop_rate: float = 0.0,
294
+ drop_path_rate: float = 0.0,
295
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
296
+ embed_layer: Callable = PatchEmbed,
297
+ norm_layer: Optional[LayerType] = None,
298
+ act_layer: Optional[LayerType] = None,
299
+ block_fn: Type[nn.Module] = Block,
300
+ mlp_layer: Type[nn.Module] = Mlp,
301
+ ignore_head: bool = False,
302
+ ) -> None:
303
+ """
304
+ Args:
305
+ img_size: Input image size.
306
+ patch_size: Patch size.
307
+ in_chans: Number of image input channels.
308
+ num_classes: Mumber of classes for classification head.
309
+ global_pool: Type of global pooling for final sequence (default: 'token').
310
+ embed_dim: Transformer embedding dimension.
311
+ depth: Depth of transformer.
312
+ num_heads: Number of attention heads.
313
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
314
+ qkv_bias: Enable bias for qkv projections if True.
315
+ init_values: Layer-scale init values (layer-scale enabled if not None).
316
+ class_token: Use class token.
317
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
318
+ reg_tokens: Number of register tokens.
319
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
320
+ drop_rate: Head dropout rate.
321
+ pos_drop_rate: Position embedding dropout rate.
322
+ attn_drop_rate: Attention dropout rate.
323
+ drop_path_rate: Stochastic depth rate.
324
+ weight_init: Weight initialization scheme.
325
+ embed_layer: Patch embedding layer.
326
+ norm_layer: Normalization layer.
327
+ act_layer: MLP activation layer.
328
+ block_fn: Transformer block layer.
329
+ """
330
+ super().__init__()
331
+ assert global_pool in ("", "avg", "token", "map")
332
+ assert class_token or global_pool != "token"
333
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
334
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
335
+ # act_layer = get_act_layer(act_layer) or nn.GELU
336
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
337
+ act_layer = nn.GELU
338
+
339
+ self.num_classes = num_classes
340
+ self.global_pool = global_pool
341
+ self.num_features = self.embed_dim = (
342
+ embed_dim # num_features for consistency with other models
343
+ )
344
+ self.num_prefix_tokens = 1 if class_token else 0
345
+ self.num_prefix_tokens += reg_tokens
346
+ self.num_reg_tokens = reg_tokens
347
+ self.has_class_token = class_token
348
+ self.no_embed_class = (
349
+ no_embed_class # don't embed prefix positions (includes reg)
350
+ )
351
+ self.dynamic_img_size = dynamic_img_size
352
+ self.grad_checkpointing = False
353
+ self.ignore_head = ignore_head
354
+
355
+ embed_args = {}
356
+ if dynamic_img_size:
357
+ # flatten deferred until after pos embed
358
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
359
+ self.patch_embed = embed_layer(
360
+ img_size=img_size,
361
+ patch_size=patch_size,
362
+ in_chans=in_chans,
363
+ embed_dim=embed_dim,
364
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
365
+ dynamic_img_pad=dynamic_img_pad,
366
+ **embed_args,
367
+ )
368
+ num_patches = self.patch_embed.num_patches
369
+
370
+ self.cls_token = (
371
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
372
+ )
373
+ self.reg_token = (
374
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
375
+ )
376
+ embed_len = (
377
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
378
+ )
379
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
380
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
381
+ if patch_drop_rate > 0:
382
+ self.patch_drop = PatchDropout(
383
+ patch_drop_rate,
384
+ num_prefix_tokens=self.num_prefix_tokens,
385
+ )
386
+ else:
387
+ self.patch_drop = nn.Identity()
388
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
389
+
390
+ dpr = [
391
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
392
+ ] # stochastic depth decay rule
393
+ self.blocks = nn.Sequential(
394
+ *[
395
+ block_fn(
396
+ dim=embed_dim,
397
+ num_heads=num_heads,
398
+ mlp_ratio=mlp_ratio,
399
+ qkv_bias=qkv_bias,
400
+ qk_norm=qk_norm,
401
+ init_values=init_values,
402
+ proj_drop=proj_drop_rate,
403
+ attn_drop=attn_drop_rate,
404
+ drop_path=dpr[i],
405
+ norm_layer=norm_layer,
406
+ act_layer=act_layer,
407
+ mlp_layer=mlp_layer,
408
+ )
409
+ for i in range(depth)
410
+ ]
411
+ )
412
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
413
+
414
+ # Classifier Head
415
+ if global_pool == "map":
416
+ AttentionPoolLatent.init_weights = init_weights
417
+ self.attn_pool = AttentionPoolLatent(
418
+ self.embed_dim,
419
+ num_heads=num_heads,
420
+ mlp_ratio=mlp_ratio,
421
+ norm_layer=norm_layer,
422
+ )
423
+ else:
424
+ self.attn_pool = None
425
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
426
+ self.head_drop = nn.Dropout(drop_rate)
427
+ self.head = (
428
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
429
+ )
430
+
431
+ if weight_init != "skip":
432
+ self.init_weights(weight_init)
433
+
434
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
435
+ assert mode in ("jax", "jax_nlhb", "moco", "")
436
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
437
+ trunc_normal_(self.pos_embed, std=0.02)
438
+ if self.cls_token is not None:
439
+ nn.init.normal_(self.cls_token, std=1e-6)
440
+ named_apply(init_weights_vit_timm, self)
441
+
442
+ @torch.jit.ignore
443
+ def no_weight_decay(self) -> Set:
444
+ return {"pos_embed", "cls_token", "dist_token"}
445
+
446
+ @torch.jit.ignore
447
+ def group_matcher(self, coarse: bool = False) -> Dict:
448
+ return dict(
449
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
450
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
451
+ )
452
+
453
+ @torch.jit.ignore
454
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
455
+ self.grad_checkpointing = enable
456
+
457
+ @torch.jit.ignore
458
+ def get_classifier(self) -> nn.Module:
459
+ return self.head
460
+
461
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
462
+ self.num_classes = num_classes
463
+ if global_pool is not None:
464
+ assert global_pool in ("", "avg", "token", "map")
465
+ if global_pool == "map" and self.attn_pool is None:
466
+ assert (
467
+ False
468
+ ), "Cannot currently add attention pooling in reset_classifier()."
469
+ elif global_pool != "map " and self.attn_pool is not None:
470
+ self.attn_pool = None # remove attention pooling
471
+ self.global_pool = global_pool
472
+ self.head = (
473
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
474
+ )
475
+
476
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
477
+ if self.dynamic_img_size:
478
+ B, H, W, C = x.shape
479
+ pos_embed = resample_abs_pos_embed(
480
+ self.pos_embed,
481
+ (H, W),
482
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
483
+ )
484
+ x = x.view(B, -1, C)
485
+ else:
486
+ pos_embed = self.pos_embed
487
+
488
+ to_cat = []
489
+ if self.cls_token is not None:
490
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
491
+ if self.reg_token is not None:
492
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
493
+
494
+ if self.no_embed_class:
495
+ # deit-3, updated JAX (big vision)
496
+ # position embedding does not overlap with class token, add then concat
497
+ x = x + pos_embed
498
+ if to_cat:
499
+ x = torch.cat(to_cat + [x], dim=1)
500
+ else:
501
+ # original timm, JAX, and deit vit impl
502
+ # pos_embed has entry for class token, concat then add
503
+ if to_cat:
504
+ x = torch.cat(to_cat + [x], dim=1)
505
+ x = x + pos_embed
506
+
507
+ return self.pos_drop(x)
508
+
509
+ def _intermediate_layers(
510
+ self,
511
+ x: torch.Tensor,
512
+ n: Union[int, Sequence] = 1,
513
+ ) -> List[torch.Tensor]:
514
+ outputs, num_blocks = [], len(self.blocks)
515
+ take_indices = set(
516
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
517
+ )
518
+
519
+ # forward pass
520
+ x = self.patch_embed(x)
521
+ x = self._pos_embed(x)
522
+ x = self.patch_drop(x)
523
+ x = self.norm_pre(x)
524
+ for i, blk in enumerate(self.blocks):
525
+ x = blk(x)
526
+ if i in take_indices:
527
+ outputs.append(x)
528
+
529
+ return outputs
530
+
531
+ def get_intermediate_layers(
532
+ self,
533
+ x: torch.Tensor,
534
+ n: Union[int, Sequence] = 1,
535
+ reshape: bool = False,
536
+ return_prefix_tokens: bool = False,
537
+ norm: bool = False,
538
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
539
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
540
+ Inspired by DINO / DINOv2 interface
541
+ """
542
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
543
+ outputs = self._intermediate_layers(x, n)
544
+ if norm:
545
+ outputs = [self.norm(out) for out in outputs]
546
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
547
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
548
+
549
+ if reshape:
550
+ grid_size = self.patch_embed.grid_size
551
+ outputs = [
552
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
553
+ .permute(0, 3, 1, 2)
554
+ .contiguous()
555
+ for out in outputs
556
+ ]
557
+
558
+ if return_prefix_tokens:
559
+ return tuple(zip(outputs, prefix_tokens))
560
+ return tuple(outputs)
561
+
562
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
563
+ x = self.patch_embed(x)
564
+ x = self._pos_embed(x)
565
+ x = self.patch_drop(x)
566
+ x = self.norm_pre(x)
567
+ if self.grad_checkpointing and not torch.jit.is_scripting():
568
+ x = checkpoint_seq(self.blocks, x)
569
+ else:
570
+ x = self.blocks(x)
571
+ x = self.norm(x)
572
+ return x
573
+
574
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
575
+ if self.attn_pool is not None:
576
+ x = self.attn_pool(x)
577
+ elif self.global_pool == "avg":
578
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
579
+ elif self.global_pool:
580
+ x = x[:, 0] # class token
581
+ x = self.fc_norm(x)
582
+ x = self.head_drop(x)
583
+ return x if pre_logits else self.head(x)
584
+
585
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
586
+ x = self.forward_features(x)
587
+ if not self.ignore_head:
588
+ x = self.forward_head(x)
589
+ return x
590
+
591
+
592
+ @dataclass
593
+ class SigLIPVisionCfg:
594
+ width: int = 1152
595
+ layers: Union[Tuple[int, int, int, int], int] = 27
596
+ heads: int = 16
597
+ patch_size: int = 14
598
+ image_size: Union[Tuple[int, int], int] = 336
599
+ global_pool: str = "map"
600
+ mlp_ratio: float = 3.7362
601
+ class_token: bool = False
602
+ num_classes: int = 0
603
+ use_checkpoint: bool = False
604
+
605
+
606
+ SigLIP_MODEL_CONFIG = {
607
+ "siglip_so400m_patch14_384": {
608
+ "image_size": 336,
609
+ "patch_size": 14,
610
+ "width": 1152,
611
+ "layers": 27,
612
+ "heads": 16,
613
+ "mlp_ratio": 3.7362,
614
+ "global_pool": "map",
615
+ "use_checkpoint": False,
616
+ },
617
+ "siglip_so400m_patch14_224": {
618
+ "image_size": 224,
619
+ "patch_size": 14,
620
+ "width": 1152,
621
+ "layers": 27,
622
+ "heads": 16,
623
+ "mlp_ratio": 3.7362,
624
+ "global_pool": "map",
625
+ "use_checkpoint": False,
626
+ },
627
+ "siglip_large_patch16_384": {
628
+ "image_size": 384,
629
+ "patch_size": 16,
630
+ "width": 1024,
631
+ "layers": 24,
632
+ "heads": 16,
633
+ "mlp_ratio": 4,
634
+ "global_pool": "map",
635
+ "use_checkpoint": False,
636
+ },
637
+ }
638
+
639
+
640
+ def create_siglip_vit(
641
+ model_name: str = "siglip_so400m_patch14_384",
642
+ image_size: int = 384,
643
+ select_layer: int = -1,
644
+ ckpt_path: str = "",
645
+ **kwargs,
646
+ ):
647
+ assert (
648
+ model_name in SigLIP_MODEL_CONFIG.keys()
649
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
650
+
651
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
652
+
653
+ if select_layer <= 0:
654
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
655
+ else:
656
+ layers = min(vision_cfg.layers, select_layer)
657
+
658
+ model = VisionTransformer(
659
+ img_size=image_size,
660
+ patch_size=vision_cfg.patch_size,
661
+ embed_dim=vision_cfg.width,
662
+ depth=layers,
663
+ num_heads=vision_cfg.heads,
664
+ mlp_ratio=vision_cfg.mlp_ratio,
665
+ class_token=vision_cfg.class_token,
666
+ global_pool=vision_cfg.global_pool,
667
+ ignore_head=kwargs.get("ignore_head", True),
668
+ weight_init=kwargs.get("weight_init", "skip"),
669
+ num_classes=0,
670
+ )
671
+
672
+ if ckpt_path:
673
+ state_dict = torch.load(ckpt_path, map_location="cpu")
674
+
675
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
676
+ print(
677
+ f"SigLIP-ViT restores from {ckpt_path},\n"
678
+ f"\tincompatible_keys:', {incompatible_keys}."
679
+ )
680
+
681
+ return model
janus/models/vq_model.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+
21
+ from dataclasses import dataclass, field
22
+ from typing import List
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ from functools import partial
29
+
30
+
31
+ @dataclass
32
+ class ModelArgs:
33
+ codebook_size: int = 16384
34
+ codebook_embed_dim: int = 8
35
+ codebook_l2_norm: bool = True
36
+ codebook_show_usage: bool = True
37
+ commit_loss_beta: float = 0.25
38
+ entropy_loss_ratio: float = 0.0
39
+
40
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
41
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
42
+ z_channels: int = 256
43
+ dropout_p: float = 0.0
44
+
45
+
46
+ class Encoder(nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_channels=3,
50
+ ch=128,
51
+ ch_mult=(1, 1, 2, 2, 4),
52
+ num_res_blocks=2,
53
+ norm_type="group",
54
+ dropout=0.0,
55
+ resamp_with_conv=True,
56
+ z_channels=256,
57
+ ):
58
+ super().__init__()
59
+ self.num_resolutions = len(ch_mult)
60
+ self.num_res_blocks = num_res_blocks
61
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
62
+
63
+ # downsampling
64
+ in_ch_mult = (1,) + tuple(ch_mult)
65
+ self.conv_blocks = nn.ModuleList()
66
+ for i_level in range(self.num_resolutions):
67
+ conv_block = nn.Module()
68
+ # res & attn
69
+ res_block = nn.ModuleList()
70
+ attn_block = nn.ModuleList()
71
+ block_in = ch * in_ch_mult[i_level]
72
+ block_out = ch * ch_mult[i_level]
73
+ for _ in range(self.num_res_blocks):
74
+ res_block.append(
75
+ ResnetBlock(
76
+ block_in, block_out, dropout=dropout, norm_type=norm_type
77
+ )
78
+ )
79
+ block_in = block_out
80
+ if i_level == self.num_resolutions - 1:
81
+ attn_block.append(AttnBlock(block_in, norm_type))
82
+ conv_block.res = res_block
83
+ conv_block.attn = attn_block
84
+ # downsample
85
+ if i_level != self.num_resolutions - 1:
86
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
87
+ self.conv_blocks.append(conv_block)
88
+
89
+ # middle
90
+ self.mid = nn.ModuleList()
91
+ self.mid.append(
92
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
93
+ )
94
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
95
+ self.mid.append(
96
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
97
+ )
98
+
99
+ # end
100
+ self.norm_out = Normalize(block_in, norm_type)
101
+ self.conv_out = nn.Conv2d(
102
+ block_in, z_channels, kernel_size=3, stride=1, padding=1
103
+ )
104
+
105
+ def forward(self, x):
106
+ h = self.conv_in(x)
107
+ # downsampling
108
+ for i_level, block in enumerate(self.conv_blocks):
109
+ for i_block in range(self.num_res_blocks):
110
+ h = block.res[i_block](h)
111
+ if len(block.attn) > 0:
112
+ h = block.attn[i_block](h)
113
+ if i_level != self.num_resolutions - 1:
114
+ h = block.downsample(h)
115
+
116
+ # middle
117
+ for mid_block in self.mid:
118
+ h = mid_block(h)
119
+
120
+ # end
121
+ h = self.norm_out(h)
122
+ h = nonlinearity(h)
123
+ h = self.conv_out(h)
124
+ return h
125
+
126
+
127
+ class Decoder(nn.Module):
128
+ def __init__(
129
+ self,
130
+ z_channels=256,
131
+ ch=128,
132
+ ch_mult=(1, 1, 2, 2, 4),
133
+ num_res_blocks=2,
134
+ norm_type="group",
135
+ dropout=0.0,
136
+ resamp_with_conv=True,
137
+ out_channels=3,
138
+ ):
139
+ super().__init__()
140
+ self.num_resolutions = len(ch_mult)
141
+ self.num_res_blocks = num_res_blocks
142
+
143
+ block_in = ch * ch_mult[self.num_resolutions - 1]
144
+ # z to block_in
145
+ self.conv_in = nn.Conv2d(
146
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
147
+ )
148
+
149
+ # middle
150
+ self.mid = nn.ModuleList()
151
+ self.mid.append(
152
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
153
+ )
154
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
155
+ self.mid.append(
156
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
157
+ )
158
+
159
+ # upsampling
160
+ self.conv_blocks = nn.ModuleList()
161
+ for i_level in reversed(range(self.num_resolutions)):
162
+ conv_block = nn.Module()
163
+ # res & attn
164
+ res_block = nn.ModuleList()
165
+ attn_block = nn.ModuleList()
166
+ block_out = ch * ch_mult[i_level]
167
+ for _ in range(self.num_res_blocks + 1):
168
+ res_block.append(
169
+ ResnetBlock(
170
+ block_in, block_out, dropout=dropout, norm_type=norm_type
171
+ )
172
+ )
173
+ block_in = block_out
174
+ if i_level == self.num_resolutions - 1:
175
+ attn_block.append(AttnBlock(block_in, norm_type))
176
+ conv_block.res = res_block
177
+ conv_block.attn = attn_block
178
+ # downsample
179
+ if i_level != 0:
180
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
181
+ self.conv_blocks.append(conv_block)
182
+
183
+ # end
184
+ self.norm_out = Normalize(block_in, norm_type)
185
+ self.conv_out = nn.Conv2d(
186
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
187
+ )
188
+
189
+ @property
190
+ def last_layer(self):
191
+ return self.conv_out.weight
192
+
193
+ def forward(self, z):
194
+ # z to block_in
195
+ h = self.conv_in(z)
196
+
197
+ # middle
198
+ for mid_block in self.mid:
199
+ h = mid_block(h)
200
+
201
+ # upsampling
202
+ for i_level, block in enumerate(self.conv_blocks):
203
+ for i_block in range(self.num_res_blocks + 1):
204
+ h = block.res[i_block](h)
205
+ if len(block.attn) > 0:
206
+ h = block.attn[i_block](h)
207
+ if i_level != self.num_resolutions - 1:
208
+ h = block.upsample(h)
209
+
210
+ # end
211
+ h = self.norm_out(h)
212
+ h = nonlinearity(h)
213
+ h = self.conv_out(h)
214
+ return h
215
+
216
+
217
+ class VectorQuantizer(nn.Module):
218
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
219
+ super().__init__()
220
+ self.n_e = n_e
221
+ self.e_dim = e_dim
222
+ self.beta = beta
223
+ self.entropy_loss_ratio = entropy_loss_ratio
224
+ self.l2_norm = l2_norm
225
+ self.show_usage = show_usage
226
+
227
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
228
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
229
+ if self.l2_norm:
230
+ self.embedding.weight.data = F.normalize(
231
+ self.embedding.weight.data, p=2, dim=-1
232
+ )
233
+ if self.show_usage:
234
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
235
+
236
+ def forward(self, z):
237
+ # reshape z -> (batch, height, width, channel) and flatten
238
+ z = torch.einsum("b c h w -> b h w c", z).contiguous()
239
+ z_flattened = z.view(-1, self.e_dim)
240
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
241
+
242
+ if self.l2_norm:
243
+ z = F.normalize(z, p=2, dim=-1)
244
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
245
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
246
+ else:
247
+ embedding = self.embedding.weight
248
+
249
+ d = (
250
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
251
+ + torch.sum(embedding**2, dim=1)
252
+ - 2
253
+ * torch.einsum(
254
+ "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
255
+ )
256
+ )
257
+
258
+ min_encoding_indices = torch.argmin(d, dim=1)
259
+ z_q = embedding[min_encoding_indices].view(z.shape)
260
+ perplexity = None
261
+ min_encodings = None
262
+ vq_loss = None
263
+ commit_loss = None
264
+ entropy_loss = None
265
+
266
+ # compute loss for embedding
267
+ if self.training:
268
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
269
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
270
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
271
+
272
+ # preserve gradients
273
+ z_q = z + (z_q - z).detach()
274
+
275
+ # reshape back to match original input shape
276
+ z_q = torch.einsum("b h w c -> b c h w", z_q)
277
+
278
+ return (
279
+ z_q,
280
+ (vq_loss, commit_loss, entropy_loss),
281
+ (perplexity, min_encodings, min_encoding_indices),
282
+ )
283
+
284
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
285
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
286
+ if self.l2_norm:
287
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
288
+ else:
289
+ embedding = self.embedding.weight
290
+ z_q = embedding[indices] # (b*h*w, c)
291
+
292
+ if shape is not None:
293
+ if channel_first:
294
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
295
+ # reshape back to match original input shape
296
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
297
+ else:
298
+ z_q = z_q.view(shape)
299
+ return z_q
300
+
301
+
302
+ class ResnetBlock(nn.Module):
303
+ def __init__(
304
+ self,
305
+ in_channels,
306
+ out_channels=None,
307
+ conv_shortcut=False,
308
+ dropout=0.0,
309
+ norm_type="group",
310
+ ):
311
+ super().__init__()
312
+ self.in_channels = in_channels
313
+ out_channels = in_channels if out_channels is None else out_channels
314
+ self.out_channels = out_channels
315
+ self.use_conv_shortcut = conv_shortcut
316
+
317
+ self.norm1 = Normalize(in_channels, norm_type)
318
+ self.conv1 = nn.Conv2d(
319
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
320
+ )
321
+ self.norm2 = Normalize(out_channels, norm_type)
322
+ self.dropout = nn.Dropout(dropout)
323
+ self.conv2 = nn.Conv2d(
324
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
325
+ )
326
+
327
+ if self.in_channels != self.out_channels:
328
+ if self.use_conv_shortcut:
329
+ self.conv_shortcut = nn.Conv2d(
330
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
331
+ )
332
+ else:
333
+ self.nin_shortcut = nn.Conv2d(
334
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
335
+ )
336
+
337
+ def forward(self, x):
338
+ h = x
339
+ h = self.norm1(h)
340
+ h = nonlinearity(h)
341
+ h = self.conv1(h)
342
+ h = self.norm2(h)
343
+ h = nonlinearity(h)
344
+ h = self.dropout(h)
345
+ h = self.conv2(h)
346
+
347
+ if self.in_channels != self.out_channels:
348
+ if self.use_conv_shortcut:
349
+ x = self.conv_shortcut(x)
350
+ else:
351
+ x = self.nin_shortcut(x)
352
+ return x + h
353
+
354
+
355
+ class AttnBlock(nn.Module):
356
+ def __init__(self, in_channels, norm_type="group"):
357
+ super().__init__()
358
+ self.norm = Normalize(in_channels, norm_type)
359
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
360
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
361
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
362
+ self.proj_out = nn.Conv2d(
363
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
364
+ )
365
+
366
+ def forward(self, x):
367
+ h_ = x
368
+ h_ = self.norm(h_)
369
+ q = self.q(h_)
370
+ k = self.k(h_)
371
+ v = self.v(h_)
372
+
373
+ # compute attention
374
+ b, c, h, w = q.shape
375
+ q = q.reshape(b, c, h * w)
376
+ q = q.permute(0, 2, 1) # b,hw,c
377
+ k = k.reshape(b, c, h * w) # b,c,hw
378
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
379
+ w_ = w_ * (int(c) ** (-0.5))
380
+ w_ = F.softmax(w_, dim=2)
381
+
382
+ # attend to values
383
+ v = v.reshape(b, c, h * w)
384
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
385
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
386
+ h_ = h_.reshape(b, c, h, w)
387
+
388
+ h_ = self.proj_out(h_)
389
+
390
+ return x + h_
391
+
392
+
393
+ def nonlinearity(x):
394
+ # swish
395
+ return x * torch.sigmoid(x)
396
+
397
+
398
+ def Normalize(in_channels, norm_type="group"):
399
+ assert norm_type in ["group", "batch"]
400
+ if norm_type == "group":
401
+ return nn.GroupNorm(
402
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
403
+ )
404
+ elif norm_type == "batch":
405
+ return nn.SyncBatchNorm(in_channels)
406
+
407
+
408
+ class Upsample(nn.Module):
409
+ def __init__(self, in_channels, with_conv):
410
+ super().__init__()
411
+ self.with_conv = with_conv
412
+ if self.with_conv:
413
+ self.conv = nn.Conv2d(
414
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
415
+ )
416
+
417
+ def forward(self, x):
418
+ if x.dtype != torch.float32:
419
+ x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
420
+ torch.bfloat16
421
+ )
422
+ else:
423
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
424
+
425
+ if self.with_conv:
426
+ x = self.conv(x)
427
+ return x
428
+
429
+
430
+ class Downsample(nn.Module):
431
+ def __init__(self, in_channels, with_conv):
432
+ super().__init__()
433
+ self.with_conv = with_conv
434
+ if self.with_conv:
435
+ # no asymmetric padding in torch conv, must do it ourselves
436
+ self.conv = nn.Conv2d(
437
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
438
+ )
439
+
440
+ def forward(self, x):
441
+ if self.with_conv:
442
+ pad = (0, 1, 0, 1)
443
+ x = F.pad(x, pad, mode="constant", value=0)
444
+ x = self.conv(x)
445
+ else:
446
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
447
+ return x
448
+
449
+
450
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
451
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
452
+ flat_affinity /= temperature
453
+ probs = F.softmax(flat_affinity, dim=-1)
454
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
455
+ if loss_type == "softmax":
456
+ target_probs = probs
457
+ else:
458
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
459
+ avg_probs = torch.mean(target_probs, dim=0)
460
+ avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
461
+ sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
462
+ loss = sample_entropy - avg_entropy
463
+ return loss
464
+
465
+
466
+ class VQModel(nn.Module):
467
+ def __init__(self, config: ModelArgs):
468
+ super().__init__()
469
+ self.config = config
470
+ self.encoder = Encoder(
471
+ ch_mult=config.encoder_ch_mult,
472
+ z_channels=config.z_channels,
473
+ dropout=config.dropout_p,
474
+ )
475
+ self.decoder = Decoder(
476
+ ch_mult=config.decoder_ch_mult,
477
+ z_channels=config.z_channels,
478
+ dropout=config.dropout_p,
479
+ )
480
+
481
+ self.quantize = VectorQuantizer(
482
+ config.codebook_size,
483
+ config.codebook_embed_dim,
484
+ config.commit_loss_beta,
485
+ config.entropy_loss_ratio,
486
+ config.codebook_l2_norm,
487
+ config.codebook_show_usage,
488
+ )
489
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
490
+ self.post_quant_conv = nn.Conv2d(
491
+ config.codebook_embed_dim, config.z_channels, 1
492
+ )
493
+
494
+ def encode(self, x):
495
+ h = self.encoder(x)
496
+ h = self.quant_conv(h)
497
+ quant, emb_loss, info = self.quantize(h)
498
+ return quant, emb_loss, info
499
+
500
+ def decode(self, quant):
501
+ quant = self.post_quant_conv(quant)
502
+ dec = self.decoder(quant)
503
+ return dec
504
+
505
+ def decode_code(self, code_b, shape=None, channel_first=True):
506
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
507
+ dec = self.decode(quant_b)
508
+ return dec
509
+
510
+ def forward(self, input):
511
+ quant, diff, _ = self.encode(input)
512
+ dec = self.decode(quant)
513
+ return dec, diff
514
+
515
+
516
+ #################################################################################
517
+ # VQ Model Configs #
518
+ #################################################################################
519
+ def VQ_16(**kwargs):
520
+ return VQModel(
521
+ ModelArgs(
522
+ encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
523
+ )
524
+ )
525
+
526
+
527
+ VQ_models = {"VQ-16": VQ_16}
janus/utils/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
janus/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (174 Bytes). View file
 
janus/utils/__pycache__/conversation.cpython-38.pyc ADDED
Binary file (7.5 kB). View file
 
janus/utils/__pycache__/io.cpython-38.pyc ADDED
Binary file (2.06 kB). View file
 
janus/utils/conversation.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ """
21
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
22
+ """
23
+
24
+ import dataclasses
25
+ from enum import IntEnum, auto
26
+ from typing import Dict, List
27
+
28
+
29
+ class SeparatorStyle(IntEnum):
30
+ """Separator styles."""
31
+
32
+ ADD_COLON_SINGLE = auto()
33
+ ADD_COLON_TWO = auto()
34
+ ADD_COLON_SPACE_SINGLE = auto()
35
+ NO_COLON_SINGLE = auto()
36
+ NO_COLON_TWO = auto()
37
+ ADD_NEW_LINE_SINGLE = auto()
38
+ LLAMA2 = auto()
39
+ CHATGLM = auto()
40
+ CHATML = auto()
41
+ CHATINTERN = auto()
42
+ DOLLY = auto()
43
+ RWKV = auto()
44
+ PHOENIX = auto()
45
+ ROBIN = auto()
46
+ DeepSeek = auto()
47
+ PLAIN = auto()
48
+ ALIGNMENT = auto()
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class Conversation:
53
+ """A class that manages prompt templates and keeps all conversation history."""
54
+
55
+ # The name of this template
56
+ name: str
57
+ # The template of the system prompt
58
+ system_template: str = "{system_message}"
59
+ # The system message
60
+ system_message: str = ""
61
+ # The names of two roles
62
+ roles: List[str] = (("USER", "ASSISTANT"),)
63
+ # All messages. Each item is (role, message).
64
+ messages: List[List[str]] = ()
65
+ # The number of few shot examples
66
+ offset: int = 0
67
+ # The separator style and configurations
68
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
69
+ sep: str = "\n"
70
+ sep2: str = None
71
+ # Stop criteria (the default one is EOS token)
72
+ stop_str: str = None
73
+ # Stops generation if meeting any token in this list
74
+ stop_token_ids: List[int] = None
75
+
76
+ def get_prompt(self) -> str:
77
+ """Get the prompt for generation."""
78
+ system_prompt = self.system_template.format(system_message=self.system_message)
79
+
80
+ if self.sep_style == SeparatorStyle.DeepSeek:
81
+ seps = [self.sep, self.sep2]
82
+ if system_prompt == "" or system_prompt is None:
83
+ ret = ""
84
+ else:
85
+ ret = system_prompt + seps[0]
86
+ for i, (role, message) in enumerate(self.messages):
87
+ if message:
88
+ ret += role + ": " + message + seps[i % 2]
89
+ else:
90
+ ret += role + ":"
91
+ return ret
92
+ elif self.sep_style == SeparatorStyle.LLAMA2:
93
+ seps = [self.sep, self.sep2]
94
+ if self.system_message:
95
+ ret = system_prompt
96
+ else:
97
+ ret = "[INST] "
98
+ for i, (role, message) in enumerate(self.messages):
99
+ tag = self.roles[i % 2]
100
+ if message:
101
+ if type(message) is tuple: # multimodal message
102
+ message, _ = message
103
+ if i == 0:
104
+ ret += message + " "
105
+ else:
106
+ ret += tag + " " + message + seps[i % 2]
107
+ else:
108
+ ret += tag
109
+ return ret
110
+ elif self.sep_style == SeparatorStyle.PLAIN:
111
+ seps = [self.sep, self.sep2]
112
+ ret = ""
113
+ for i, (role, message) in enumerate(self.messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ if i % 2 == 0:
118
+ ret += message + seps[i % 2]
119
+ else:
120
+ ret += message + seps[i % 2]
121
+ else:
122
+ ret += ""
123
+ return ret
124
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
125
+ seps = [self.sep, self.sep2]
126
+ ret = ""
127
+ for i, (role, message) in enumerate(self.messages):
128
+ if message:
129
+ if type(message) is tuple:
130
+ message, _, _ = message
131
+ if i % 2 == 0:
132
+ ret += "<image>\n" + seps[i % 2]
133
+ else:
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ return ret
138
+ else:
139
+ raise ValueError(f"Invalid style: {self.sep_style}")
140
+
141
+ def get_prompt_for_current_round(self, content=None):
142
+ """Get current round formatted question prompt during sft training"""
143
+ if self.sep_style == SeparatorStyle.PLAIN:
144
+ formatted_question = "<image>\n"
145
+ elif self.sep_style == SeparatorStyle.DeepSeek:
146
+ formatted_question = (
147
+ f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
148
+ )
149
+ else:
150
+ raise ValueError(f"Unsupported sep_style: {self.sep_style}")
151
+ return formatted_question
152
+
153
+ def set_system_message(self, system_message: str):
154
+ """Set the system message."""
155
+ self.system_message = system_message
156
+
157
+ def append_message(self, role: str, message: str):
158
+ """Append a new message."""
159
+ self.messages.append([role, message])
160
+
161
+ def reset_message(self):
162
+ """Reset a new message."""
163
+ self.messages = []
164
+
165
+ def update_last_message(self, message: str):
166
+ """Update the last output.
167
+
168
+ The last message is typically set to be None when constructing the prompt,
169
+ so we need to update it in-place after getting the response from a model.
170
+ """
171
+ self.messages[-1][1] = message
172
+
173
+ def to_gradio_chatbot(self):
174
+ """Convert the conversation to gradio chatbot format."""
175
+ ret = []
176
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
177
+ if i % 2 == 0:
178
+ ret.append([msg, None])
179
+ else:
180
+ ret[-1][-1] = msg
181
+ return ret
182
+
183
+ def to_openai_api_messages(self):
184
+ """Convert the conversation to OpenAI chat completion format."""
185
+ system_prompt = self.system_template.format(system_message=self.system_message)
186
+ ret = [{"role": "system", "content": system_prompt}]
187
+
188
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
189
+ if i % 2 == 0:
190
+ ret.append({"role": "user", "content": msg})
191
+ else:
192
+ if msg is not None:
193
+ ret.append({"role": "assistant", "content": msg})
194
+ return ret
195
+
196
+ def copy(self):
197
+ return Conversation(
198
+ name=self.name,
199
+ system_template=self.system_template,
200
+ system_message=self.system_message,
201
+ roles=self.roles,
202
+ messages=[[x, y] for x, y in self.messages],
203
+ offset=self.offset,
204
+ sep_style=self.sep_style,
205
+ sep=self.sep,
206
+ sep2=self.sep2,
207
+ stop_str=self.stop_str,
208
+ stop_token_ids=self.stop_token_ids,
209
+ )
210
+
211
+ def dict(self):
212
+ return {
213
+ "template_name": self.name,
214
+ "system_message": self.system_message,
215
+ "roles": self.roles,
216
+ "messages": self.messages,
217
+ "offset": self.offset,
218
+ }
219
+
220
+
221
+ # A global registry for all conversation templates
222
+ conv_templates: Dict[str, Conversation] = {}
223
+
224
+
225
+ def register_conv_template(template: Conversation, override: bool = False):
226
+ """Register a new conversation template."""
227
+ if not override:
228
+ assert (
229
+ template.name not in conv_templates
230
+ ), f"{template.name} has been registered."
231
+
232
+ conv_templates[template.name] = template
233
+
234
+
235
+ def get_conv_template(name: str) -> Conversation:
236
+ """Get a conversation template."""
237
+ return conv_templates[name].copy()
238
+
239
+
240
+ # llava_llama2 template
241
+ register_conv_template(
242
+ Conversation(
243
+ name="llava_llama2",
244
+ system_message="You are a helpful language and vision assistant. "
245
+ "You are able to understand the visual content that the user provides, "
246
+ "and assist the user with a variety of tasks using natural language.",
247
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
248
+ roles=("[INST]", "[/INST]"),
249
+ messages=(),
250
+ offset=0,
251
+ sep_style=SeparatorStyle.LLAMA2,
252
+ sep=" ",
253
+ sep2=" </s><s>",
254
+ stop_token_ids=[2],
255
+ )
256
+ )
257
+
258
+ # llama2 template
259
+ # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
260
+ register_conv_template(
261
+ Conversation(
262
+ name="llama-2",
263
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
264
+ roles=("[INST]", "[/INST]"),
265
+ messages=(),
266
+ offset=0,
267
+ sep_style=SeparatorStyle.LLAMA2,
268
+ sep=" ",
269
+ sep2=" </s><s>",
270
+ stop_token_ids=[2],
271
+ )
272
+ )
273
+
274
+
275
+ # deepseek template
276
+ register_conv_template(
277
+ Conversation(
278
+ name="deepseek_old",
279
+ system_template="{system_message}",
280
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
281
+ # "thinking step by step to be sure you get the right answer.",
282
+ system_message="",
283
+ roles=("User", "Assistant"),
284
+ messages=(),
285
+ offset=0,
286
+ sep_style=SeparatorStyle.DeepSeek,
287
+ sep="\n\n",
288
+ sep2="<|end▁of▁sentence|>",
289
+ stop_token_ids=[100001],
290
+ stop_str=["User:", "<|end▁of▁sentence|>"],
291
+ )
292
+ )
293
+ register_conv_template(
294
+ Conversation(
295
+ name="deepseek",
296
+ system_template="{system_message}",
297
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
298
+ # "thinking step by step to be sure you get the right answer.",
299
+ system_message="",
300
+ roles=("<|User|>", "<|Assistant|>"),
301
+ messages=(),
302
+ offset=0,
303
+ sep_style=SeparatorStyle.DeepSeek,
304
+ sep="\n\n",
305
+ sep2="<|end▁of▁sentence|>",
306
+ stop_token_ids=[100001],
307
+ stop_str=["<|User|>", "<|end▁of▁sentence|>"]
308
+ )
309
+ )
310
+
311
+ register_conv_template(
312
+ Conversation(
313
+ name="plain",
314
+ system_template="",
315
+ system_message="",
316
+ roles=("", ""),
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.PLAIN,
320
+ sep="",
321
+ sep2="",
322
+ stop_token_ids=[2],
323
+ stop_str=["</s>"],
324
+ )
325
+ )
326
+
327
+
328
+ register_conv_template(
329
+ Conversation(
330
+ name="alignment",
331
+ system_template="",
332
+ system_message="",
333
+ roles=("", ""),
334
+ messages=(),
335
+ offset=0,
336
+ sep_style=SeparatorStyle.ALIGNMENT,
337
+ sep="",
338
+ sep2="",
339
+ stop_token_ids=[2],
340
+ stop_str=["</s>"],
341
+ )
342
+ )
343
+
344
+
345
+ if __name__ == "__main__":
346
+ # print("Llama-2 template:")
347
+ # conv = get_conv_template("llama-2")
348
+ # conv.set_system_message("You are a helpful, respectful and honest assistant.")
349
+ # conv.append_message(conv.roles[0], "Hello!")
350
+ # conv.append_message(conv.roles[1], "Hi!")
351
+ # conv.append_message(conv.roles[0], "How are you?")
352
+ # conv.append_message(conv.roles[1], None)
353
+ # print(conv.get_prompt())
354
+
355
+ # print("\n")
356
+
357
+ print("deepseek template:")
358
+ conv = get_conv_template("deepseek")
359
+ conv.append_message(conv.roles[0], "Hello!")
360
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
361
+ conv.append_message(conv.roles[0], "Who are you?")
362
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
363
+ conv.append_message(conv.roles[0], "How are you?")
364
+ conv.append_message(conv.roles[1], None)
365
+ print(conv.get_prompt())
janus/utils/io.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import json
21
+ from typing import Dict, List
22
+
23
+ import PIL.Image
24
+ import torch
25
+ import base64
26
+ import io
27
+ from transformers import AutoModelForCausalLM
28
+
29
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
30
+
31
+
32
+ def load_pretrained_model(model_path: str):
33
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
34
+ tokenizer = vl_chat_processor.tokenizer
35
+
36
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
37
+ model_path, trust_remote_code=True
38
+ )
39
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
40
+
41
+ return tokenizer, vl_chat_processor, vl_gpt
42
+
43
+
44
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
45
+ """
46
+
47
+ Support file path or base64 images.
48
+
49
+ Args:
50
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
51
+ [
52
+ {
53
+ "role": "User",
54
+ "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
55
+ "images": ["./examples/table_datasets.png"]
56
+ },
57
+ {"role": "Assistant", "content": ""},
58
+ ]
59
+
60
+ Returns:
61
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
62
+
63
+ """
64
+
65
+ pil_images = []
66
+
67
+ for message in conversations:
68
+ if "images" not in message:
69
+ continue
70
+
71
+ for image_data in message["images"]:
72
+ if image_data.startswith("data:image"):
73
+ # Image data is in base64 format
74
+ _, image_data = image_data.split(",", 1)
75
+ image_bytes = base64.b64decode(image_data)
76
+ pil_img = PIL.Image.open(io.BytesIO(image_bytes))
77
+ else:
78
+ # Image data is a file path
79
+ pil_img = PIL.Image.open(image_data)
80
+ pil_img = pil_img.convert("RGB")
81
+ pil_images.append(pil_img)
82
+
83
+ return pil_images
84
+
85
+
86
+ def load_json(filepath):
87
+ with open(filepath, "r") as f:
88
+ data = json.load(f)
89
+ return data
weights/RealESRGAN_x2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c830d067d54fc767b9543a8432f36d91bc2de313584e8bbfe4ac26a47339e899
3
+ size 67061725