|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .config import LlamaConfig |
|
|
|
class LlamaMLP(nn.Module): |
|
def __init__(self, config: LlamaConfig): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
|
self.act_fn = nn.SiLU() |
|
|
|
def forward(self, x): |
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |