from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from .module import NeuralModule from .tdnn_attention import ( StatsPoolLayer, AttentivePoolLayer, TdnnModule, TdnnSeModule, TdnnSeRes2NetModule, init_weights ) class EcapaTdnnEncoder(NeuralModule): """ Modified ECAPA Encoder layer without Res2Net module for faster training and inference which achieves better numbers on speaker diarization tasks Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf) input: feat_in: input feature shape (mel spec feature shape) filters: list of filter shapes for SE_TDNN modules kernel_sizes: list of kernel shapes for SE_TDNN modules dilations: list of dilations for group conv se layer scale: scale value to group wider conv channels (deafult:8) output: outputs : encoded output output_length: masked output lengths """ def __init__( self, feat_in: int, filters: list, kernel_sizes: list, dilations: list, scale: int = 8, res2net: bool = False, res2net_scale: int = 8, init_mode: str = 'xavier_uniform', ): super().__init__() self.layers = nn.ModuleList() self.layers.append(TdnnModule(feat_in, filters[0], kernel_size=kernel_sizes[0], dilation=dilations[0])) for i in range(len(filters) - 2): if res2net: self.layers.append( TdnnSeRes2NetModule( filters[i], filters[i + 1], group_scale=scale, se_channels=128, kernel_size=kernel_sizes[i + 1], dilation=dilations[i + 1], res2net_scale=res2net_scale, ) ) else: self.layers.append( TdnnSeModule( filters[i], filters[i + 1], group_scale=scale, se_channels=128, kernel_size=kernel_sizes[i + 1], dilation=dilations[i + 1], ) ) self.feature_agg = TdnnModule(filters[-1], filters[-1], kernel_sizes[-1], dilations[-1]) self.apply(lambda x: init_weights(x, mode=init_mode)) def forward(self, audio_signal, length=None): x = audio_signal outputs = [] for layer in self.layers: x = layer(x, length=length) outputs.append(x) x = torch.cat(outputs[1:], dim=1) x = self.feature_agg(x) return x, length class SpeakerDecoder(NeuralModule): """ Speaker Decoder creates the final neural layers that maps from the outputs of Jasper Encoder to the embedding layer followed by speaker based softmax loss. Args: feat_in (int): Number of channels being input to this module num_classes (int): Number of unique speakers in dataset emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings from 1st of this layers). Defaults to [1024,1024] pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention' Defaults to 'xvector (mean and variance)' tap (temporal average pooling: just mean) attention (attention based pooling) init_mode (str): Describes how neural network parameters are initialized. Options are ['xavier_uniform', 'xavier_normal', 'kaiming_uniform','kaiming_normal']. Defaults to "xavier_uniform". """ def __init__( self, feat_in: int, num_classes: int, emb_sizes: Optional[Union[int, list]] = 256, pool_mode: str = 'xvector', angular: bool = False, attention_channels: int = 128, init_mode: str = "xavier_uniform", ): super().__init__() self.angular = angular self.emb_id = 2 bias = False if self.angular else True emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes self._num_classes = num_classes self.pool_mode = pool_mode.lower() if self.pool_mode == 'xvector' or self.pool_mode == 'tap': self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode) affine_type = 'linear' elif self.pool_mode == 'attention': self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels) affine_type = 'conv' shapes = [self._pooling.feat_in] for size in emb_sizes: shapes.append(int(size)) emb_layers = [] for shape_in, shape_out in zip(shapes[:-1], shapes[1:]): layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type) emb_layers.append(layer) self.emb_layers = nn.ModuleList(emb_layers) self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias) self.apply(lambda x: init_weights(x, mode=init_mode)) def affine_layer( self, inp_shape, out_shape, learn_mean=True, affine_type='conv', ): if affine_type == 'conv': layer = nn.Sequential( nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True), nn.Conv1d(inp_shape, out_shape, kernel_size=1), ) else: layer = nn.Sequential( nn.Linear(inp_shape, out_shape), nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True), nn.ReLU(), ) return layer def forward(self, encoder_output, length=None): pool = self._pooling(encoder_output, length) embs = [] for layer in self.emb_layers: pool, emb = layer(pool), layer[: self.emb_id](pool) embs.append(emb) pool = pool.squeeze(-1) if self.angular: for W in self.final.parameters(): W = F.normalize(W, p=2, dim=1) pool = F.normalize(pool, p=2, dim=1) out = self.final(pool) return out, embs[-1].squeeze(-1)