giulio98 commited on
Commit
77cc3a2
·
verified ·
1 Parent(s): 1141bb2

Create sde_ve_scheduler.py

Browse files
Files changed (1) hide show
  1. scheduler/sde_ve_scheduler.py +285 -0
scheduler/sde_ve_scheduler.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
11
+
12
+
13
+ @dataclass
14
+ class SdeVeOutput(BaseOutput):
15
+ """
16
+ Output class for the scheduler's `step` function output.
17
+
18
+ Args:
19
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
21
+ denoising loop.
22
+ prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
23
+ Mean averaged `prev_sample` over previous timesteps.
24
+ """
25
+
26
+ prev_sample: torch.FloatTensor
27
+ prev_sample_mean: torch.FloatTensor
28
+
29
+
30
+ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
31
+ """
32
+ `ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler.
33
+
34
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
35
+ methods the library implements for all schedulers such as loading and saving.
36
+
37
+ Args:
38
+ num_train_timesteps (`int`, defaults to 1000):
39
+ The number of diffusion steps to train the model.
40
+ snr (`float`, defaults to 0.15):
41
+ A coefficient weighting the step from the `model_output` sample (from the network) to the random noise.
42
+ sigma_min (`float`, defaults to 0.01):
43
+ The initial noise scale for the sigma sequence in the sampling procedure. The minimum sigma should mirror
44
+ the distribution of the data.
45
+ sigma_max (`float`, defaults to 1348.0):
46
+ The maximum value used for the range of continuous timesteps passed into the model.
47
+ sampling_eps (`float`, defaults to 1e-5):
48
+ The end value of sampling where timesteps decrease progressively from 1 to epsilon.
49
+ correct_steps (`int`, defaults to 1):
50
+ The number of correction steps performed on a produced sample.
51
+ """
52
+
53
+ order = 1
54
+
55
+ @register_to_config
56
+ def __init__(
57
+ self,
58
+ num_train_timesteps: int = 2000,
59
+ snr: float = 0.15,
60
+ sigma_min: float = 0.01,
61
+ sigma_max: float = 1348.0,
62
+ sampling_eps: float = 1e-5,
63
+ correct_steps: int = 1,
64
+ ):
65
+ # standard deviation of the initial noise distribution
66
+ self.init_noise_sigma = sigma_max
67
+
68
+ # setable values
69
+ self.timesteps = None
70
+
71
+ self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
72
+
73
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
74
+ """
75
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
76
+ current timestep.
77
+
78
+ Args:
79
+ sample (`torch.FloatTensor`):
80
+ The input sample.
81
+ timestep (`int`, *optional*):
82
+ The current timestep in the diffusion chain.
83
+
84
+ Returns:
85
+ `torch.FloatTensor`:
86
+ A scaled input sample.
87
+ """
88
+ return sample
89
+
90
+ def set_timesteps(
91
+ self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
92
+ ):
93
+ """
94
+ Sets the continuous timesteps used for the diffusion chain (to be run before inference).
95
+
96
+ Args:
97
+ num_inference_steps (`int`):
98
+ The number of diffusion steps used when generating samples with a pre-trained model.
99
+ sampling_eps (`float`, *optional*):
100
+ The final timestep value (overrides value given during scheduler instantiation).
101
+ device (`str` or `torch.device`, *optional*):
102
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
103
+
104
+ """
105
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
106
+
107
+ self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
108
+
109
+ def set_sigmas(
110
+ self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
111
+ ):
112
+ """
113
+ Sets the noise scales used for the diffusion chain (to be run before inference). The sigmas control the weight
114
+ of the `drift` and `diffusion` components of the sample update.
115
+
116
+ Args:
117
+ num_inference_steps (`int`):
118
+ The number of diffusion steps used when generating samples with a pre-trained model.
119
+ sigma_min (`float`, optional):
120
+ The initial noise scale value (overrides value given during scheduler instantiation).
121
+ sigma_max (`float`, optional):
122
+ The final noise scale value (overrides value given during scheduler instantiation).
123
+ sampling_eps (`float`, optional):
124
+ The final timestep value (overrides value given during scheduler instantiation).
125
+
126
+ """
127
+ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
128
+ sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
129
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
130
+ if self.timesteps is None:
131
+ self.set_timesteps(num_inference_steps, sampling_eps)
132
+
133
+ self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
134
+ self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
135
+ self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
136
+
137
+ def get_adjacent_sigma(self, timesteps, t):
138
+ return torch.where(
139
+ timesteps == 0,
140
+ torch.zeros_like(t.to(timesteps.device)),
141
+ self.discrete_sigmas[timesteps - 1].to(timesteps.device),
142
+ )
143
+
144
+ def step_pred(
145
+ self,
146
+ model_output: torch.FloatTensor,
147
+ timestep: int,
148
+ sample: torch.FloatTensor,
149
+ generator: Optional[torch.Generator] = None,
150
+ return_dict: bool = True,
151
+ ) -> Union[SdeVeOutput, Tuple]:
152
+ """
153
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
154
+ process from the learned model outputs (most often the predicted noise).
155
+
156
+ Args:
157
+ model_output (`torch.FloatTensor`):
158
+ The direct output from learned diffusion model.
159
+ timestep (`int`):
160
+ The current discrete timestep in the diffusion chain.
161
+ sample (`torch.FloatTensor`):
162
+ A current instance of a sample created by the diffusion process.
163
+ generator (`torch.Generator`, *optional*):
164
+ A random number generator.
165
+ return_dict (`bool`, *optional*, defaults to `True`):
166
+ Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
167
+
168
+ Returns:
169
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
170
+ If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
171
+ is returned where the first element is the sample tensor.
172
+
173
+ """
174
+ if self.timesteps is None:
175
+ raise ValueError(
176
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
177
+ )
178
+
179
+ timestep = timestep * torch.ones(
180
+ sample.shape[0], device=sample.device
181
+ ) # torch.repeat_interleave(timestep, sample.shape[0])
182
+ timesteps = (timestep * (len(self.timesteps) - 1)).long()
183
+
184
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
185
+ timesteps = timesteps.to(self.discrete_sigmas.device)
186
+
187
+ sigma = self.discrete_sigmas[timesteps].to(sample.device)
188
+ adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
189
+ drift = torch.zeros_like(sample)
190
+ diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
191
+
192
+ # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
193
+ # also equation 47 shows the analog from SDE models to ancestral sampling methods
194
+ diffusion = diffusion.flatten()
195
+ while len(diffusion.shape) < len(sample.shape):
196
+ diffusion = diffusion.unsqueeze(-1)
197
+ drift = drift - diffusion**2 * model_output
198
+
199
+ # equation 6: sample noise for the diffusion term of
200
+ noise = randn_tensor(
201
+ sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype
202
+ )
203
+ prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
204
+ # TODO is the variable diffusion the correct scaling term for the noise?
205
+ prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
206
+
207
+ if not return_dict:
208
+ return (prev_sample, prev_sample_mean)
209
+
210
+ return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
211
+
212
+ def step_correct(
213
+ self,
214
+ model_output: torch.FloatTensor,
215
+ sample: torch.FloatTensor,
216
+ generator: Optional[torch.Generator] = None,
217
+ return_dict: bool = True,
218
+ ) -> Union[SchedulerOutput, Tuple]:
219
+ """
220
+ Correct the predicted sample based on the `model_output` of the network. This is often run repeatedly after
221
+ making the prediction for the previous timestep.
222
+
223
+ Args:
224
+ model_output (`torch.FloatTensor`):
225
+ The direct output from learned diffusion model.
226
+ sample (`torch.FloatTensor`):
227
+ A current instance of a sample created by the diffusion process.
228
+ generator (`torch.Generator`, *optional*):
229
+ A random number generator.
230
+ return_dict (`bool`, *optional*, defaults to `True`):
231
+ Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
232
+
233
+ Returns:
234
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
235
+ If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
236
+ is returned where the first element is the sample tensor.
237
+
238
+ """
239
+ if self.timesteps is None:
240
+ raise ValueError(
241
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
242
+ )
243
+
244
+ # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
245
+ # sample noise for correction
246
+ noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator, device=sample.device).to(sample.device)
247
+
248
+ # compute step size from the model_output, the noise, and the snr
249
+ grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
250
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
251
+ step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
252
+ step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
253
+ # self.repeat_scalar(step_size, sample.shape[0])
254
+
255
+ # compute corrected sample: model_output term and noise term
256
+ step_size = step_size.flatten()
257
+ while len(step_size.shape) < len(sample.shape):
258
+ step_size = step_size.unsqueeze(-1)
259
+ prev_sample_mean = sample + step_size * model_output
260
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
261
+
262
+ if not return_dict:
263
+ return (prev_sample,)
264
+
265
+ return SchedulerOutput(prev_sample=prev_sample)
266
+
267
+ def add_noise(
268
+ self,
269
+ original_samples: torch.FloatTensor,
270
+ noise: torch.FloatTensor,
271
+ timesteps: torch.FloatTensor,
272
+ ) -> torch.FloatTensor:
273
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
274
+ timesteps = timesteps.to(original_samples.device)
275
+ sigmas = self.config.sigma_min * (self.config.sigma_max / self.config.sigma_min) ** timesteps
276
+ noise = (
277
+ noise * sigmas[:, None, None, None]
278
+ if noise is not None
279
+ else torch.randn_like(original_samples) * sigmas[:, None, None, None]
280
+ )
281
+ noisy_samples = noise + original_samples
282
+ return noisy_samples
283
+
284
+ def __len__(self):
285
+ return self.config.num_train_timesteps