yangwang825 commited on
Commit
ed027b5
·
verified ·
1 Parent(s): 73576f3

Upload EcapaTdnnForSequenceClassification

Browse files
angular_loss.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class Loss(nn.modules.loss._Loss):
6
+ """Inherit this class to implement custom loss."""
7
+
8
+ def __init__(self, **kwargs):
9
+ super(Loss, self).__init__(**kwargs)
10
+
11
+
12
+ class AdditiveMarginSoftmaxLoss(Loss):
13
+ """Computes Additive Margin Softmax (CosFace) Loss
14
+
15
+ Paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition
16
+
17
+ args:
18
+ scale: scale value for cosine angle
19
+ margin: margin value added to cosine angle
20
+ """
21
+
22
+ def __init__(self, scale=30.0, margin=0.2):
23
+ super().__init__()
24
+
25
+ self.eps = 1e-7
26
+ self.scale = scale
27
+ self.margin = margin
28
+
29
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
30
+ # Extract the logits corresponding to the true class
31
+ logits_target = logits[torch.arange(logits.size(0)), labels] # Faster indexing
32
+ numerator = self.scale * (logits_target - self.margin) # Apply additive margin
33
+ # Exclude the target logits from denominator calculation
34
+ logits.scatter_(1, labels.unsqueeze(1), float('-inf')) # Mask target class
35
+ denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * logits), dim=1)
36
+ # Compute final loss
37
+ loss = -torch.log(torch.exp(numerator) / denominator)
38
+ return loss.mean()
39
+
40
+
41
+ class AdditiveAngularMarginSoftmaxLoss(Loss):
42
+ """Computes Additive Angular Margin Softmax (ArcFace) Loss
43
+
44
+ Paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition
45
+
46
+ Args:
47
+ scale: scale value for cosine angle
48
+ margin: margin value added to cosine angle
49
+ """
50
+
51
+ def __init__(self, scale=20.0, margin=1.35):
52
+ super().__init__()
53
+
54
+ self.eps = 1e-7
55
+ self.scale = scale
56
+ self.margin = margin
57
+
58
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
59
+ numerator = self.scale * torch.cos(
60
+ torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps))
61
+ + self.margin
62
+ )
63
+ excl = torch.cat(
64
+ [torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0
65
+ )
66
+ denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1)
67
+ L = numerator - torch.log(denominator)
68
+ return -torch.mean(L)
audio_processing.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from packaging import version
3
+ from dataclasses import dataclass
4
+ from abc import ABC, abstractmethod
5
+
6
+ import torch
7
+
8
+ try:
9
+ import torchaudio
10
+ import torchaudio.functional
11
+ import torchaudio.transforms
12
+
13
+ TORCHAUDIO_VERSION = version.parse(torchaudio.__version__)
14
+ TORCHAUDIO_VERSION_MIN = version.parse('0.5')
15
+
16
+ HAVE_TORCHAUDIO = True
17
+ except ModuleNotFoundError:
18
+ HAVE_TORCHAUDIO = False
19
+
20
+ from .module import NeuralModule
21
+ from .features import FilterbankFeatures, FilterbankFeaturesTA
22
+ from .spectrogram_augment import SpecCutout, SpecAugment
23
+
24
+
25
+ class AudioPreprocessor(NeuralModule, ABC):
26
+ """
27
+ An interface for Neural Modules that performs audio pre-processing,
28
+ transforming the wav files to features.
29
+ """
30
+
31
+ def __init__(self, win_length, hop_length):
32
+ super().__init__()
33
+
34
+ self.win_length = win_length
35
+ self.hop_length = hop_length
36
+
37
+ self.torch_windows = {
38
+ 'hann': torch.hann_window,
39
+ 'hamming': torch.hamming_window,
40
+ 'blackman': torch.blackman_window,
41
+ 'bartlett': torch.bartlett_window,
42
+ 'ones': torch.ones,
43
+ None: torch.ones,
44
+ }
45
+
46
+ # Normally, when you call to(dtype) on a torch.nn.Module, all
47
+ # floating point parameters and buffers will change to that
48
+ # dtype, rather than being float32. The AudioPreprocessor
49
+ # classes, uniquely, don't actually have any parameters or
50
+ # buffers from what I see. In addition, we want the input to
51
+ # the preprocessor to be float32, but need to create the
52
+ # output in appropriate precision. We have this empty tensor
53
+ # here just to detect which dtype tensor this module should
54
+ # output at the end of execution.
55
+ self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False)
56
+
57
+ @torch.no_grad()
58
+ def forward(self, input_signal, length):
59
+ processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
60
+ processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
61
+ return processed_signal, processed_length
62
+
63
+ @abstractmethod
64
+ def get_features(self, input_signal, length):
65
+ # Called by forward(). Subclasses should implement this.
66
+ pass
67
+
68
+
69
+ class AudioToMelSpectrogramPreprocessor(AudioPreprocessor):
70
+ """Featurizer module that converts wavs to mel spectrograms.
71
+
72
+ Args:
73
+ sample_rate (int): Sample rate of the input audio data.
74
+ Defaults to 16000
75
+ window_size (float): Size of window for fft in seconds
76
+ Defaults to 0.02
77
+ window_stride (float): Stride of window for fft in seconds
78
+ Defaults to 0.01
79
+ n_window_size (int): Size of window for fft in samples
80
+ Defaults to None. Use one of window_size or n_window_size.
81
+ n_window_stride (int): Stride of window for fft in samples
82
+ Defaults to None. Use one of window_stride or n_window_stride.
83
+ window (str): Windowing function for fft. can be one of ['hann',
84
+ 'hamming', 'blackman', 'bartlett']
85
+ Defaults to "hann"
86
+ normalize (str): Can be one of ['per_feature', 'all_features']; all
87
+ other options disable feature normalization. 'all_features'
88
+ normalizes the entire spectrogram to be mean 0 with std 1.
89
+ 'pre_features' normalizes per channel / freq instead.
90
+ Defaults to "per_feature"
91
+ n_fft (int): Length of FT window. If None, it uses the smallest power
92
+ of 2 that is larger than n_window_size.
93
+ Defaults to None
94
+ preemph (float): Amount of pre emphasis to add to audio. Can be
95
+ disabled by passing None.
96
+ Defaults to 0.97
97
+ features (int): Number of mel spectrogram freq bins to output.
98
+ Defaults to 64
99
+ lowfreq (int): Lower bound on mel basis in Hz.
100
+ Defaults to 0
101
+ highfreq (int): Lower bound on mel basis in Hz.
102
+ Defaults to None
103
+ log (bool): Log features.
104
+ Defaults to True
105
+ log_zero_guard_type(str): Need to avoid taking the log of zero. There
106
+ are two options: "add" or "clamp".
107
+ Defaults to "add".
108
+ log_zero_guard_value(float, or str): Add or clamp requires the number
109
+ to add with or clamp to. log_zero_guard_value can either be a float
110
+ or "tiny" or "eps". torch.finfo is used if "tiny" or "eps" is
111
+ passed.
112
+ Defaults to 2**-24.
113
+ dither (float): Amount of white-noise dithering.
114
+ Defaults to 1e-5
115
+ pad_to (int): Ensures that the output size of the time dimension is
116
+ a multiple of pad_to.
117
+ Defaults to 16
118
+ frame_splicing (int): Defaults to 1
119
+ exact_pad (bool): If True, sets stft center to False and adds padding, such that num_frames = audio_length
120
+ // hop_length. Defaults to False.
121
+ pad_value (float): The value that shorter mels are padded with.
122
+ Defaults to 0
123
+ mag_power (float): The power that the linear spectrogram is raised to
124
+ prior to multiplication with mel basis.
125
+ Defaults to 2 for a power spec
126
+ rng : Random number generator
127
+ nb_augmentation_prob (float) : Probability with which narrowband augmentation would be applied to
128
+ samples in the batch.
129
+ Defaults to 0.0
130
+ nb_max_freq (int) : Frequency above which all frequencies will be masked for narrowband augmentation.
131
+ Defaults to 4000
132
+ use_torchaudio: Whether to use the `torchaudio` implementation.
133
+ mel_norm: Normalization used for mel filterbank weights.
134
+ Defaults to 'slaney' (area normalization)
135
+ stft_exact_pad: Deprecated argument, kept for compatibility with older checkpoints.
136
+ stft_conv: Deprecated argument, kept for compatibility with older checkpoints.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ sample_rate=16000,
142
+ window_size=0.02,
143
+ window_stride=0.01,
144
+ n_window_size=None,
145
+ n_window_stride=None,
146
+ window="hann",
147
+ normalize="per_feature",
148
+ n_fft=None,
149
+ preemph=0.97,
150
+ features=64,
151
+ lowfreq=0,
152
+ highfreq=None,
153
+ log=True,
154
+ log_zero_guard_type="add",
155
+ log_zero_guard_value=2**-24,
156
+ dither=1e-5,
157
+ pad_to=16,
158
+ frame_splicing=1,
159
+ exact_pad=False,
160
+ pad_value=0,
161
+ mag_power=2.0,
162
+ rng=None,
163
+ nb_augmentation_prob=0.0,
164
+ nb_max_freq=4000,
165
+ use_torchaudio: bool = False,
166
+ mel_norm="slaney",
167
+ stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
168
+ stft_conv=False, # Deprecated arguments; kept for config compatibility
169
+ ):
170
+ super().__init__(n_window_size, n_window_stride)
171
+
172
+ self._sample_rate = sample_rate
173
+ if window_size and n_window_size:
174
+ raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
175
+ if window_stride and n_window_stride:
176
+ raise ValueError(
177
+ f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
178
+ )
179
+ if window_size:
180
+ n_window_size = int(window_size * self._sample_rate)
181
+ if window_stride:
182
+ n_window_stride = int(window_stride * self._sample_rate)
183
+
184
+ # Given the long and similar argument list, point to the class and instantiate it by reference
185
+ if not use_torchaudio:
186
+ featurizer_class = FilterbankFeatures
187
+ else:
188
+ featurizer_class = FilterbankFeaturesTA
189
+ self.featurizer = featurizer_class(
190
+ sample_rate=self._sample_rate,
191
+ n_window_size=n_window_size,
192
+ n_window_stride=n_window_stride,
193
+ window=window,
194
+ normalize=normalize,
195
+ n_fft=n_fft,
196
+ preemph=preemph,
197
+ nfilt=features,
198
+ lowfreq=lowfreq,
199
+ highfreq=highfreq,
200
+ log=log,
201
+ log_zero_guard_type=log_zero_guard_type,
202
+ log_zero_guard_value=log_zero_guard_value,
203
+ dither=dither,
204
+ pad_to=pad_to,
205
+ frame_splicing=frame_splicing,
206
+ exact_pad=exact_pad,
207
+ pad_value=pad_value,
208
+ mag_power=mag_power,
209
+ rng=rng,
210
+ nb_augmentation_prob=nb_augmentation_prob,
211
+ nb_max_freq=nb_max_freq,
212
+ mel_norm=mel_norm,
213
+ stft_exact_pad=stft_exact_pad, # Deprecated arguments; kept for config compatibility
214
+ stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility
215
+ )
216
+
217
+ def get_features(self, input_signal, length):
218
+ return self.featurizer(input_signal, length)
219
+
220
+ @property
221
+ def filter_banks(self):
222
+ return self.featurizer.filter_banks
223
+
224
+
225
+ class AudioToMFCCPreprocessor(AudioPreprocessor):
226
+ """Preprocessor that converts wavs to MFCCs.
227
+ Uses torchaudio.transforms.MFCC.
228
+
229
+ Args:
230
+ sample_rate: The sample rate of the audio.
231
+ Defaults to 16000.
232
+ window_size: Size of window for fft in seconds. Used to calculate the
233
+ win_length arg for mel spectrogram.
234
+ Defaults to 0.02
235
+ window_stride: Stride of window for fft in seconds. Used to caculate
236
+ the hop_length arg for mel spect.
237
+ Defaults to 0.01
238
+ n_window_size: Size of window for fft in samples
239
+ Defaults to None. Use one of window_size or n_window_size.
240
+ n_window_stride: Stride of window for fft in samples
241
+ Defaults to None. Use one of window_stride or n_window_stride.
242
+ window: Windowing function for fft. can be one of ['hann',
243
+ 'hamming', 'blackman', 'bartlett', 'none', 'null'].
244
+ Defaults to 'hann'
245
+ n_fft: Length of FT window. If None, it uses the smallest power of 2
246
+ that is larger than n_window_size.
247
+ Defaults to None
248
+ lowfreq (int): Lower bound on mel basis in Hz.
249
+ Defaults to 0
250
+ highfreq (int): Lower bound on mel basis in Hz.
251
+ Defaults to None
252
+ n_mels: Number of mel filterbanks.
253
+ Defaults to 64
254
+ n_mfcc: Number of coefficients to retain
255
+ Defaults to 64
256
+ dct_type: Type of discrete cosine transform to use
257
+ norm: Type of norm to use
258
+ log: Whether to use log-mel spectrograms instead of db-scaled.
259
+ Defaults to True.
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ sample_rate=16000,
265
+ window_size=0.02,
266
+ window_stride=0.01,
267
+ n_window_size=None,
268
+ n_window_stride=None,
269
+ window='hann',
270
+ n_fft=None,
271
+ lowfreq=0.0,
272
+ highfreq=None,
273
+ n_mels=64,
274
+ n_mfcc=64,
275
+ dct_type=2,
276
+ norm='ortho',
277
+ log=True,
278
+ ):
279
+ self._sample_rate = sample_rate
280
+ if not HAVE_TORCHAUDIO:
281
+ print('Could not import torchaudio. Some features might not work.')
282
+
283
+ raise ModuleNotFoundError(
284
+ "torchaudio is not installed but is necessary for "
285
+ "AudioToMFCCPreprocessor. We recommend you try "
286
+ "building it from source for the PyTorch version you have."
287
+ )
288
+ if window_size and n_window_size:
289
+ raise ValueError(f"{self} received both window_size and " f"n_window_size. Only one should be specified.")
290
+ if window_stride and n_window_stride:
291
+ raise ValueError(
292
+ f"{self} received both window_stride and " f"n_window_stride. Only one should be specified."
293
+ )
294
+ # Get win_length (n_window_size) and hop_length (n_window_stride)
295
+ if window_size:
296
+ n_window_size = int(window_size * self._sample_rate)
297
+ if window_stride:
298
+ n_window_stride = int(window_stride * self._sample_rate)
299
+
300
+ super().__init__(n_window_size, n_window_stride)
301
+
302
+ mel_kwargs = {}
303
+
304
+ mel_kwargs['f_min'] = lowfreq
305
+ mel_kwargs['f_max'] = highfreq
306
+ mel_kwargs['n_mels'] = n_mels
307
+
308
+ mel_kwargs['n_fft'] = n_fft or 2 ** math.ceil(math.log2(n_window_size))
309
+
310
+ mel_kwargs['win_length'] = n_window_size
311
+ mel_kwargs['hop_length'] = n_window_stride
312
+
313
+ # Set window_fn. None defaults to torch.ones.
314
+ window_fn = self.torch_windows.get(window, None)
315
+ if window_fn is None:
316
+ raise ValueError(
317
+ f"Window argument for AudioProcessor is invalid: {window}."
318
+ f"For no window function, use 'ones' or None."
319
+ )
320
+ mel_kwargs['window_fn'] = window_fn
321
+
322
+ # Use torchaudio's implementation of MFCCs as featurizer
323
+ self.featurizer = torchaudio.transforms.MFCC(
324
+ sample_rate=self._sample_rate,
325
+ n_mfcc=n_mfcc,
326
+ dct_type=dct_type,
327
+ norm=norm,
328
+ log_mels=log,
329
+ melkwargs=mel_kwargs,
330
+ )
331
+
332
+ def get_features(self, input_signal, length):
333
+ features = self.featurizer(input_signal)
334
+ seq_len = torch.ceil(length.to(torch.float32) / self.hop_length).to(dtype=torch.long)
335
+ return features, seq_len
336
+
337
+
338
+ class SpectrogramAugmentation(NeuralModule):
339
+ """
340
+ Performs time and freq cuts in one of two ways.
341
+ SpecAugment zeroes out vertical and horizontal sections as described in
342
+ SpecAugment (https://arxiv.org/abs/1904.08779). Arguments for use with
343
+ SpecAugment are `freq_masks`, `time_masks`, `freq_width`, and `time_width`.
344
+ SpecCutout zeroes out rectangulars as described in Cutout
345
+ (https://arxiv.org/abs/1708.04552). Arguments for use with Cutout are
346
+ `rect_masks`, `rect_freq`, and `rect_time`.
347
+
348
+ Args:
349
+ freq_masks (int): how many frequency segments should be cut.
350
+ Defaults to 0.
351
+ time_masks (int): how many time segments should be cut
352
+ Defaults to 0.
353
+ freq_width (int): maximum number of frequencies to be cut in one
354
+ segment.
355
+ Defaults to 10.
356
+ time_width (int): maximum number of time steps to be cut in one
357
+ segment
358
+ Defaults to 10.
359
+ rect_masks (int): how many rectangular masks should be cut
360
+ Defaults to 0.
361
+ rect_freq (int): maximum size of cut rectangles along the frequency
362
+ dimension
363
+ Defaults to 5.
364
+ rect_time (int): maximum size of cut rectangles along the time
365
+ dimension
366
+ Defaults to 25.
367
+ use_numba_spec_augment: use numba code for Spectrogram augmentation
368
+ use_vectorized_spec_augment: use vectorized code for Spectrogram augmentation
369
+
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ freq_masks=0,
375
+ time_masks=0,
376
+ freq_width=10,
377
+ time_width=10,
378
+ rect_masks=0,
379
+ rect_time=5,
380
+ rect_freq=20,
381
+ rng=None,
382
+ mask_value=0.0,
383
+ use_vectorized_spec_augment: bool = True,
384
+ ):
385
+ super().__init__()
386
+
387
+ if rect_masks > 0:
388
+ self.spec_cutout = SpecCutout(
389
+ rect_masks=rect_masks,
390
+ rect_time=rect_time,
391
+ rect_freq=rect_freq,
392
+ rng=rng,
393
+ )
394
+ # self.spec_cutout.to(self._device)
395
+ else:
396
+ self.spec_cutout = lambda input_spec: input_spec
397
+ if freq_masks + time_masks > 0:
398
+ self.spec_augment = SpecAugment(
399
+ freq_masks=freq_masks,
400
+ time_masks=time_masks,
401
+ freq_width=freq_width,
402
+ time_width=time_width,
403
+ rng=rng,
404
+ mask_value=mask_value,
405
+ use_vectorized_code=use_vectorized_spec_augment,
406
+ )
407
+ else:
408
+ self.spec_augment = lambda input_spec, length: input_spec
409
+
410
+ def forward(self, input_spec, length):
411
+ augmented_spec = self.spec_cutout(input_spec=input_spec)
412
+ augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length)
413
+ return augmented_spec
config.json CHANGED
@@ -1,11 +1,14 @@
1
  {
2
- "_attn_implementation_autoset": true,
3
  "angular": true,
4
  "angular_margin": 0.2,
5
  "angular_scale": 30,
 
 
 
6
  "attention_channels": 128,
7
  "auto_map": {
8
- "AutoConfig": "configuration_ecapa_tdnn.EcapaTdnnConfig"
 
9
  },
10
  "bos_token_id": 1,
11
  "decoder_config": {
@@ -2577,6 +2580,7 @@
2577
  },
2578
  "time_masks": 5,
2579
  "time_width": 0.03,
 
2580
  "transformers_version": "4.48.3",
2581
  "use_torchaudio": true,
2582
  "use_vectorized_spec_augment": true,
 
1
  {
 
2
  "angular": true,
3
  "angular_margin": 0.2,
4
  "angular_scale": 30,
5
+ "architectures": [
6
+ "EcapaTdnnForSequenceClassification"
7
+ ],
8
  "attention_channels": 128,
9
  "auto_map": {
10
+ "AutoConfig": "configuration_ecapa_tdnn.EcapaTdnnConfig",
11
+ "AutoModelForAudioClassification": "modeling_ecapa_tdnn.EcapaTdnnForSequenceClassification"
12
  },
13
  "bos_token_id": 1,
14
  "decoder_config": {
 
2580
  },
2581
  "time_masks": 5,
2582
  "time_width": 0.03,
2583
+ "torch_dtype": "float32",
2584
  "transformers_version": "4.48.3",
2585
  "use_torchaudio": true,
2586
  "use_vectorized_spec_augment": true,
conv_asr.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .module import NeuralModule
8
+ from .tdnn_attention import (
9
+ StatsPoolLayer,
10
+ AttentivePoolLayer,
11
+ TdnnModule,
12
+ TdnnSeModule,
13
+ TdnnSeRes2NetModule,
14
+ init_weights
15
+ )
16
+
17
+
18
+ class EcapaTdnnEncoder(NeuralModule):
19
+ """
20
+ Modified ECAPA Encoder layer without Res2Net module for faster training and inference which achieves
21
+ better numbers on speaker diarization tasks
22
+ Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
23
+
24
+ input:
25
+ feat_in: input feature shape (mel spec feature shape)
26
+ filters: list of filter shapes for SE_TDNN modules
27
+ kernel_sizes: list of kernel shapes for SE_TDNN modules
28
+ dilations: list of dilations for group conv se layer
29
+ scale: scale value to group wider conv channels (deafult:8)
30
+
31
+ output:
32
+ outputs : encoded output
33
+ output_length: masked output lengths
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ feat_in: int,
39
+ filters: list,
40
+ kernel_sizes: list,
41
+ dilations: list,
42
+ scale: int = 8,
43
+ res2net: bool = False,
44
+ res2net_scale: int = 8,
45
+ init_mode: str = 'xavier_uniform',
46
+ ):
47
+ super().__init__()
48
+ self.layers = nn.ModuleList()
49
+ self.layers.append(TdnnModule(feat_in, filters[0], kernel_size=kernel_sizes[0], dilation=dilations[0]))
50
+
51
+ for i in range(len(filters) - 2):
52
+ if res2net:
53
+ self.layers.append(
54
+ TdnnSeRes2NetModule(
55
+ filters[i],
56
+ filters[i + 1],
57
+ group_scale=scale,
58
+ se_channels=128,
59
+ kernel_size=kernel_sizes[i + 1],
60
+ dilation=dilations[i + 1],
61
+ res2net_scale=res2net_scale,
62
+ )
63
+ )
64
+ else:
65
+ self.layers.append(
66
+ TdnnSeModule(
67
+ filters[i],
68
+ filters[i + 1],
69
+ group_scale=scale,
70
+ se_channels=128,
71
+ kernel_size=kernel_sizes[i + 1],
72
+ dilation=dilations[i + 1],
73
+ )
74
+ )
75
+ self.feature_agg = TdnnModule(filters[-1], filters[-1], kernel_sizes[-1], dilations[-1])
76
+ self.apply(lambda x: init_weights(x, mode=init_mode))
77
+
78
+ def forward(self, audio_signal, length=None):
79
+ x = audio_signal
80
+ outputs = []
81
+
82
+ for layer in self.layers:
83
+ x = layer(x, length=length)
84
+ outputs.append(x)
85
+
86
+ x = torch.cat(outputs[1:], dim=1)
87
+ x = self.feature_agg(x)
88
+ return x, length
89
+
90
+
91
+ class SpeakerDecoder(NeuralModule):
92
+ """
93
+ Speaker Decoder creates the final neural layers that maps from the outputs
94
+ of Jasper Encoder to the embedding layer followed by speaker based softmax loss.
95
+
96
+ Args:
97
+ feat_in (int): Number of channels being input to this module
98
+ num_classes (int): Number of unique speakers in dataset
99
+ emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings
100
+ from 1st of this layers). Defaults to [1024,1024]
101
+ pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention'
102
+ Defaults to 'xvector (mean and variance)'
103
+ tap (temporal average pooling: just mean)
104
+ attention (attention based pooling)
105
+ init_mode (str): Describes how neural network parameters are
106
+ initialized. Options are ['xavier_uniform', 'xavier_normal',
107
+ 'kaiming_uniform','kaiming_normal'].
108
+ Defaults to "xavier_uniform".
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ feat_in: int,
114
+ num_classes: int,
115
+ emb_sizes: Optional[Union[int, list]] = 256,
116
+ pool_mode: str = 'xvector',
117
+ angular: bool = False,
118
+ attention_channels: int = 128,
119
+ init_mode: str = "xavier_uniform",
120
+ ):
121
+ super().__init__()
122
+ self.angular = angular
123
+ self.emb_id = 2
124
+ bias = False if self.angular else True
125
+ emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes
126
+
127
+ self._num_classes = num_classes
128
+ self.pool_mode = pool_mode.lower()
129
+ if self.pool_mode == 'xvector' or self.pool_mode == 'tap':
130
+ self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode)
131
+ affine_type = 'linear'
132
+ elif self.pool_mode == 'attention':
133
+ self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels)
134
+ affine_type = 'conv'
135
+
136
+ shapes = [self._pooling.feat_in]
137
+ for size in emb_sizes:
138
+ shapes.append(int(size))
139
+
140
+ emb_layers = []
141
+ for shape_in, shape_out in zip(shapes[:-1], shapes[1:]):
142
+ layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type)
143
+ emb_layers.append(layer)
144
+
145
+ self.emb_layers = nn.ModuleList(emb_layers)
146
+
147
+ self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias)
148
+
149
+ self.apply(lambda x: init_weights(x, mode=init_mode))
150
+
151
+ def affine_layer(
152
+ self,
153
+ inp_shape,
154
+ out_shape,
155
+ learn_mean=True,
156
+ affine_type='conv',
157
+ ):
158
+ if affine_type == 'conv':
159
+ layer = nn.Sequential(
160
+ nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True),
161
+ nn.Conv1d(inp_shape, out_shape, kernel_size=1),
162
+ )
163
+
164
+ else:
165
+ layer = nn.Sequential(
166
+ nn.Linear(inp_shape, out_shape),
167
+ nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True),
168
+ nn.ReLU(),
169
+ )
170
+
171
+ return layer
172
+
173
+ def forward(self, encoder_output, length=None):
174
+ pool = self._pooling(encoder_output, length)
175
+ embs = []
176
+
177
+ for layer in self.emb_layers:
178
+ pool, emb = layer(pool), layer[: self.emb_id](pool)
179
+ embs.append(emb)
180
+
181
+ pool = pool.squeeze(-1)
182
+ if self.angular:
183
+ for W in self.final.parameters():
184
+ W = F.normalize(W, p=2, dim=1)
185
+ pool = F.normalize(pool, p=2, dim=1)
186
+
187
+ out = self.final(pool)
188
+
189
+ return out, embs[-1].squeeze(-1)
features.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Optional, Union, Tuple
4
+
5
+ import librosa
6
+ import torchaudio
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ try:
12
+ import torchaudio
13
+
14
+ HAVE_TORCHAUDIO = True
15
+ except ModuleNotFoundError:
16
+ HAVE_TORCHAUDIO = False
17
+
18
+ CONSTANT = 1e-5
19
+
20
+
21
+ def normalize_batch(x, seq_len, normalize_type):
22
+ x_mean = None
23
+ x_std = None
24
+ if normalize_type == "per_feature":
25
+ batch_size = x.shape[0]
26
+ max_time = x.shape[2]
27
+
28
+ # When doing stream capture to a graph, item() is not allowed
29
+ # becuase it calls cudaStreamSynchronize(). Therefore, we are
30
+ # sacrificing some error checking when running with cuda graphs.
31
+ if (
32
+ torch.cuda.is_available()
33
+ and not torch.cuda.is_current_stream_capturing()
34
+ and torch.any(seq_len == 1).item()
35
+ ):
36
+ raise ValueError(
37
+ "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
38
+ "in torch.std() returning nan. Make sure your audio length has enough samples for a single "
39
+ "feature (ex. at least `hop_length` for Mel Spectrograms)."
40
+ )
41
+ time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time)
42
+ valid_mask = time_steps < seq_len.unsqueeze(1)
43
+ x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2)
44
+ x_mean_denominator = valid_mask.sum(axis=1)
45
+ x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1)
46
+
47
+ # Subtract 1 in the denominator to correct for the bias.
48
+ x_std = torch.sqrt(
49
+ torch.sum(torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, axis=2)
50
+ / (x_mean_denominator.unsqueeze(1) - 1.0)
51
+ )
52
+ # make sure x_std is not zero
53
+ x_std += CONSTANT
54
+ return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
55
+ elif normalize_type == "all_features":
56
+ x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
57
+ x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
58
+ for i in range(x.shape[0]):
59
+ x_mean[i] = x[i, :, : seq_len[i].item()].mean()
60
+ x_std[i] = x[i, :, : seq_len[i].item()].std()
61
+ # make sure x_std is not zero
62
+ x_std += CONSTANT
63
+ return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
64
+ elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
65
+ x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
66
+ x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
67
+ return (
68
+ (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
69
+ x_mean,
70
+ x_std,
71
+ )
72
+ else:
73
+ return x, x_mean, x_std
74
+
75
+
76
+ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Tensor, fill_value=0.0) -> torch.Tensor:
77
+ """
78
+ Fill spectrogram values outside the length with `fill_value`
79
+
80
+ Args:
81
+ spectrogram: Tensor with shape [B, C, L] containing batched spectrograms
82
+ spectrogram_len: Tensor with shape [B] containing the sequence length of each batch element
83
+ fill_value: value to fill with, 0.0 by default
84
+
85
+ Returns:
86
+ cleaned spectrogram, tensor with shape equal to `spectrogram`
87
+ """
88
+ device = spectrogram.device
89
+ batch_size, _, max_len = spectrogram.shape
90
+ mask = torch.arange(max_len, device=device)[None, :] >= spectrogram_len[:, None]
91
+ mask = mask.unsqueeze(1).expand_as(spectrogram)
92
+ return spectrogram.masked_fill(mask, fill_value)
93
+
94
+
95
+ def splice_frames(x, frame_splicing):
96
+ """Stacks frames together across feature dim
97
+
98
+ input is batch_size, feature_dim, num_frames
99
+ output is batch_size, feature_dim*frame_splicing, num_frames
100
+
101
+ """
102
+ seq = [x]
103
+ for n in range(1, frame_splicing):
104
+ seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
105
+ return torch.cat(seq, dim=1)
106
+
107
+
108
+ @torch.jit.script_if_tracing
109
+ def make_seq_mask_like(
110
+ lengths: torch.Tensor, like: torch.Tensor, time_dim: int = -1, valid_ones: bool = True
111
+ ) -> torch.Tensor:
112
+ """
113
+
114
+ Args:
115
+ lengths: Tensor with shape [B] containing the sequence length of each batch element
116
+ like: The mask will contain the same number of dimensions as this Tensor, and will have the same max
117
+ length in the time dimension of this Tensor.
118
+ time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based.
119
+ valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert.
120
+
121
+ Returns:
122
+ A :class:`torch.Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else
123
+ vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match
124
+ the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and
125
+ `time_dim == -1', mask will have shape `[3, 1, 5]`.
126
+ """
127
+ # Mask with shape [B, T]
128
+ mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.view(-1, 1))
129
+ # [B, T] -> [B, *, T] where * is any number of singleton dimensions to expand to like tensor
130
+ for _ in range(like.dim() - mask.dim()):
131
+ mask = mask.unsqueeze(1)
132
+ # If needed, transpose time dim
133
+ if time_dim != -1 and time_dim != mask.dim() - 1:
134
+ mask = mask.transpose(-1, time_dim)
135
+ # Maybe invert the padded vs. valid token values
136
+ if not valid_ones:
137
+ mask = ~mask
138
+ return mask
139
+
140
+
141
+ class FilterbankFeatures(nn.Module):
142
+ """Featurizer that converts wavs to Mel Spectrograms.
143
+ See AudioToMelSpectrogramPreprocessor for args.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ sample_rate=16000,
149
+ n_window_size=320,
150
+ n_window_stride=160,
151
+ window="hann",
152
+ normalize="per_feature",
153
+ n_fft=None,
154
+ preemph=0.97,
155
+ nfilt=64,
156
+ lowfreq=0,
157
+ highfreq=None,
158
+ log=True,
159
+ log_zero_guard_type="add",
160
+ log_zero_guard_value=2**-24,
161
+ dither=CONSTANT,
162
+ pad_to=16,
163
+ max_duration=16.7,
164
+ frame_splicing=1,
165
+ exact_pad=False,
166
+ pad_value=0,
167
+ mag_power=2.0,
168
+ use_grads=False,
169
+ rng=None,
170
+ nb_augmentation_prob=0.0,
171
+ nb_max_freq=4000,
172
+ mel_norm="slaney",
173
+ stft_exact_pad=False, # Deprecated arguments; kept for config compatibility
174
+ stft_conv=False, # Deprecated arguments; kept for config compatibility
175
+ ):
176
+ super().__init__()
177
+ if stft_conv or stft_exact_pad:
178
+ print(
179
+ "Using torch_stft is deprecated and has been removed. The values have been forcibly set to False "
180
+ "for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
181
+ "as needed."
182
+ )
183
+ if exact_pad and n_window_stride % 2 == 1:
184
+ raise NotImplementedError(
185
+ f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
186
+ "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
187
+ )
188
+ self.log_zero_guard_value = log_zero_guard_value
189
+ if (
190
+ n_window_size is None
191
+ or n_window_stride is None
192
+ or not isinstance(n_window_size, int)
193
+ or not isinstance(n_window_stride, int)
194
+ or n_window_size <= 0
195
+ or n_window_stride <= 0
196
+ ):
197
+ raise ValueError(
198
+ f"{self} got an invalid value for either n_window_size or "
199
+ f"n_window_stride. Both must be positive ints."
200
+ )
201
+
202
+ self.win_length = n_window_size
203
+ self.hop_length = n_window_stride
204
+ self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
205
+ self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
206
+ self.exact_pad = exact_pad
207
+
208
+ if exact_pad:
209
+ print("STFT using exact pad")
210
+ torch_windows = {
211
+ 'hann': torch.hann_window,
212
+ 'hamming': torch.hamming_window,
213
+ 'blackman': torch.blackman_window,
214
+ 'bartlett': torch.bartlett_window,
215
+ 'none': None,
216
+ }
217
+ window_fn = torch_windows.get(window, None)
218
+ window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
219
+ self.register_buffer("window", window_tensor)
220
+
221
+ self.normalize = normalize
222
+ self.log = log
223
+ self.dither = dither
224
+ self.frame_splicing = frame_splicing
225
+ self.nfilt = nfilt
226
+ self.preemph = preemph
227
+ self.pad_to = pad_to
228
+ highfreq = highfreq or sample_rate / 2
229
+
230
+ filterbanks = torch.tensor(
231
+ librosa.filters.mel(
232
+ sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
233
+ ),
234
+ dtype=torch.float,
235
+ ).unsqueeze(0)
236
+ self.register_buffer("fb", filterbanks)
237
+
238
+ # Calculate maximum sequence length
239
+ max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float))
240
+ max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
241
+ self.max_length = max_length + max_pad
242
+ self.pad_value = pad_value
243
+ self.mag_power = mag_power
244
+
245
+ # We want to avoid taking the log of zero
246
+ # There are two options: either adding or clamping to a small value
247
+ if log_zero_guard_type not in ["add", "clamp"]:
248
+ raise ValueError(
249
+ f"{self} received {log_zero_guard_type} for the "
250
+ f"log_zero_guard_type parameter. It must be either 'add' or "
251
+ f"'clamp'."
252
+ )
253
+
254
+ self.use_grads = use_grads
255
+ if not use_grads:
256
+ self.forward = torch.no_grad()(self.forward)
257
+ self._rng = random.Random() if rng is None else rng
258
+ self.nb_augmentation_prob = nb_augmentation_prob
259
+ if self.nb_augmentation_prob > 0.0:
260
+ if nb_max_freq >= sample_rate / 2:
261
+ self.nb_augmentation_prob = 0.0
262
+ else:
263
+ self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
264
+
265
+ # log_zero_guard_value is the the small we want to use, we support
266
+ # an actual number, or "tiny", or "eps"
267
+ self.log_zero_guard_type = log_zero_guard_type
268
+
269
+ def stft(self, x):
270
+ return torch.stft(
271
+ x,
272
+ n_fft=self.n_fft,
273
+ hop_length=self.hop_length,
274
+ win_length=self.win_length,
275
+ center=False if self.exact_pad else True,
276
+ window=self.window.to(dtype=torch.float),
277
+ return_complex=True,
278
+ )
279
+
280
+ def log_zero_guard_value_fn(self, x):
281
+ if isinstance(self.log_zero_guard_value, str):
282
+ if self.log_zero_guard_value == "tiny":
283
+ return torch.finfo(x.dtype).tiny
284
+ elif self.log_zero_guard_value == "eps":
285
+ return torch.finfo(x.dtype).eps
286
+ else:
287
+ raise ValueError(
288
+ f"{self} received {self.log_zero_guard_value} for the "
289
+ f"log_zero_guard_type parameter. It must be either a "
290
+ f"number, 'tiny', or 'eps'"
291
+ )
292
+ else:
293
+ return self.log_zero_guard_value
294
+
295
+ def get_seq_len(self, seq_len):
296
+ # Assuming that center is True is stft_pad_amount = 0
297
+ pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
298
+ seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
299
+ return seq_len.to(dtype=torch.long)
300
+
301
+ @property
302
+ def filter_banks(self):
303
+ return self.fb
304
+
305
+ def forward(self, x, seq_len, linear_spec=False):
306
+ seq_len = self.get_seq_len(seq_len)
307
+
308
+ if self.stft_pad_amount is not None:
309
+ x = torch.nn.functional.pad(
310
+ x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect"
311
+ ).squeeze(1)
312
+
313
+ # dither (only in training mode for eval determinism)
314
+ if self.training and self.dither > 0:
315
+ x += self.dither * torch.randn_like(x)
316
+
317
+ # do preemphasis
318
+ if self.preemph is not None:
319
+ x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1)
320
+
321
+ # disable autocast to get full range of stft values
322
+ with torch.amp.autocast(x.device.type, enabled=False):
323
+ x = self.stft(x)
324
+
325
+ # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
326
+ # guard is needed for sqrt if grads are passed through
327
+ guard = 0 if not self.use_grads else CONSTANT
328
+ x = torch.view_as_real(x)
329
+ x = torch.sqrt(x.pow(2).sum(-1) + guard)
330
+
331
+ if self.training and self.nb_augmentation_prob > 0.0:
332
+ for idx in range(x.shape[0]):
333
+ if self._rng.random() < self.nb_augmentation_prob:
334
+ x[idx, self._nb_max_fft_bin :, :] = 0.0
335
+
336
+ # get power spectrum
337
+ if self.mag_power != 1.0:
338
+ x = x.pow(self.mag_power)
339
+
340
+ # return plain spectrogram if required
341
+ if linear_spec:
342
+ return x, seq_len
343
+
344
+ # dot with filterbank energies
345
+ x = torch.matmul(self.fb.to(x.dtype), x)
346
+ # log features if required
347
+ if self.log:
348
+ if self.log_zero_guard_type == "add":
349
+ x = torch.log(x + self.log_zero_guard_value_fn(x))
350
+ elif self.log_zero_guard_type == "clamp":
351
+ x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
352
+ else:
353
+ raise ValueError("log_zero_guard_type was not understood")
354
+
355
+ # frame splicing if required
356
+ if self.frame_splicing > 1:
357
+ x = splice_frames(x, self.frame_splicing)
358
+
359
+ # normalize if required
360
+ if self.normalize:
361
+ x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize)
362
+
363
+ # mask to zero any values beyond seq_len in batch, pad to multiple of `pad_to` (for efficiency)
364
+ max_len = x.size(-1)
365
+ mask = torch.arange(max_len, device=x.device)
366
+ mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
367
+ x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value)
368
+ del mask
369
+ pad_to = self.pad_to
370
+ if pad_to == "max":
371
+ x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value)
372
+ elif pad_to > 0:
373
+ pad_amt = x.size(-1) % pad_to
374
+ if pad_amt != 0:
375
+ x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
376
+ return x, seq_len
377
+
378
+
379
+ class FilterbankFeaturesTA(nn.Module):
380
+ """
381
+ Exportable, `torchaudio`-based implementation of Mel Spectrogram extraction.
382
+
383
+ See `AudioToMelSpectrogramPreprocessor` for args.
384
+
385
+ """
386
+
387
+ def __init__(
388
+ self,
389
+ sample_rate: int = 16000,
390
+ n_window_size: int = 320,
391
+ n_window_stride: int = 160,
392
+ normalize: Optional[str] = "per_feature",
393
+ nfilt: int = 64,
394
+ n_fft: Optional[int] = None,
395
+ preemph: float = 0.97,
396
+ lowfreq: float = 0,
397
+ highfreq: Optional[float] = None,
398
+ log: bool = True,
399
+ log_zero_guard_type: str = "add",
400
+ log_zero_guard_value: Union[float, str] = 2**-24,
401
+ dither: float = 1e-5,
402
+ window: str = "hann",
403
+ pad_to: int = 0,
404
+ pad_value: float = 0.0,
405
+ mel_norm="slaney",
406
+ # Seems like no one uses these options anymore. Don't convolute the code by supporting thm.
407
+ use_grads: bool = False, # Deprecated arguments; kept for config compatibility
408
+ max_duration: float = 16.7, # Deprecated arguments; kept for config compatibility
409
+ frame_splicing: int = 1, # Deprecated arguments; kept for config compatibility
410
+ exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
411
+ nb_augmentation_prob: float = 0.0, # Deprecated arguments; kept for config compatibility
412
+ nb_max_freq: int = 4000, # Deprecated arguments; kept for config compatibility
413
+ mag_power: float = 2.0, # Deprecated arguments; kept for config compatibility
414
+ rng: Optional[random.Random] = None, # Deprecated arguments; kept for config compatibility
415
+ stft_exact_pad: bool = False, # Deprecated arguments; kept for config compatibility
416
+ stft_conv: bool = False, # Deprecated arguments; kept for config compatibility
417
+ ):
418
+ super().__init__()
419
+ if not HAVE_TORCHAUDIO:
420
+ raise ValueError(f"Need to install torchaudio to instantiate a {self.__class__.__name__}")
421
+
422
+ # Make sure log zero guard is supported, if given as a string
423
+ supported_log_zero_guard_strings = {"eps", "tiny"}
424
+ if isinstance(log_zero_guard_value, str) and log_zero_guard_value not in supported_log_zero_guard_strings:
425
+ raise ValueError(
426
+ f"Log zero guard value must either be a float or a member of {supported_log_zero_guard_strings}"
427
+ )
428
+
429
+ # Copied from `AudioPreprocessor` due to the ad-hoc structuring of the Mel Spec extractor class
430
+ self.torch_windows = {
431
+ 'hann': torch.hann_window,
432
+ 'hamming': torch.hamming_window,
433
+ 'blackman': torch.blackman_window,
434
+ 'bartlett': torch.bartlett_window,
435
+ 'ones': torch.ones,
436
+ None: torch.ones,
437
+ }
438
+
439
+ # Ensure we can look up the window function
440
+ if window not in self.torch_windows:
441
+ raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}")
442
+
443
+ self.win_length = n_window_size
444
+ self.hop_length = n_window_stride
445
+ self._sample_rate = sample_rate
446
+ self._normalize_strategy = normalize
447
+ self._use_log = log
448
+ self._preemphasis_value = preemph
449
+ self.log_zero_guard_type = log_zero_guard_type
450
+ self.log_zero_guard_value: Union[str, float] = log_zero_guard_value
451
+ self.dither = dither
452
+ self.pad_to = pad_to
453
+ self.pad_value = pad_value
454
+ self.n_fft = n_fft
455
+ self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram(
456
+ sample_rate=self._sample_rate,
457
+ win_length=self.win_length,
458
+ hop_length=self.hop_length,
459
+ n_mels=nfilt,
460
+ window_fn=self.torch_windows[window],
461
+ mel_scale="slaney",
462
+ norm=mel_norm,
463
+ n_fft=n_fft,
464
+ f_max=highfreq,
465
+ f_min=lowfreq,
466
+ wkwargs={"periodic": False},
467
+ )
468
+
469
+ @property
470
+ def filter_banks(self):
471
+ """Matches the analogous class"""
472
+ return self._mel_spec_extractor.mel_scale.fb
473
+
474
+ def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
475
+ if isinstance(self.log_zero_guard_value, float):
476
+ return self.log_zero_guard_value
477
+ return getattr(torch.finfo(dtype), self.log_zero_guard_value)
478
+
479
+ def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor:
480
+ if self.training and self.dither > 0.0:
481
+ noise = torch.randn_like(signals) * self.dither
482
+ signals = signals + noise
483
+ return signals
484
+
485
+ def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor:
486
+ if self._preemphasis_value is not None:
487
+ padded = torch.nn.functional.pad(signals, (1, 0))
488
+ signals = signals - self._preemphasis_value * padded[:, :-1]
489
+ return signals
490
+
491
+ def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
492
+ out_lengths = input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
493
+ return out_lengths
494
+
495
+ def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor:
496
+ # Only apply during training; else need to capture dynamic shape for exported models
497
+ if not self.training or self.pad_to == 0 or features.shape[-1] % self.pad_to == 0:
498
+ return features
499
+ pad_length = self.pad_to - (features.shape[-1] % self.pad_to)
500
+ return torch.nn.functional.pad(features, pad=(0, pad_length), value=self.pad_value)
501
+
502
+ def _apply_log(self, features: torch.Tensor) -> torch.Tensor:
503
+ if self._use_log:
504
+ zero_guard = self._resolve_log_zero_guard_value(features.dtype)
505
+ if self.log_zero_guard_type == "add":
506
+ features = features + zero_guard
507
+ elif self.log_zero_guard_type == "clamp":
508
+ features = features.clamp(min=zero_guard)
509
+ else:
510
+ raise ValueError(f"Unsupported log zero guard type: '{self.log_zero_guard_type}'")
511
+ features = features.log()
512
+ return features
513
+
514
+ def _extract_spectrograms(self, signals: torch.Tensor) -> torch.Tensor:
515
+ # Complex FFT needs to be done in single precision
516
+ with torch.amp.autocast('cuda', enabled=False):
517
+ features = self._mel_spec_extractor(waveform=signals)
518
+ return features
519
+
520
+ def _apply_normalization(self, features: torch.Tensor, lengths: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
521
+ # For consistency, this function always does a masked fill even if not normalizing.
522
+ mask: torch.Tensor = make_seq_mask_like(lengths=lengths, like=features, time_dim=-1, valid_ones=False)
523
+ features = features.masked_fill(mask, 0.0)
524
+ # Maybe don't normalize
525
+ if self._normalize_strategy is None:
526
+ return features
527
+ # Use the log zero guard for the sqrt zero guard
528
+ guard_value = self._resolve_log_zero_guard_value(features.dtype)
529
+ if self._normalize_strategy == "per_feature" or self._normalize_strategy == "all_features":
530
+ # 'all_features' reduces over each sample; 'per_feature' reduces over each channel
531
+ reduce_dim = 2
532
+ if self._normalize_strategy == "all_features":
533
+ reduce_dim = [1, 2]
534
+ # [B, D, T] -> [B, D, 1] or [B, 1, 1]
535
+ means = features.sum(dim=reduce_dim, keepdim=True).div(lengths.view(-1, 1, 1))
536
+ stds = (
537
+ features.sub(means)
538
+ .masked_fill(mask, 0.0)
539
+ .pow(2.0)
540
+ .sum(dim=reduce_dim, keepdim=True) # [B, D, T] -> [B, D, 1] or [B, 1, 1]
541
+ .div(lengths.view(-1, 1, 1) - 1) # assume biased estimator
542
+ .clamp(min=guard_value) # avoid sqrt(0)
543
+ .sqrt()
544
+ )
545
+ features = (features - means) / (stds + eps)
546
+ else:
547
+ # Deprecating constant std/mean
548
+ raise ValueError(f"Unsupported norm type: '{self._normalize_strategy}")
549
+ features = features.masked_fill(mask, 0.0)
550
+ return features
551
+
552
+ def forward(self, input_signal: torch.Tensor, length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
553
+ feature_lengths = self._compute_output_lengths(input_lengths=length)
554
+ signals = self._apply_dithering(signals=input_signal)
555
+ signals = self._apply_preemphasis(signals=signals)
556
+ features = self._extract_spectrograms(signals=signals)
557
+ features = self._apply_log(features=features)
558
+ features = self._apply_normalization(features=features, lengths=feature_lengths)
559
+ features = self._apply_pad_to(features=features)
560
+ return features, feature_lengths
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5828920fac146b843150400b269fa8fe51e3f3c5922b6c1882db45a43480dc92
3
  size 26039912
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4da89b0b6d405974f1e332bdc9945fae76222d7ddf0f955653fba9a00cca0339
3
  size 26039912
modeling_ecapa_tdnn.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Union, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from transformers import PreTrainedModel
8
+ from transformers.utils import ModelOutput
9
+
10
+ from .configuration_ecapa_tdnn import EcapaTdnnConfig
11
+ from .audio_processing import AudioToMelSpectrogramPreprocessor
12
+ from .audio_processing import SpectrogramAugmentation
13
+ from .conv_asr import EcapaTdnnEncoder, SpeakerDecoder
14
+ from .angular_loss import AdditiveMarginSoftmaxLoss, AdditiveAngularMarginSoftmaxLoss
15
+
16
+
17
+ @dataclass
18
+ class EcapaTdnnBaseModelOutput(ModelOutput):
19
+
20
+ encoder_outputs: torch.FloatTensor = None
21
+ extract_features: torch.FloatTensor = None
22
+ output_lengths: torch.FloatTensor = None
23
+
24
+
25
+ @dataclass
26
+ class EcapaTdnnSequenceClassifierOutput(ModelOutput):
27
+
28
+ loss: torch.FloatTensor = None
29
+ logits: torch.FloatTensor = None
30
+ embeddings: torch.FloatTensor = None
31
+
32
+
33
+ class EcapaTdnnPreTrainedModel(PreTrainedModel):
34
+
35
+ config_class = EcapaTdnnConfig
36
+ base_model_prefix = "ecapa_tdnn"
37
+ main_input_name = "input_values"
38
+
39
+ def _init_weights(self, module):
40
+ """Initialize the weights"""
41
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
42
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
43
+ if module.bias is not None:
44
+ module.bias.data.zero_()
45
+ elif isinstance(module, nn.Conv2d):
46
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
47
+ if module.bias is not None:
48
+ module.bias.data.zero_()
49
+ elif isinstance(module, nn.LayerNorm):
50
+ module.bias.data.zero_()
51
+ module.weight.data.fill_(1.0)
52
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
53
+ nn.init.constant_(module.weight, 1)
54
+ nn.init.constant_(module.bias, 0)
55
+
56
+ @property
57
+ def num_weights(self):
58
+ """
59
+ Utility property that returns the total number of parameters of NeuralModule.
60
+ """
61
+ return self._num_weights()
62
+
63
+ @torch.jit.ignore
64
+ def _num_weights(self):
65
+ num: int = 0
66
+ for p in self.parameters():
67
+ if p.requires_grad:
68
+ num += p.numel()
69
+ return num
70
+
71
+
72
+ class EcapaTdnnModel(EcapaTdnnPreTrainedModel):
73
+
74
+ def __init__(self, config: EcapaTdnnConfig):
75
+ super().__init__(config)
76
+ self.config = config
77
+
78
+ self.preprocessor = AudioToMelSpectrogramPreprocessor(**config.mel_spectrogram_config)
79
+ self.spec_augment = SpectrogramAugmentation(**config.spectrogram_augmentation_config)
80
+ self.encoder = EcapaTdnnEncoder(**config.encoder_config)
81
+
82
+ # Initialize weights and apply final processing
83
+ self.post_init()
84
+
85
+ def forward(
86
+ self,
87
+ input_values: Optional[torch.Tensor],
88
+ attention_mask: Optional[torch.Tensor] = None,
89
+ ) -> Union[Tuple, EcapaTdnnBaseModelOutput]:
90
+ if attention_mask is None:
91
+ attention_mask = torch.ones_like(input_values).to(input_values)
92
+ lengths = attention_mask.sum(dim=1).long()
93
+ extract_features, output_lengths = self.preprocessor(input_values, lengths)
94
+ if self.training:
95
+ extract_features = self.spec_augment(extract_features, output_lengths)
96
+ encoder_outputs, output_lengths = self.encoder(extract_features, output_lengths)
97
+
98
+ return EcapaTdnnBaseModelOutput(
99
+ encoder_outputs=encoder_outputs,
100
+ extract_features=extract_features,
101
+ output_lengths=output_lengths,
102
+ )
103
+
104
+
105
+ class EcapaTdnnForSequenceClassification(EcapaTdnnPreTrainedModel):
106
+
107
+ def __init__(self, config: EcapaTdnnConfig):
108
+ super().__init__(config)
109
+
110
+ self.ecapa_tdnn = EcapaTdnnModel(config)
111
+ self.classifier = SpeakerDecoder(**config.decoder_config)
112
+
113
+ if config.objective == 'additive_angular_margin':
114
+ self.loss_fct = AdditiveAngularMarginSoftmaxLoss(**config.objective_config)
115
+ elif config.objective == 'additive_margin':
116
+ self.loss_fct = AdditiveMarginSoftmaxLoss(**config.objective_config)
117
+ elif config.objective == 'cross_entropy':
118
+ self.loss_fct = nn.CrossEntropyLoss(**config.objective_config)
119
+
120
+ self.init_weights()
121
+
122
+ def freeze_base_model(self):
123
+ for param in self.ecapa_tdnn.parameters():
124
+ param.requires_grad = False
125
+
126
+ def forward(
127
+ self,
128
+ input_values: Optional[torch.Tensor],
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ labels: Optional[torch.Tensor] = None,
131
+ ) -> Union[Tuple, EcapaTdnnSequenceClassifierOutput]:
132
+ ecapa_tdnn_outputs = self.ecapa_tdnn(
133
+ input_values,
134
+ attention_mask,
135
+ )
136
+ logits, output_embeddings = self.classifier(
137
+ ecapa_tdnn_outputs.encoder_outputs,
138
+ ecapa_tdnn_outputs.output_lengths
139
+ )
140
+ logits = logits.view(-1, self.config.num_labels)
141
+
142
+ loss = None
143
+ if labels is not None:
144
+ loss = self.loss_fct(logits, labels.view(-1))
145
+
146
+ return EcapaTdnnSequenceClassifierOutput(
147
+ loss=loss,
148
+ logits=logits,
149
+ embeddings=output_embeddings,
150
+ )
module.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class NeuralModule(nn.Module):
6
+
7
+ @property
8
+ def num_weights(self):
9
+ """
10
+ Utility property that returns the total number of parameters of NeuralModule.
11
+ """
12
+ return self._num_weights()
13
+
14
+ @torch.jit.ignore
15
+ def _num_weights(self):
16
+ num: int = 0
17
+ for p in self.parameters():
18
+ if p.requires_grad:
19
+ num += p.numel()
20
+ return num
21
+
22
+ def freeze(self) -> None:
23
+ r"""
24
+ Freeze all params for inference.
25
+
26
+ This method sets `requires_grad` to False for all parameters of the module.
27
+ It also stores the original `requires_grad` state of each parameter in a dictionary,
28
+ so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`.
29
+ """
30
+ grad_map = {}
31
+
32
+ for pname, param in self.named_parameters():
33
+ # Store the original grad state
34
+ grad_map[pname] = param.requires_grad
35
+ # Freeze the parameter
36
+ param.requires_grad = False
37
+
38
+ # Store the frozen grad map
39
+ if not hasattr(self, '_frozen_grad_map'):
40
+ self._frozen_grad_map = grad_map
41
+ else:
42
+ self._frozen_grad_map.update(grad_map)
43
+
44
+ self.eval()
45
+
46
+ def unfreeze(self, partial: bool = False) -> None:
47
+ """
48
+ Unfreeze all parameters for training.
49
+
50
+ Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
51
+ The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
52
+ previously unfrozen prior `freeze()`.
53
+
54
+ Example:
55
+ Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always.
56
+
57
+ ```python
58
+ model.encoder.freeze() # Freezes all parameters in the encoder explicitly
59
+ ```
60
+
61
+ During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method.
62
+ This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called,
63
+ we should keep the encoder parameters frozen.
64
+
65
+ ```python
66
+ model.freeze() # Freezes all parameters in the model; encoder remains frozen
67
+ ```
68
+
69
+ Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling
70
+ `unfreeze(partial=True)`.
71
+
72
+ ```python
73
+ model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen
74
+ ```
75
+
76
+ Args:
77
+ partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen
78
+ when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`.
79
+ """
80
+ if partial and not hasattr(self, '_frozen_grad_map'):
81
+ raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")
82
+
83
+ for pname, param in self.named_parameters():
84
+ if not partial:
85
+ # Unfreeze all parameters
86
+ param.requires_grad = True
87
+ else:
88
+ # Unfreeze only parameters that were previously frozen
89
+
90
+ # Check if the parameter was frozen
91
+ if pname in self._frozen_grad_map:
92
+ param.requires_grad = self._frozen_grad_map[pname]
93
+ else:
94
+ # Log a warning if the parameter was not found in the frozen grad map
95
+ print(
96
+ f"Parameter {pname} not found in list of previously frozen parameters. "
97
+ f"Unfreezing this parameter."
98
+ )
99
+ param.requires_grad = True
100
+
101
+ # Clean up the frozen grad map
102
+ if hasattr(self, '_frozen_grad_map'):
103
+ delattr(self, '_frozen_grad_map')
104
+
105
+ self.train()
spectrogram_augment.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class SpecAugment(nn.Module):
11
+ """
12
+ Zeroes out(cuts) random continuous horisontal or
13
+ vertical segments of the spectrogram as described in
14
+ SpecAugment (https://arxiv.org/abs/1904.08779).
15
+
16
+ params:
17
+ freq_masks - how many frequency segments should be cut
18
+ time_masks - how many time segments should be cut
19
+ freq_width - maximum number of frequencies to be cut in one segment
20
+ time_width - maximum number of time steps to be cut in one segment.
21
+ Can be a positive integer or a float value in the range [0, 1].
22
+ If positive integer value, defines maximum number of time steps
23
+ to be cut in one segment.
24
+ If a float value, defines maximum percentage of timesteps that
25
+ are cut adaptively.
26
+ use_vectorized_code - GPU-based implementation with batched masking and GPU rng,
27
+ setting it to False reverts to the legacy implementation.
28
+ Fast implementation is inspired by torchaudio:
29
+ https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/functional/functional.py#L816
30
+ """
31
+
32
+ FREQ_AXIS = 1 # Frequency axis in the spectrogram tensor
33
+ TIME_AXIS = 2 # Time axis in the spectrogram tensor
34
+
35
+ def __init__(
36
+ self,
37
+ freq_masks: int = 0,
38
+ time_masks: int = 0,
39
+ freq_width: int = 10,
40
+ time_width: Union[int, float] = 10,
41
+ rng: random.Random = None,
42
+ mask_value: float = 0.0,
43
+ use_vectorized_code: bool = True,
44
+ ):
45
+ super().__init__()
46
+
47
+ self._rng = random.Random() if rng is None else rng
48
+
49
+ self.freq_masks = freq_masks
50
+ self.time_masks = time_masks
51
+
52
+ self.freq_width = freq_width
53
+ self.time_width = time_width
54
+
55
+ self.mask_value = mask_value
56
+ self.use_vectorized_code = use_vectorized_code
57
+
58
+ if isinstance(time_width, int):
59
+ self.adaptive_temporal_width = False
60
+ else:
61
+ if time_width > 1.0 or time_width < 0.0:
62
+ raise ValueError("If `time_width` is a float value, must be in range [0, 1]")
63
+
64
+ self.adaptive_temporal_width = True
65
+
66
+ @torch.no_grad()
67
+ def forward(self, input_spec, length):
68
+ if self.use_vectorized_code:
69
+ return self._forward_vectorized(input_spec, length)
70
+ else:
71
+ return self._forward_legacy(input_spec, length)
72
+
73
+ def _forward_legacy(self, input_spec, length):
74
+ batch_size, num_freq_bins, _ = input_spec.shape
75
+ # Move lengths to CPU before repeated indexing
76
+ lengths_cpu = length.cpu().numpy()
77
+ # Generate a numpy boolean mask. `True` elements represent where the input spec will be augmented.
78
+ fill_mask: np.array = np.full(shape=input_spec.shape, fill_value=False)
79
+ freq_start_upper_bound = num_freq_bins - self.freq_width
80
+ # Choose different mask ranges for each element of the batch
81
+ for idx in range(batch_size):
82
+ # Set freq masking
83
+ for _ in range(self.freq_masks):
84
+ start = self._rng.randint(0, freq_start_upper_bound)
85
+ width = self._rng.randint(0, self.freq_width)
86
+ fill_mask[idx, start : start + width, :] = True
87
+
88
+ # Derive time width, sometimes based percentage of input length.
89
+ if self.adaptive_temporal_width:
90
+ time_max_width = max(1, int(lengths_cpu[idx] * self.time_width))
91
+ else:
92
+ time_max_width = self.time_width
93
+ time_start_upper_bound = max(1, lengths_cpu[idx] - time_max_width)
94
+
95
+ # Set time masking
96
+ for _ in range(self.time_masks):
97
+ start = self._rng.randint(0, time_start_upper_bound)
98
+ width = self._rng.randint(0, time_max_width)
99
+ fill_mask[idx, :, start : start + width] = True
100
+ # Bring the mask to device and fill spec
101
+ fill_mask = torch.from_numpy(fill_mask).to(input_spec.device)
102
+ masked_spec = input_spec.masked_fill(mask=fill_mask, value=self.mask_value)
103
+ return masked_spec
104
+
105
+ def _forward_vectorized(self, input_spec: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
106
+ # time masks
107
+ input_spec = self._apply_masks(
108
+ input_spec=input_spec,
109
+ num_masks=self.time_masks,
110
+ length=length,
111
+ width=self.time_width,
112
+ axis=self.TIME_AXIS,
113
+ mask_value=self.mask_value,
114
+ )
115
+ # freq masks
116
+ input_spec = self._apply_masks(
117
+ input_spec=input_spec,
118
+ num_masks=self.freq_masks,
119
+ length=length,
120
+ width=self.freq_width,
121
+ axis=self.FREQ_AXIS,
122
+ mask_value=self.mask_value,
123
+ )
124
+ return input_spec
125
+
126
+ def _apply_masks(
127
+ self,
128
+ input_spec: torch.Tensor,
129
+ num_masks: int,
130
+ length: torch.Tensor,
131
+ width: Union[int, float],
132
+ mask_value: float,
133
+ axis: int,
134
+ ) -> torch.Tensor:
135
+
136
+ assert axis in (
137
+ self.FREQ_AXIS,
138
+ self.TIME_AXIS,
139
+ ), f"Axis can be only be equal to frequency \
140
+ ({self.FREQ_AXIS}) or time ({self.TIME_AXIS}). Received: {axis=}"
141
+ assert not (
142
+ isinstance(width, float) and axis == self.FREQ_AXIS
143
+ ), "Float width supported \
144
+ only with time axis."
145
+
146
+ batch_size = input_spec.shape[0]
147
+ axis_length = input_spec.shape[axis]
148
+
149
+ # If width is float then it is transformed into a tensor
150
+ if axis == self.TIME_AXIS and isinstance(width, float):
151
+ width = torch.clamp(width * length, max=axis_length).unsqueeze(1)
152
+
153
+ # Generate [0-1) random numbers and then scale the tensors.
154
+ # Use float32 dtype for begin/end mask markers before they are quantized to long.
155
+ mask_width = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32) * width
156
+ mask_width = mask_width.long()
157
+ mask_start = torch.rand((batch_size, num_masks), device=input_spec.device, dtype=torch.float32)
158
+
159
+ if axis == self.TIME_AXIS:
160
+ # length can only be used for the time axis
161
+ mask_start = mask_start * (length.unsqueeze(1) - mask_width)
162
+ else:
163
+ mask_start = mask_start * (axis_length - mask_width)
164
+
165
+ mask_start = mask_start.long()
166
+ mask_end = mask_start + mask_width
167
+
168
+ # Create mask values using vectorized indexing
169
+ indices = torch.arange(axis_length, device=input_spec.device)
170
+ # Create a mask_tensor with all the indices.
171
+ # The mask_tensor shape is (batch_size, num_masks, axis_length).
172
+ mask_tensor = (indices >= mask_start.unsqueeze(-1)) & (indices < mask_end.unsqueeze(-1))
173
+
174
+ # Reduce masks to one mask
175
+ mask_tensor = mask_tensor.any(dim=1)
176
+
177
+ # Create a final mask that aligns with the full tensor
178
+ mask = torch.zeros_like(input_spec, dtype=torch.bool)
179
+ if axis == self.TIME_AXIS:
180
+ mask_ranges = mask_tensor[:, None, :]
181
+ else: # axis == self.FREQ_AXIS
182
+ mask_ranges = mask_tensor[:, :, None]
183
+ mask[:, :, :] = mask_ranges
184
+
185
+ # Apply the mask value
186
+ return input_spec.masked_fill(mask=mask, value=mask_value)
187
+
188
+
189
+ class SpecCutout(nn.Module):
190
+ """
191
+ Zeroes out(cuts) random rectangles in the spectrogram
192
+ as described in (https://arxiv.org/abs/1708.04552).
193
+
194
+ params:
195
+ rect_masks - how many rectangular masks should be cut
196
+ rect_freq - maximum size of cut rectangles along the frequency dimension
197
+ rect_time - maximum size of cut rectangles along the time dimension
198
+ """
199
+
200
+ def __init__(self, rect_masks=0, rect_time=5, rect_freq=20, rng=None):
201
+ super(SpecCutout, self).__init__()
202
+
203
+ self._rng = random.Random() if rng is None else rng
204
+
205
+ self.rect_masks = rect_masks
206
+ self.rect_time = rect_time
207
+ self.rect_freq = rect_freq
208
+
209
+ @torch.no_grad()
210
+ def forward(self, input_spec):
211
+ sh = input_spec.shape
212
+
213
+ for idx in range(sh[0]):
214
+ for i in range(self.rect_masks):
215
+ rect_x = self._rng.randint(0, sh[1] - self.rect_freq)
216
+ rect_y = self._rng.randint(0, sh[2] - self.rect_time)
217
+
218
+ w_x = self._rng.randint(0, self.rect_freq)
219
+ w_y = self._rng.randint(0, self.rect_time)
220
+
221
+ input_spec[idx, rect_x : rect_x + w_x, rect_y : rect_y + w_y] = 0.0
222
+
223
+ return input_spec
tdnn_attention.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional
3
+
4
+ from numpy import inf
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.init import _calculate_correct_fan
10
+
11
+
12
+ class StatsPoolLayer(nn.Module):
13
+ """Statistics and time average pooling (TAP) layer
14
+
15
+ This computes mean and, optionally, standard deviation statistics across the time dimension.
16
+
17
+ Args:
18
+ feat_in: Input features with shape [B, D, T]
19
+ pool_mode: Type of pool mode. Supported modes are 'xvector' (mean and standard deviation) and 'tap' (time
20
+ average pooling, i.e., mean)
21
+ eps: Epsilon, minimum value before taking the square root, when using 'xvector' mode.
22
+ unbiased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default
23
+ for torch.Tensor.std() is True.
24
+
25
+ Returns:
26
+ Pooled statistics with shape [B, D].
27
+
28
+ Raises:
29
+ ValueError if an unsupported pooling mode is specified.
30
+ """
31
+
32
+ def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, unbiased: bool = True):
33
+ super().__init__()
34
+ supported_modes = {"xvector", "tap"}
35
+ if pool_mode not in supported_modes:
36
+ raise ValueError(f"Pool mode must be one of {supported_modes}; got '{pool_mode}'")
37
+ self.pool_mode = pool_mode
38
+ self.feat_in = feat_in
39
+ self.eps = eps
40
+ self.unbiased = unbiased
41
+ if self.pool_mode == 'xvector':
42
+ # Mean + std
43
+ self.feat_in *= 2
44
+
45
+ def forward(self, encoder_output, length=None):
46
+ if length is None:
47
+ mean = encoder_output.mean(dim=-1) # Time Axis
48
+ if self.pool_mode == 'xvector':
49
+ correction = 1 if self.unbiased else 0
50
+ std = encoder_output.std(dim=-1, correction=correction).clamp(min=self.eps)
51
+ pooled = torch.cat([mean, std], dim=-1)
52
+ else:
53
+ pooled = mean
54
+ else:
55
+ mask = make_seq_mask_like(like=encoder_output, lengths=length, valid_ones=False)
56
+ encoder_output = encoder_output.masked_fill(mask, 0.0)
57
+ # [B, D, T] -> [B, D]
58
+ means = encoder_output.mean(dim=-1)
59
+ # Re-scale to get padded means
60
+ means = means * (encoder_output.shape[-1] / length).unsqueeze(-1)
61
+ if self.pool_mode == "xvector":
62
+ correction = 1 if self.unbiased else 0
63
+ stds = (
64
+ encoder_output.sub(means.unsqueeze(-1))
65
+ .masked_fill(mask, 0.0)
66
+ .pow(2.0)
67
+ .sum(-1) # [B, D, T] -> [B, D]
68
+ .div(length.view(-1, 1).sub(correction))
69
+ .clamp(min=self.eps)
70
+ .sqrt()
71
+ )
72
+ pooled = torch.cat((means, stds), dim=-1)
73
+ else:
74
+ pooled = means
75
+ return pooled
76
+
77
+
78
+ class AttentivePoolLayer(nn.Module):
79
+ """
80
+ Attention pooling layer for pooling speaker embeddings
81
+ Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
82
+ inputs:
83
+ inp_filters: input feature channel length from encoder
84
+ attention_channels: intermediate attention channel size
85
+ kernel_size: kernel_size for TDNN and attention conv1d layers (default: 1)
86
+ dilation: dilation size for TDNN and attention conv1d layers (default: 1)
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ inp_filters: int,
92
+ attention_channels: int = 128,
93
+ kernel_size: int = 1,
94
+ dilation: int = 1,
95
+ eps: float = 1e-10,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.feat_in = 2 * inp_filters
100
+
101
+ self.attention_layer = nn.Sequential(
102
+ TdnnModule(inp_filters * 3, attention_channels, kernel_size=kernel_size, dilation=dilation),
103
+ nn.Tanh(),
104
+ nn.Conv1d(
105
+ in_channels=attention_channels,
106
+ out_channels=inp_filters,
107
+ kernel_size=kernel_size,
108
+ dilation=dilation,
109
+ ),
110
+ )
111
+ self.eps = eps
112
+
113
+ def forward(self, x, length=None):
114
+ max_len = x.size(2)
115
+
116
+ if length is None:
117
+ length = torch.ones(x.shape[0], device=x.device)
118
+
119
+ mask, num_values = lens_to_mask(length, max_len=max_len, device=x.device)
120
+
121
+ # encoder statistics
122
+ mean, std = get_statistics_with_mask(x, mask / num_values)
123
+ mean = mean.unsqueeze(2).repeat(1, 1, max_len)
124
+ std = std.unsqueeze(2).repeat(1, 1, max_len)
125
+ attn = torch.cat([x, mean, std], dim=1)
126
+
127
+ # attention statistics
128
+ attn = self.attention_layer(attn) # attention pass
129
+ attn = attn.masked_fill(mask == 0, -inf)
130
+ alpha = F.softmax(attn, dim=2) # attention values, α
131
+ mu, sg = get_statistics_with_mask(x, alpha) # µ and ∑
132
+
133
+ # gather
134
+ return torch.cat((mu, sg), dim=1).unsqueeze(2)
135
+
136
+
137
+ class TdnnModule(nn.Module):
138
+ """
139
+ Time Delayed Neural Module (TDNN) - 1D
140
+ input:
141
+ inp_filters: input filter channels for conv layer
142
+ out_filters: output filter channels for conv layer
143
+ kernel_size: kernel weight size for conv layer
144
+ dilation: dilation for conv layer
145
+ stride: stride for conv layer
146
+ padding: padding for conv layer (default None: chooses padding value such that input and output feature shape matches)
147
+ output:
148
+ tdnn layer output
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ inp_filters: int,
154
+ out_filters: int,
155
+ kernel_size: int = 1,
156
+ dilation: int = 1,
157
+ stride: int = 1,
158
+ groups: int = 1,
159
+ padding: int = None,
160
+ ):
161
+ super().__init__()
162
+ if padding is None:
163
+ padding = get_same_padding(kernel_size, stride=stride, dilation=dilation)
164
+
165
+ self.conv_layer = nn.Conv1d(
166
+ in_channels=inp_filters,
167
+ out_channels=out_filters,
168
+ kernel_size=kernel_size,
169
+ dilation=dilation,
170
+ groups=groups,
171
+ padding=padding,
172
+ )
173
+
174
+ self.activation = nn.ReLU()
175
+ self.bn = nn.BatchNorm1d(out_filters)
176
+
177
+ def forward(self, x, length=None):
178
+ x = self.conv_layer(x)
179
+ x = self.activation(x)
180
+ return self.bn(x)
181
+
182
+
183
+ class MaskedSEModule(nn.Module):
184
+ """
185
+ Squeeze and Excite module implementation with conv1d layers
186
+ input:
187
+ inp_filters: input filter channel size
188
+ se_filters: intermediate squeeze and excite channel output and input size
189
+ out_filters: output filter channel size
190
+ kernel_size: kernel_size for both conv1d layers
191
+ dilation: dilation size for both conv1d layers
192
+
193
+ output:
194
+ squeeze and excite layer output
195
+ """
196
+
197
+ def __init__(self, inp_filters: int, se_filters: int, out_filters: int, kernel_size: int = 1, dilation: int = 1):
198
+ super().__init__()
199
+ self.se_layer = nn.Sequential(
200
+ nn.Conv1d(
201
+ inp_filters,
202
+ se_filters,
203
+ kernel_size=kernel_size,
204
+ dilation=dilation,
205
+ ),
206
+ nn.ReLU(),
207
+ nn.BatchNorm1d(se_filters),
208
+ nn.Conv1d(
209
+ se_filters,
210
+ out_filters,
211
+ kernel_size=kernel_size,
212
+ dilation=dilation,
213
+ ),
214
+ nn.Sigmoid(),
215
+ )
216
+
217
+ def forward(self, input, length=None):
218
+ if length is None:
219
+ x = torch.mean(input, dim=2, keep_dim=True)
220
+ else:
221
+ max_len = input.size(2)
222
+ mask, num_values = lens_to_mask(length, max_len=max_len, device=input.device)
223
+ x = torch.sum((input * mask), dim=2, keepdim=True) / (num_values)
224
+
225
+ out = self.se_layer(x)
226
+ return out * input
227
+
228
+
229
+ class TdnnSeModule(nn.Module):
230
+ """
231
+ Modified building SE_TDNN group module block from ECAPA implementation for faster training and inference
232
+ Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)
233
+ inputs:
234
+ inp_filters: input filter channel size
235
+ out_filters: output filter channel size
236
+ group_scale: scale value to group wider conv channels (deafult:8)
237
+ se_channels: squeeze and excite output channel size (deafult: 1024/8= 128)
238
+ kernel_size: kernel_size for group conv1d layers (default: 1)
239
+ dilation: dilation size for group conv1d layers (default: 1)
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ inp_filters: int,
245
+ out_filters: int,
246
+ group_scale: int = 8,
247
+ se_channels: int = 128,
248
+ kernel_size: int = 1,
249
+ dilation: int = 1,
250
+ init_mode: str = 'xavier_uniform',
251
+ ):
252
+ super().__init__()
253
+ self.out_filters = out_filters
254
+ padding_val = get_same_padding(kernel_size=kernel_size, dilation=dilation, stride=1)
255
+
256
+ group_conv = nn.Conv1d(
257
+ out_filters,
258
+ out_filters,
259
+ kernel_size=kernel_size,
260
+ dilation=dilation,
261
+ padding=padding_val,
262
+ groups=group_scale,
263
+ )
264
+ self.group_tdnn_block = nn.Sequential(
265
+ TdnnModule(inp_filters, out_filters, kernel_size=1, dilation=1),
266
+ group_conv,
267
+ nn.ReLU(),
268
+ nn.BatchNorm1d(out_filters),
269
+ TdnnModule(out_filters, out_filters, kernel_size=1, dilation=1),
270
+ )
271
+
272
+ self.se_layer = MaskedSEModule(out_filters, se_channels, out_filters)
273
+
274
+ self.apply(lambda x: init_weights(x, mode=init_mode))
275
+
276
+ def forward(self, input, length=None):
277
+ x = self.group_tdnn_block(input)
278
+ x = self.se_layer(x, length)
279
+ return x + input
280
+
281
+
282
+ class Res2NetBlock(nn.Module):
283
+ """
284
+ Res2Net module that splits input channels into groups and processes them separately before merging.
285
+ This allows multi-scale feature extraction.
286
+ """
287
+ def __init__(self, in_channels, out_channels, scale=4, kernel_size=1, dilation=1):
288
+ super().__init__()
289
+ assert in_channels % scale == 0, "in_channels must be divisible by scale"
290
+
291
+ self.scale = scale
292
+ self.width = in_channels // scale # Number of channels per group
293
+
294
+ self.convs = nn.ModuleList([
295
+ nn.Conv1d(self.width, self.width, kernel_size=kernel_size, dilation=dilation, padding=dilation, bias=False)
296
+ for _ in range(scale - 1)
297
+ ])
298
+ self.bn = nn.BatchNorm1d(out_channels)
299
+ self.activation = nn.ReLU()
300
+
301
+ def forward(self, x):
302
+ """
303
+ x: [B, C, T]
304
+ """
305
+ splits = torch.split(x, self.width, dim=1)
306
+ outputs = [splits[0]] # First part remains unchanged
307
+
308
+ for i in range(1, self.scale):
309
+ conv_out = self.convs[i - 1](splits[i]) # Apply convolution on each group
310
+ outputs.append(conv_out + outputs[i - 1]) # Hierarchical aggregation
311
+
312
+ out = torch.cat(outputs, dim=1) # Merge groups
313
+ return self.activation(self.bn(out))
314
+
315
+
316
+ class TdnnSeRes2NetModule(nn.Module):
317
+ """
318
+ SE-TDNN module with Res2Net for ECAPA-TDNN.
319
+ """
320
+ def __init__(
321
+ self,
322
+ inp_filters: int,
323
+ out_filters: int,
324
+ group_scale: int = 1,
325
+ se_channels: int = 128,
326
+ kernel_size: int = 1,
327
+ dilation: int = 1,
328
+ res2net_scale: int = 8, # New Res2Net parameter
329
+ ):
330
+ super().__init__()
331
+
332
+ # First TDNN layer
333
+ self.tdnn1 = TdnnModule(inp_filters, out_filters, kernel_size=1, dilation=1, groups=group_scale)
334
+
335
+ # Res2Net block replaces grouped TDNN
336
+ self.res2net = Res2NetBlock(out_filters, out_filters, scale=res2net_scale, kernel_size=kernel_size, dilation=dilation)
337
+
338
+ # Squeeze-and-Excite module
339
+ self.se_layer = MaskedSEModule(out_filters, se_channels, out_filters)
340
+
341
+ def forward(self, x, length=None):
342
+ residual = x
343
+ x = self.tdnn1(x)
344
+ x = self.res2net(x) # Apply Res2Net block
345
+ x = self.se_layer(x, length)
346
+ return x + residual # Residual connection
347
+
348
+
349
+ class MaskedConv1d(nn.Module):
350
+
351
+ __constants__ = ["use_conv_mask", "real_out_channels", "heads"]
352
+
353
+ def __init__(
354
+ self,
355
+ in_channels,
356
+ out_channels,
357
+ kernel_size,
358
+ stride=1,
359
+ padding=0,
360
+ dilation=1,
361
+ groups=1,
362
+ heads=-1,
363
+ bias=False,
364
+ use_mask=True,
365
+ quantize=False,
366
+ ):
367
+ super(MaskedConv1d, self).__init__()
368
+
369
+ if not (heads == -1 or groups == in_channels):
370
+ raise ValueError("Only use heads for depthwise convolutions")
371
+
372
+ self.real_out_channels = out_channels
373
+ if heads != -1:
374
+ in_channels = heads
375
+ out_channels = heads
376
+ groups = heads
377
+
378
+ # preserve original padding
379
+ self._padding = padding
380
+
381
+ # if padding is a tuple/list, it is considered as asymmetric padding
382
+ if type(padding) in (tuple, list):
383
+ self.pad_layer = nn.ConstantPad1d(padding, value=0.0)
384
+ # reset padding for conv since pad_layer will handle this
385
+ padding = 0
386
+ else:
387
+ self.pad_layer = None
388
+
389
+ self.conv = nn.Conv1d(
390
+ in_channels,
391
+ out_channels,
392
+ kernel_size,
393
+ stride=stride,
394
+ padding=padding,
395
+ dilation=dilation,
396
+ groups=groups,
397
+ bias=bias,
398
+ )
399
+ self.use_mask = use_mask
400
+ self.heads = heads
401
+
402
+ # Calculations for "same" padding cache
403
+ self.same_padding = (self.conv.stride[0] == 1) and (
404
+ 2 * self.conv.padding[0] == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1)
405
+ )
406
+ if self.pad_layer is None:
407
+ self.same_padding_asymmetric = False
408
+ else:
409
+ self.same_padding_asymmetric = (self.conv.stride[0] == 1) and (
410
+ sum(self._padding) == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1)
411
+ )
412
+
413
+ # `self.lens` caches consecutive integers from 0 to `self.max_len` that are used to compute the mask for a
414
+ # batch. Recomputed to bigger size as needed. Stored on a device of the latest batch lens.
415
+ if self.use_mask:
416
+ self.max_len = torch.tensor(0)
417
+ self.lens = torch.tensor(0)
418
+
419
+ def get_seq_len(self, lens):
420
+ if self.same_padding or self.same_padding_asymmetric:
421
+ return lens
422
+
423
+ if self.pad_layer is None:
424
+ return (
425
+ torch.div(
426
+ lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1,
427
+ self.conv.stride[0],
428
+ rounding_mode='trunc',
429
+ )
430
+ + 1
431
+ )
432
+ else:
433
+ return (
434
+ torch.div(
435
+ lens + sum(self._padding) - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1,
436
+ self.conv.stride[0],
437
+ rounding_mode='trunc',
438
+ )
439
+ + 1
440
+ )
441
+
442
+ def forward(self, x, lens):
443
+ if self.use_mask:
444
+ # Generally will be called by ConvASREncoder, but kept as single gpu backup.
445
+ if x.size(2) > self.max_len:
446
+ self.update_masked_length(x.size(2), device=lens.device)
447
+ x = self.mask_input(x, lens)
448
+
449
+ # Update lengths
450
+ lens = self.get_seq_len(lens)
451
+
452
+ # asymmtric pad if necessary
453
+ if self.pad_layer is not None:
454
+ x = self.pad_layer(x)
455
+
456
+ sh = x.shape
457
+ if self.heads != -1:
458
+ x = x.view(-1, self.heads, sh[-1])
459
+
460
+ out = self.conv(x)
461
+
462
+ if self.heads != -1:
463
+ out = out.view(sh[0], self.real_out_channels, -1)
464
+
465
+ return out, lens
466
+
467
+ def update_masked_length(self, max_len, seq_range=None, device=None):
468
+ if seq_range is None:
469
+ self.lens, self.max_len = _masked_conv_init_lens(self.lens, max_len, self.max_len)
470
+ self.lens = self.lens.to(device)
471
+ else:
472
+ self.lens = seq_range
473
+ self.max_len = torch.tensor(max_len)
474
+
475
+ def mask_input(self, x, lens):
476
+ max_len = x.size(2)
477
+ mask = self.lens[:max_len].unsqueeze(0).to(lens.device) < lens.unsqueeze(1)
478
+ x = x * mask.unsqueeze(1).to(device=x.device)
479
+ return x
480
+
481
+
482
+ @torch.jit.script
483
+ def _masked_conv_init_lens(lens: torch.Tensor, current_maxlen: int, original_maxlen: torch.Tensor):
484
+ if current_maxlen > original_maxlen:
485
+ new_lens = torch.arange(current_maxlen)
486
+ new_max_lens = torch.tensor(current_maxlen)
487
+ else:
488
+ new_lens = lens
489
+ new_max_lens = original_maxlen
490
+ return new_lens, new_max_lens
491
+
492
+
493
+ def get_same_padding(kernel_size, stride, dilation) -> int:
494
+ if stride > 1 and dilation > 1:
495
+ raise ValueError("Only stride OR dilation may be greater than 1")
496
+ return (dilation * (kernel_size - 1)) // 2
497
+
498
+
499
+ def lens_to_mask(lens: List[int], max_len: int, device: str = None):
500
+ """
501
+ outputs masking labels for list of lengths of audio features, with max length of any
502
+ mask as max_len
503
+ input:
504
+ lens: list of lens
505
+ max_len: max length of any audio feature
506
+ output:
507
+ mask: masked labels
508
+ num_values: sum of mask values for each feature (useful for computing statistics later)
509
+ """
510
+ lens_mat = torch.arange(max_len).to(device)
511
+ mask = lens_mat[:max_len].unsqueeze(0) < lens.unsqueeze(1)
512
+ mask = mask.unsqueeze(1)
513
+ num_values = torch.sum(mask, dim=2, keepdim=True)
514
+ return mask, num_values
515
+
516
+
517
+ def get_statistics_with_mask(x: torch.Tensor, m: torch.Tensor, dim: int = 2, eps: float = 1e-10):
518
+ """
519
+ compute mean and standard deviation of input(x) provided with its masking labels (m)
520
+ input:
521
+ x: feature input
522
+ m: averaged mask labels
523
+ output:
524
+ mean: mean of input features
525
+ std: stadard deviation of input features
526
+ """
527
+ mean = torch.sum((m * x), dim=dim)
528
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
529
+ return mean, std
530
+
531
+
532
+ @torch.jit.script_if_tracing
533
+ def make_seq_mask_like(
534
+ like: torch.Tensor, lengths: torch.Tensor, valid_ones: bool = True, time_dim: int = -1
535
+ ) -> torch.Tensor:
536
+ mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.unsqueeze(-1))
537
+ # Match number of dims in `like` tensor
538
+ for _ in range(like.dim() - mask.dim()):
539
+ mask = mask.unsqueeze(1)
540
+ # If time dim != -1, transpose to proper dim.
541
+ if time_dim != -1:
542
+ mask = mask.transpose(time_dim, -1)
543
+ if not valid_ones:
544
+ mask = ~mask
545
+ return mask
546
+
547
+
548
+ def init_weights(m, mode: Optional[str] = 'xavier_uniform'):
549
+ if isinstance(m, MaskedConv1d):
550
+ init_weights(m.conv, mode)
551
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
552
+ if mode is not None:
553
+ if mode == 'xavier_uniform':
554
+ nn.init.xavier_uniform_(m.weight, gain=1.0)
555
+ elif mode == 'xavier_normal':
556
+ nn.init.xavier_normal_(m.weight, gain=1.0)
557
+ elif mode == 'kaiming_uniform':
558
+ nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
559
+ elif mode == 'kaiming_normal':
560
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
561
+ elif mode == 'tds_uniform':
562
+ tds_uniform_(m.weight)
563
+ elif mode == 'tds_normal':
564
+ tds_normal_(m.weight)
565
+ else:
566
+ raise ValueError("Unknown Initialization mode: {0}".format(mode))
567
+ elif isinstance(m, nn.BatchNorm1d):
568
+ if m.track_running_stats:
569
+ m.running_mean.zero_()
570
+ m.running_var.fill_(1)
571
+ m.num_batches_tracked.zero_()
572
+ if m.affine:
573
+ nn.init.ones_(m.weight)
574
+ nn.init.zeros_(m.bias)
575
+
576
+
577
+ def tds_uniform_(tensor, mode='fan_in'):
578
+ """
579
+ Uniform Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
580
+ Normalized to -
581
+
582
+ .. math::
583
+ \\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}}
584
+
585
+ Args:
586
+ tensor: an n-dimensional `torch.Tensor`
587
+ mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
588
+ preserves the magnitude of the variance of the weights in the
589
+ forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
590
+ backwards pass.
591
+ """
592
+ fan = _calculate_correct_fan(tensor, mode)
593
+ gain = 2.0 # sqrt(4.0) = 2
594
+ std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
595
+ bound = std # Calculate uniform bounds from standard deviation
596
+ with torch.no_grad():
597
+ return tensor.uniform_(-bound, bound)
598
+
599
+
600
+ def tds_normal_(tensor, mode='fan_in'):
601
+ """
602
+ Normal Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf)
603
+ Normalized to -
604
+
605
+ .. math::
606
+ \\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}}
607
+
608
+ Args:
609
+ tensor: an n-dimensional `torch.Tensor`
610
+ mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
611
+ preserves the magnitude of the variance of the weights in the
612
+ forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
613
+ backwards pass.
614
+ """
615
+ fan = _calculate_correct_fan(tensor, mode)
616
+ gain = 2.0
617
+ std = gain / math.sqrt(fan) # sqrt(4.0 / fan_in)
618
+ bound = std # Calculate uniform bounds from standard deviation
619
+ with torch.no_grad():
620
+ return tensor.normal_(0.0, bound)