File size: 1,495 Bytes
a83aa44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mlp import LlamaMLP
from .config import LlamaConfig
from .rms_norm import LlamaRMSNorm
from .attention import LlamaAttention
from .diff_attn import DifferentialAttention
from .tensor_prod_attn import CausalTensorProductSelfAttn
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_num):
super().__init__()
self.self_attn = DifferentialAttention(config, layer_num)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states |