|
from typing import Optional, Tuple |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .mlp import LlamaMLP |
|
from .config import LlamaConfig |
|
from .rms_norm import LlamaRMSNorm |
|
from .attention import LlamaAttention |
|
from .diff_attn import DifferentialAttention |
|
from .tensor_prod_attn import CausalTensorProductSelfAttn |
|
|
|
class LlamaDecoderLayer(nn.Module): |
|
def __init__(self, config: LlamaConfig, layer_num): |
|
super().__init__() |
|
self.self_attn = DifferentialAttention(config, layer_num) |
|
self.mlp = LlamaMLP(config) |
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
|
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
hidden_states = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
return hidden_states |