from flash_attn import flash_attn_func import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from .extact import xATGLU from .liger_rope import LigerRopeFunction from .config import LlamaConfig # The four-flash attn strategy comes from here: # https://github.com/microsoft/unilm/blob/master/Diff-Transformer/multihead_flashdiff_2.py class DifferentialAttention(nn.Module): def __init__(self, config: LlamaConfig, layer_num): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.n_rep = self.num_heads // self.num_kv_heads self.head_dim = self.hidden_size // (2 * self.num_heads) self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.scaling = self.head_dim ** -0.5 self.q_proj = nn.Linear(self.hidden_size, 2 * self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(2 * self.num_heads * self.head_dim, self.hidden_size, bias=False) self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * layer_num) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) self.subln = nn.LayerNorm(2 * self.head_dim, elementwise_affine=False) self.register_buffer( "cos_cached", self._compute_rope_embeddings( self.max_position_embeddings, self.head_dim, self.rope_theta, dtype=torch.float32, device=self.q_proj.weight.device, )[0], persistent=False, ) self.register_buffer( "sin_cached", self._compute_rope_embeddings( self.max_position_embeddings, self.head_dim, self.rope_theta, dtype=torch.float32, device=self.q_proj.weight.device, )[1], persistent=False, ) def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype) sin = emb.sin().to(dtype) return cos.unsqueeze(0), sin.unsqueeze(0) def forward( self, hidden_states, attention_mask, position_ids, ) -> torch.Tensor: bsz, seq_len, embed_dim = hidden_states.size() if position_ids is None: position_ids = torch.arange(seq_len, device=hidden_states.device) position_ids = repeat(position_ids, 'l -> b l', b=bsz) q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) q = rearrange(q, 'b s (h d) -> b s h d', h=2*self.num_heads, d=self.head_dim) k = rearrange(k, 'b s (h d) -> b s h d', h=2*self.num_kv_heads, d=self.head_dim) # Reshaped for GQA v = rearrange(v, 'b s (h g d) -> b s h g d', h=self.num_kv_heads, g=2, d=self.head_dim) # Apply rotary embeddings using LigerRopeFunction cos = self.cos_cached[:, position_ids] # [1, bsz, seq_len, dim] sin = self.sin_cached[:, position_ids] # [1, bsz, seq_len, dim] q, k = LigerRopeFunction.apply(q, k, cos, sin, position_ids) # Rearrange into GQA style q = rearrange(q, 'b s (h g) d -> b s h g d', h=self.num_heads, g=2) k = rearrange(k, 'b s (h g) d -> b s h g d', h=self.num_kv_heads, g=2) q1, q2 = q[:, :, :, 0], q[:, :, :, 1] k1, k2 = k[:, :, :, 0], k[:, :, :, 1] v1, v2 = v[:, :, :, 0], v[:, :, :, 1] # First attention group on q1/k1 and the v's attn11 = flash_attn_func( q1, k1, v1, dropout_p=0.0, # @Z TODO:: causal=attention_mask is None ) attn12 = flash_attn_func( q1, k1, v2, dropout_p=0.0, causal=attention_mask is None ) attn1 = torch.cat([attn11, attn12], dim=-1) # Second attention group on q2/k2 and the v's attn21 = flash_attn_func( q2, k2, v1, dropout_p=0.0, causal=attention_mask is None ) attn22 = flash_attn_func( q2, k2, v2, dropout_p=0.0, causal=attention_mask is None ) attn2 = torch.cat([attn21, attn22], dim=-1) lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) lambda_full = lambda_1 - lambda_2 + self.lambda_init attn = attn1 - lambda_full * attn2 attn = self.subln(attn) attn = attn * (1 - self.lambda_init) attn_output = rearrange(attn, "b s h d -> b s (h d)") return self.o_proj(attn_output)