|
import math |
|
from typing import List, Optional |
|
|
|
from numpy import inf |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.init import _calculate_correct_fan |
|
|
|
|
|
class StatsPoolLayer(nn.Module): |
|
"""Statistics and time average pooling (TAP) layer |
|
|
|
This computes mean and, optionally, standard deviation statistics across the time dimension. |
|
|
|
Args: |
|
feat_in: Input features with shape [B, D, T] |
|
pool_mode: Type of pool mode. Supported modes are 'xvector' (mean and standard deviation) and 'tap' (time |
|
average pooling, i.e., mean) |
|
eps: Epsilon, minimum value before taking the square root, when using 'xvector' mode. |
|
unbiased: Whether to use the biased estimator for the standard deviation when using 'xvector' mode. The default |
|
for torch.Tensor.std() is True. |
|
|
|
Returns: |
|
Pooled statistics with shape [B, D]. |
|
|
|
Raises: |
|
ValueError if an unsupported pooling mode is specified. |
|
""" |
|
|
|
def __init__(self, feat_in: int, pool_mode: str = 'xvector', eps: float = 1e-10, unbiased: bool = True): |
|
super().__init__() |
|
supported_modes = {"xvector", "tap"} |
|
if pool_mode not in supported_modes: |
|
raise ValueError(f"Pool mode must be one of {supported_modes}; got '{pool_mode}'") |
|
self.pool_mode = pool_mode |
|
self.feat_in = feat_in |
|
self.eps = eps |
|
self.unbiased = unbiased |
|
if self.pool_mode == 'xvector': |
|
|
|
self.feat_in *= 2 |
|
|
|
def forward(self, encoder_output, length=None): |
|
if length is None: |
|
mean = encoder_output.mean(dim=-1) |
|
if self.pool_mode == 'xvector': |
|
correction = 1 if self.unbiased else 0 |
|
std = encoder_output.std(dim=-1, correction=correction).clamp(min=self.eps) |
|
pooled = torch.cat([mean, std], dim=-1) |
|
else: |
|
pooled = mean |
|
else: |
|
mask = make_seq_mask_like(like=encoder_output, lengths=length, valid_ones=False) |
|
encoder_output = encoder_output.masked_fill(mask, 0.0) |
|
|
|
means = encoder_output.mean(dim=-1) |
|
|
|
means = means * (encoder_output.shape[-1] / length).unsqueeze(-1) |
|
if self.pool_mode == "xvector": |
|
correction = 1 if self.unbiased else 0 |
|
stds = ( |
|
encoder_output.sub(means.unsqueeze(-1)) |
|
.masked_fill(mask, 0.0) |
|
.pow(2.0) |
|
.sum(-1) |
|
.div(length.view(-1, 1).sub(correction)) |
|
.clamp(min=self.eps) |
|
.sqrt() |
|
) |
|
pooled = torch.cat((means, stds), dim=-1) |
|
else: |
|
pooled = means |
|
return pooled |
|
|
|
|
|
class AttentivePoolLayer(nn.Module): |
|
""" |
|
Attention pooling layer for pooling speaker embeddings |
|
Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf) |
|
inputs: |
|
inp_filters: input feature channel length from encoder |
|
attention_channels: intermediate attention channel size |
|
kernel_size: kernel_size for TDNN and attention conv1d layers (default: 1) |
|
dilation: dilation size for TDNN and attention conv1d layers (default: 1) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
inp_filters: int, |
|
attention_channels: int = 128, |
|
kernel_size: int = 1, |
|
dilation: int = 1, |
|
eps: float = 1e-10, |
|
): |
|
super().__init__() |
|
|
|
self.feat_in = 2 * inp_filters |
|
|
|
self.attention_layer = nn.Sequential( |
|
TdnnModule(inp_filters * 3, attention_channels, kernel_size=kernel_size, dilation=dilation), |
|
nn.Tanh(), |
|
nn.Conv1d( |
|
in_channels=attention_channels, |
|
out_channels=inp_filters, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
), |
|
) |
|
self.eps = eps |
|
|
|
def forward(self, x, length=None): |
|
max_len = x.size(2) |
|
|
|
if length is None: |
|
length = torch.ones(x.shape[0], device=x.device) |
|
|
|
mask, num_values = lens_to_mask(length, max_len=max_len, device=x.device) |
|
|
|
|
|
mean, std = get_statistics_with_mask(x, mask / num_values) |
|
mean = mean.unsqueeze(2).repeat(1, 1, max_len) |
|
std = std.unsqueeze(2).repeat(1, 1, max_len) |
|
attn = torch.cat([x, mean, std], dim=1) |
|
|
|
|
|
attn = self.attention_layer(attn) |
|
attn = attn.masked_fill(mask == 0, -inf) |
|
alpha = F.softmax(attn, dim=2) |
|
mu, sg = get_statistics_with_mask(x, alpha) |
|
|
|
|
|
return torch.cat((mu, sg), dim=1).unsqueeze(2) |
|
|
|
|
|
class TdnnModule(nn.Module): |
|
""" |
|
Time Delayed Neural Module (TDNN) - 1D |
|
input: |
|
inp_filters: input filter channels for conv layer |
|
out_filters: output filter channels for conv layer |
|
kernel_size: kernel weight size for conv layer |
|
dilation: dilation for conv layer |
|
stride: stride for conv layer |
|
padding: padding for conv layer (default None: chooses padding value such that input and output feature shape matches) |
|
output: |
|
tdnn layer output |
|
""" |
|
|
|
def __init__( |
|
self, |
|
inp_filters: int, |
|
out_filters: int, |
|
kernel_size: int = 1, |
|
dilation: int = 1, |
|
stride: int = 1, |
|
groups: int = 1, |
|
padding: int = None, |
|
): |
|
super().__init__() |
|
if padding is None: |
|
padding = get_same_padding(kernel_size, stride=stride, dilation=dilation) |
|
|
|
self.conv_layer = nn.Conv1d( |
|
in_channels=inp_filters, |
|
out_channels=out_filters, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
groups=groups, |
|
padding=padding, |
|
) |
|
|
|
self.activation = nn.ReLU() |
|
self.bn = nn.BatchNorm1d(out_filters) |
|
|
|
def forward(self, x, length=None): |
|
x = self.conv_layer(x) |
|
x = self.activation(x) |
|
return self.bn(x) |
|
|
|
|
|
class MaskedSEModule(nn.Module): |
|
""" |
|
Squeeze and Excite module implementation with conv1d layers |
|
input: |
|
inp_filters: input filter channel size |
|
se_filters: intermediate squeeze and excite channel output and input size |
|
out_filters: output filter channel size |
|
kernel_size: kernel_size for both conv1d layers |
|
dilation: dilation size for both conv1d layers |
|
|
|
output: |
|
squeeze and excite layer output |
|
""" |
|
|
|
def __init__(self, inp_filters: int, se_filters: int, out_filters: int, kernel_size: int = 1, dilation: int = 1): |
|
super().__init__() |
|
self.se_layer = nn.Sequential( |
|
nn.Conv1d( |
|
inp_filters, |
|
se_filters, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
), |
|
nn.ReLU(), |
|
nn.BatchNorm1d(se_filters), |
|
nn.Conv1d( |
|
se_filters, |
|
out_filters, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
), |
|
nn.Sigmoid(), |
|
) |
|
|
|
def forward(self, input, length=None): |
|
if length is None: |
|
x = torch.mean(input, dim=2, keep_dim=True) |
|
else: |
|
max_len = input.size(2) |
|
mask, num_values = lens_to_mask(length, max_len=max_len, device=input.device) |
|
x = torch.sum((input * mask), dim=2, keepdim=True) / (num_values) |
|
|
|
out = self.se_layer(x) |
|
return out * input |
|
|
|
|
|
class TdnnSeModule(nn.Module): |
|
""" |
|
Modified building SE_TDNN group module block from ECAPA implementation for faster training and inference |
|
Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf) |
|
inputs: |
|
inp_filters: input filter channel size |
|
out_filters: output filter channel size |
|
group_scale: scale value to group wider conv channels (deafult:8) |
|
se_channels: squeeze and excite output channel size (deafult: 1024/8= 128) |
|
kernel_size: kernel_size for group conv1d layers (default: 1) |
|
dilation: dilation size for group conv1d layers (default: 1) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
inp_filters: int, |
|
out_filters: int, |
|
group_scale: int = 8, |
|
se_channels: int = 128, |
|
kernel_size: int = 1, |
|
dilation: int = 1, |
|
init_mode: str = 'xavier_uniform', |
|
): |
|
super().__init__() |
|
self.out_filters = out_filters |
|
padding_val = get_same_padding(kernel_size=kernel_size, dilation=dilation, stride=1) |
|
|
|
group_conv = nn.Conv1d( |
|
out_filters, |
|
out_filters, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
padding=padding_val, |
|
groups=group_scale, |
|
) |
|
self.group_tdnn_block = nn.Sequential( |
|
TdnnModule(inp_filters, out_filters, kernel_size=1, dilation=1), |
|
group_conv, |
|
nn.ReLU(), |
|
nn.BatchNorm1d(out_filters), |
|
TdnnModule(out_filters, out_filters, kernel_size=1, dilation=1), |
|
) |
|
|
|
self.se_layer = MaskedSEModule(out_filters, se_channels, out_filters) |
|
|
|
self.apply(lambda x: init_weights(x, mode=init_mode)) |
|
|
|
def forward(self, inputs, length=None): |
|
x = self.group_tdnn_block(inputs) |
|
x = self.se_layer(x, length) |
|
return x + inputs |
|
|
|
|
|
class Res2NetBlock(nn.Module): |
|
""" |
|
Res2Net module that splits input channels into groups and processes them separately before merging. |
|
This allows multi-scale feature extraction. |
|
""" |
|
def __init__(self, in_channels, out_channels, scale=4, kernel_size=1, dilation=1): |
|
super().__init__() |
|
assert in_channels % scale == 0, "in_channels must be divisible by scale" |
|
|
|
self.scale = scale |
|
self.width = in_channels // scale |
|
|
|
self.convs = nn.ModuleList([ |
|
nn.Conv1d(self.width, self.width, kernel_size=kernel_size, dilation=dilation, padding=dilation, bias=False) |
|
for _ in range(scale - 1) |
|
]) |
|
self.bn = nn.BatchNorm1d(out_channels) |
|
self.activation = nn.ReLU() |
|
|
|
def forward(self, x): |
|
""" |
|
x: [B, C, T] |
|
""" |
|
splits = torch.split(x, self.width, dim=1) |
|
outputs = [splits[0]] |
|
|
|
for i in range(1, self.scale): |
|
conv_out = self.convs[i - 1](splits[i]) |
|
outputs.append(conv_out + outputs[i - 1]) |
|
|
|
out = torch.cat(outputs, dim=1) |
|
return self.activation(self.bn(out)) |
|
|
|
|
|
class TdnnSeRes2NetModule(nn.Module): |
|
""" |
|
SE-TDNN module with Res2Net for ECAPA-TDNN. |
|
""" |
|
def __init__( |
|
self, |
|
inp_filters: int, |
|
out_filters: int, |
|
group_scale: int = 1, |
|
se_channels: int = 128, |
|
kernel_size: int = 1, |
|
dilation: int = 1, |
|
res2net_scale: int = 8, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.tdnn1 = TdnnModule(inp_filters, out_filters, kernel_size=1, dilation=1, groups=group_scale) |
|
|
|
|
|
self.res2net = Res2NetBlock(out_filters, out_filters, scale=res2net_scale, kernel_size=kernel_size, dilation=dilation) |
|
|
|
|
|
self.se_layer = MaskedSEModule(out_filters, se_channels, out_filters) |
|
|
|
def forward(self, x, length=None): |
|
residual = x |
|
x = self.tdnn1(x) |
|
x = self.res2net(x) |
|
x = self.se_layer(x, length) |
|
return x + residual |
|
|
|
|
|
class MaskedConv1d(nn.Module): |
|
|
|
__constants__ = ["use_conv_mask", "real_out_channels", "heads"] |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
groups=1, |
|
heads=-1, |
|
bias=False, |
|
use_mask=True, |
|
quantize=False, |
|
): |
|
super(MaskedConv1d, self).__init__() |
|
|
|
if not (heads == -1 or groups == in_channels): |
|
raise ValueError("Only use heads for depthwise convolutions") |
|
|
|
self.real_out_channels = out_channels |
|
if heads != -1: |
|
in_channels = heads |
|
out_channels = heads |
|
groups = heads |
|
|
|
|
|
self._padding = padding |
|
|
|
|
|
if type(padding) in (tuple, list): |
|
self.pad_layer = nn.ConstantPad1d(padding, value=0.0) |
|
|
|
padding = 0 |
|
else: |
|
self.pad_layer = None |
|
|
|
self.conv = nn.Conv1d( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
self.use_mask = use_mask |
|
self.heads = heads |
|
|
|
|
|
self.same_padding = (self.conv.stride[0] == 1) and ( |
|
2 * self.conv.padding[0] == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) |
|
) |
|
if self.pad_layer is None: |
|
self.same_padding_asymmetric = False |
|
else: |
|
self.same_padding_asymmetric = (self.conv.stride[0] == 1) and ( |
|
sum(self._padding) == self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) |
|
) |
|
|
|
|
|
|
|
if self.use_mask: |
|
self.max_len = torch.tensor(0) |
|
self.lens = torch.tensor(0) |
|
|
|
def get_seq_len(self, lens): |
|
if self.same_padding or self.same_padding_asymmetric: |
|
return lens |
|
|
|
if self.pad_layer is None: |
|
return ( |
|
torch.div( |
|
lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1, |
|
self.conv.stride[0], |
|
rounding_mode='trunc', |
|
) |
|
+ 1 |
|
) |
|
else: |
|
return ( |
|
torch.div( |
|
lens + sum(self._padding) - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1, |
|
self.conv.stride[0], |
|
rounding_mode='trunc', |
|
) |
|
+ 1 |
|
) |
|
|
|
def forward(self, x, lens): |
|
if self.use_mask: |
|
|
|
if x.size(2) > self.max_len: |
|
self.update_masked_length(x.size(2), device=lens.device) |
|
x = self.mask_input(x, lens) |
|
|
|
|
|
lens = self.get_seq_len(lens) |
|
|
|
|
|
if self.pad_layer is not None: |
|
x = self.pad_layer(x) |
|
|
|
sh = x.shape |
|
if self.heads != -1: |
|
x = x.view(-1, self.heads, sh[-1]) |
|
|
|
out = self.conv(x) |
|
|
|
if self.heads != -1: |
|
out = out.view(sh[0], self.real_out_channels, -1) |
|
|
|
return out, lens |
|
|
|
def update_masked_length(self, max_len, seq_range=None, device=None): |
|
if seq_range is None: |
|
self.lens, self.max_len = _masked_conv_init_lens(self.lens, max_len, self.max_len) |
|
self.lens = self.lens.to(device) |
|
else: |
|
self.lens = seq_range |
|
self.max_len = torch.tensor(max_len) |
|
|
|
def mask_input(self, x, lens): |
|
max_len = x.size(2) |
|
mask = self.lens[:max_len].unsqueeze(0).to(lens.device) < lens.unsqueeze(1) |
|
x = x * mask.unsqueeze(1).to(device=x.device) |
|
return x |
|
|
|
|
|
@torch.jit.script |
|
def _masked_conv_init_lens(lens: torch.Tensor, current_maxlen: int, original_maxlen: torch.Tensor): |
|
if current_maxlen > original_maxlen: |
|
new_lens = torch.arange(current_maxlen) |
|
new_max_lens = torch.tensor(current_maxlen) |
|
else: |
|
new_lens = lens |
|
new_max_lens = original_maxlen |
|
return new_lens, new_max_lens |
|
|
|
|
|
def get_same_padding(kernel_size, stride, dilation) -> int: |
|
if stride > 1 and dilation > 1: |
|
raise ValueError("Only stride OR dilation may be greater than 1") |
|
return (dilation * (kernel_size - 1)) // 2 |
|
|
|
|
|
def lens_to_mask(lens: List[int], max_len: int, device: str = None): |
|
""" |
|
outputs masking labels for list of lengths of audio features, with max length of any |
|
mask as max_len |
|
input: |
|
lens: list of lens |
|
max_len: max length of any audio feature |
|
output: |
|
mask: masked labels |
|
num_values: sum of mask values for each feature (useful for computing statistics later) |
|
""" |
|
lens_mat = torch.arange(max_len).to(device) |
|
mask = lens_mat[:max_len].unsqueeze(0) < lens.unsqueeze(1) |
|
mask = mask.unsqueeze(1) |
|
num_values = torch.sum(mask, dim=2, keepdim=True) |
|
return mask, num_values |
|
|
|
|
|
def get_statistics_with_mask(x: torch.Tensor, m: torch.Tensor, dim: int = 2, eps: float = 1e-10): |
|
""" |
|
compute mean and standard deviation of input(x) provided with its masking labels (m) |
|
input: |
|
x: feature input |
|
m: averaged mask labels |
|
output: |
|
mean: mean of input features |
|
std: stadard deviation of input features |
|
""" |
|
mean = torch.sum((m * x), dim=dim) |
|
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) |
|
return mean, std |
|
|
|
|
|
@torch.jit.script_if_tracing |
|
def make_seq_mask_like( |
|
like: torch.Tensor, lengths: torch.Tensor, valid_ones: bool = True, time_dim: int = -1 |
|
) -> torch.Tensor: |
|
mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.unsqueeze(-1)) |
|
|
|
for _ in range(like.dim() - mask.dim()): |
|
mask = mask.unsqueeze(1) |
|
|
|
if time_dim != -1: |
|
mask = mask.transpose(time_dim, -1) |
|
if not valid_ones: |
|
mask = ~mask |
|
return mask |
|
|
|
|
|
def init_weights(m, mode: Optional[str] = 'xavier_uniform'): |
|
if isinstance(m, MaskedConv1d): |
|
init_weights(m.conv, mode) |
|
if isinstance(m, (nn.Conv1d, nn.Linear)): |
|
if mode is not None: |
|
if mode == 'xavier_uniform': |
|
nn.init.xavier_uniform_(m.weight, gain=1.0) |
|
elif mode == 'xavier_normal': |
|
nn.init.xavier_normal_(m.weight, gain=1.0) |
|
elif mode == 'kaiming_uniform': |
|
nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") |
|
elif mode == 'kaiming_normal': |
|
nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
|
elif mode == 'tds_uniform': |
|
tds_uniform_(m.weight) |
|
elif mode == 'tds_normal': |
|
tds_normal_(m.weight) |
|
else: |
|
raise ValueError("Unknown Initialization mode: {0}".format(mode)) |
|
elif isinstance(m, nn.BatchNorm1d): |
|
if m.track_running_stats: |
|
m.running_mean.zero_() |
|
m.running_var.fill_(1) |
|
m.num_batches_tracked.zero_() |
|
if m.affine: |
|
nn.init.ones_(m.weight) |
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def tds_uniform_(tensor, mode='fan_in'): |
|
""" |
|
Uniform Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf) |
|
Normalized to - |
|
|
|
.. math:: |
|
\\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}} |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` |
|
preserves the magnitude of the variance of the weights in the |
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the |
|
backwards pass. |
|
""" |
|
fan = _calculate_correct_fan(tensor, mode) |
|
gain = 2.0 |
|
std = gain / math.sqrt(fan) |
|
bound = std |
|
with torch.no_grad(): |
|
return tensor.uniform_(-bound, bound) |
|
|
|
|
|
def tds_normal_(tensor, mode='fan_in'): |
|
""" |
|
Normal Initialization from the paper [Sequence-to-Sequence Speech Recognition with Time-Depth Separable Convolutions](https://www.isca-speech.org/archive/Interspeech_2019/pdfs/2460.pdf) |
|
Normalized to - |
|
|
|
.. math:: |
|
\\text{bound} = \\text{2} \\times \\sqrt{\\frac{1}{\\text{fan\\_mode}}} |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` |
|
preserves the magnitude of the variance of the weights in the |
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the |
|
backwards pass. |
|
""" |
|
fan = _calculate_correct_fan(tensor, mode) |
|
gain = 2.0 |
|
std = gain / math.sqrt(fan) |
|
bound = std |
|
with torch.no_grad(): |
|
return tensor.normal_(0.0, bound) |