|
from typing import Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .module import NeuralModule |
|
from .tdnn_attention import ( |
|
StatsPoolLayer, |
|
AttentivePoolLayer, |
|
TdnnModule, |
|
TdnnSeModule, |
|
TdnnSeRes2NetModule, |
|
init_weights |
|
) |
|
|
|
|
|
class EcapaTdnnEncoder(NeuralModule): |
|
""" |
|
Modified ECAPA Encoder layer without Res2Net module for faster training and inference which achieves |
|
better numbers on speaker diarization tasks |
|
Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf) |
|
|
|
input: |
|
feat_in: input feature shape (mel spec feature shape) |
|
filters: list of filter shapes for SE_TDNN modules |
|
kernel_sizes: list of kernel shapes for SE_TDNN modules |
|
dilations: list of dilations for group conv se layer |
|
scale: scale value to group wider conv channels (deafult:8) |
|
|
|
output: |
|
outputs : encoded output |
|
output_length: masked output lengths |
|
""" |
|
|
|
def __init__( |
|
self, |
|
feat_in: int, |
|
filters: list, |
|
kernel_sizes: list, |
|
dilations: list, |
|
scale: int = 8, |
|
res2net: bool = False, |
|
res2net_scale: int = 8, |
|
init_mode: str = 'xavier_uniform', |
|
): |
|
super().__init__() |
|
self.layers = nn.ModuleList() |
|
self.layers.append(TdnnModule(feat_in, filters[0], kernel_size=kernel_sizes[0], dilation=dilations[0])) |
|
|
|
for i in range(len(filters) - 2): |
|
if res2net: |
|
self.layers.append( |
|
TdnnSeRes2NetModule( |
|
filters[i], |
|
filters[i + 1], |
|
group_scale=scale, |
|
se_channels=128, |
|
kernel_size=kernel_sizes[i + 1], |
|
dilation=dilations[i + 1], |
|
res2net_scale=res2net_scale, |
|
) |
|
) |
|
else: |
|
self.layers.append( |
|
TdnnSeModule( |
|
filters[i], |
|
filters[i + 1], |
|
group_scale=scale, |
|
se_channels=128, |
|
kernel_size=kernel_sizes[i + 1], |
|
dilation=dilations[i + 1], |
|
) |
|
) |
|
self.feature_agg = TdnnModule(filters[-1], filters[-1], kernel_sizes[-1], dilations[-1]) |
|
self.apply(lambda x: init_weights(x, mode=init_mode)) |
|
|
|
def forward(self, audio_signal, length=None): |
|
x = audio_signal |
|
outputs = [] |
|
|
|
for layer in self.layers: |
|
x = layer(x, length=length) |
|
outputs.append(x) |
|
|
|
x = torch.cat(outputs[1:], dim=1) |
|
x = self.feature_agg(x) |
|
return x, length |
|
|
|
|
|
class SpeakerDecoder(NeuralModule): |
|
""" |
|
Speaker Decoder creates the final neural layers that maps from the outputs |
|
of Jasper Encoder to the embedding layer followed by speaker based softmax loss. |
|
|
|
Args: |
|
feat_in (int): Number of channels being input to this module |
|
num_classes (int): Number of unique speakers in dataset |
|
emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings |
|
from 1st of this layers). Defaults to [1024,1024] |
|
pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention' |
|
Defaults to 'xvector (mean and variance)' |
|
tap (temporal average pooling: just mean) |
|
attention (attention based pooling) |
|
init_mode (str): Describes how neural network parameters are |
|
initialized. Options are ['xavier_uniform', 'xavier_normal', |
|
'kaiming_uniform','kaiming_normal']. |
|
Defaults to "xavier_uniform". |
|
""" |
|
|
|
def __init__( |
|
self, |
|
feat_in: int, |
|
num_classes: int, |
|
emb_sizes: Optional[Union[int, list]] = 256, |
|
pool_mode: str = 'xvector', |
|
angular: bool = False, |
|
attention_channels: int = 128, |
|
init_mode: str = "xavier_uniform", |
|
): |
|
super().__init__() |
|
self.angular = angular |
|
self.emb_id = 2 |
|
bias = False if self.angular else True |
|
emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes |
|
|
|
self._num_classes = num_classes |
|
self.pool_mode = pool_mode.lower() |
|
if self.pool_mode == 'xvector' or self.pool_mode == 'tap': |
|
self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode) |
|
affine_type = 'linear' |
|
elif self.pool_mode == 'attention': |
|
self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels) |
|
affine_type = 'conv' |
|
|
|
shapes = [self._pooling.feat_in] |
|
for size in emb_sizes: |
|
shapes.append(int(size)) |
|
|
|
emb_layers = [] |
|
for shape_in, shape_out in zip(shapes[:-1], shapes[1:]): |
|
layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type) |
|
emb_layers.append(layer) |
|
|
|
self.emb_layers = nn.ModuleList(emb_layers) |
|
|
|
self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias) |
|
|
|
self.apply(lambda x: init_weights(x, mode=init_mode)) |
|
|
|
def affine_layer( |
|
self, |
|
inp_shape, |
|
out_shape, |
|
learn_mean=True, |
|
affine_type='conv', |
|
): |
|
if affine_type == 'conv': |
|
layer = nn.Sequential( |
|
nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True), |
|
nn.Conv1d(inp_shape, out_shape, kernel_size=1), |
|
) |
|
|
|
else: |
|
layer = nn.Sequential( |
|
nn.Linear(inp_shape, out_shape), |
|
nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True), |
|
nn.ReLU(), |
|
) |
|
|
|
return layer |
|
|
|
def forward(self, encoder_output, length=None): |
|
pool = self._pooling(encoder_output, length) |
|
embs = [] |
|
|
|
for layer in self.emb_layers: |
|
pool, emb = layer(pool), layer[: self.emb_id](pool) |
|
embs.append(emb) |
|
|
|
pool = pool.squeeze(-1) |
|
if self.angular: |
|
for W in self.final.parameters(): |
|
W = F.normalize(W, p=2, dim=1) |
|
pool = F.normalize(pool, p=2, dim=1) |
|
|
|
out = self.final(pool) |
|
|
|
return out, embs[-1].squeeze(-1) |