Blackroot's picture
Upload 19 files
a83aa44 verified
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mlp import LlamaMLP
from .config import LlamaConfig
from .rms_norm import LlamaRMSNorm
from .attention import LlamaAttention
from .diff_attn import DifferentialAttention
from .tensor_prod_attn import CausalTensorProductSelfAttn
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_num):
super().__init__()
self.self_attn = DifferentialAttention(config, layer_num)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states