|
import torch |
|
import torch.nn as nn |
|
from torch.cuda.amp import autocast |
|
from torch.utils.data import Dataset, DataLoader |
|
from tqdm import tqdm |
|
import math, os, sys, json, glob, time, random |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
from transformers import AutoTokenizer |
|
from distributed_shampoo import AdamGraftingConfig, DistributedShampoo |
|
from cut_cross_entropy import linear_cross_entropy |
|
from torch.nn.utils import clip_grad_norm_ |
|
from utils.trainutils import count_parameters_layerwise, save_checkpoint, TBLogger |
|
|
|
from llama_modeling.front_end import LlamaForCausalLM |
|
from llama_modeling.config import LlamaConfig |
|
|
|
class JSONLDataset(Dataset): |
|
def __init__(self, directory_path, tokenizer, seq_length=1024, |
|
text_key="text", max_files=None, batch_size=1000, |
|
pad_token_id=0): |
|
self.seq_length = seq_length |
|
self.tokenizer = tokenizer |
|
self.pad_token_id = pad_token_id |
|
self.sequences = [] |
|
|
|
files = glob.glob(os.path.join(directory_path, "*.jsonl")) |
|
if max_files is not None: |
|
files = files[:max_files] |
|
|
|
text_batch = [] |
|
for file_idx, file_path in enumerate(files): |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
try: |
|
data = json.loads(line) |
|
text = data.get(text_key, "") |
|
if len(text) >= 100: |
|
text_batch.append(text) |
|
|
|
if len(text_batch) >= batch_size: |
|
self._process_batch(text_batch) |
|
text_batch = [] |
|
except: |
|
continue |
|
|
|
if text_batch: |
|
self._process_batch(text_batch) |
|
|
|
if self.sequences: |
|
self.sequences = torch.tensor(self.sequences, dtype=torch.long) |
|
else: |
|
self.sequences = torch.empty((0, seq_length), dtype=torch.long) |
|
|
|
def _process_batch(self, texts): |
|
encoded = self.tokenizer( |
|
texts, |
|
add_special_tokens=False, |
|
truncation=True, |
|
padding=False, |
|
return_attention_mask=False, |
|
return_tensors=None |
|
)['input_ids'] |
|
|
|
mlen = 0 |
|
for token_ids in encoded: |
|
for i in range(0, len(token_ids), self.seq_length): |
|
chunk = token_ids[i:i+self.seq_length] |
|
|
|
|
|
if len(chunk) < self.seq_length: |
|
chunk += [self.pad_token_id] * (self.seq_length - len(chunk)) |
|
|
|
self.sequences.append(chunk) |
|
mlen = max(mlen, len(chunk)) |
|
|
|
print("MAX: ", mlen) |
|
|
|
def __len__(self): |
|
return len(self.sequences) |
|
|
|
def __getitem__(self, idx): |
|
return self.sequences[idx] |
|
|
|
def train_model(model, train_loader, optimizer, device, epochs=5, forward_dtype=torch.float32): |
|
model.train() |
|
criterion = nn.CrossEntropyLoss() |
|
scaler = torch.amp.GradScaler("cuda") |
|
|
|
logger = TBLogger(log_dir=f'logs/run-{time.time()}') |
|
|
|
total_steps = len(train_loader) * epochs |
|
scheduler = CosineAnnealingLR( |
|
optimizer, |
|
T_max=total_steps, |
|
eta_min=5e-6 |
|
) |
|
|
|
model = torch.compile( |
|
model, |
|
) |
|
|
|
global_step = 0 |
|
for epoch in range(epochs): |
|
running_loss = 0.0 |
|
total_batches = 0 |
|
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}') |
|
|
|
for batch_idx, data in enumerate(progress_bar): |
|
data = data.to(device) |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
with torch.autocast(device_type='cuda', dtype=forward_dtype): |
|
hidden_states, classifier_weights = model(data) |
|
|
|
loss = linear_cross_entropy( |
|
hidden_states, |
|
classifier_weights, |
|
data, |
|
shift=True, |
|
reduction="mean" |
|
) |
|
|
|
scaler.scale(loss).backward() |
|
scaler.unscale_(optimizer) |
|
clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
scheduler.step() |
|
|
|
|
|
running_loss += loss.item() |
|
total_batches += 1 |
|
global_step += 1 |
|
avg_loss = running_loss / total_batches |
|
perplexity = math.exp(min(avg_loss, 100)) |
|
|
|
progress_bar.set_postfix({ |
|
'loss': f'{avg_loss:.4f}', |
|
'ppl': f'{perplexity:.2f}' |
|
}) |
|
|
|
metrics = { |
|
'loss': loss.item(), |
|
'perplexity': perplexity, |
|
'learning_rate': optimizer.param_groups[0]['lr'], |
|
'batch_size': data.size(0) |
|
} |
|
|
|
logger.log(metrics, step=global_step, model=model, grad_checking=True) |
|
|
|
if batch_idx % 100 == 0: |
|
print(f'\nBatch {batch_idx}/{len(train_loader)}: ' |
|
f'Loss: {avg_loss:.4f}, ' |
|
f'Perplexity: {perplexity:.2f}, ' |
|
f'Batches Processed: {total_batches}') |
|
|
|
epoch_loss = running_loss / total_batches |
|
epoch_ppl = math.exp(min(epoch_loss, 100)) |
|
print(f'\nEpoch {epoch+1} Summary:') |
|
print(f'Average Loss: {epoch_loss:.4f}') |
|
print(f'Perplexity: {epoch_ppl:.2f}') |
|
print(f'Total Batches Processed: {total_batches}\n') |
|
|
|
save_checkpoint(model, f'epoch_{epoch+1}.safetensors') |
|
|
|
def sample_examples(dataset, tokenizer, num_samples=5): |
|
if len(dataset) == 0: |
|
print("The dataset is empty.") |
|
return |
|
|
|
num_samples = min(num_samples, len(dataset)) |
|
|
|
sampled_indices = random.sample(range(len(dataset)), num_samples) |
|
|
|
for i, idx in enumerate(sampled_indices): |
|
sequence = dataset[idx] |
|
print(f"Sample {i + 1} (Index {idx}):") |
|
print(sequence) |
|
decoded_text = tokenizer.decode(sequence, skip_special_tokens=False, decode_special_tokens=False) |
|
print(decoded_text) |
|
print("-" * 40) |
|
|
|
def main(): |
|
BATCH_SIZE = 36 |
|
SEQ_LENGTH = 512 |
|
EPOCHS = 3 |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
tokenizer = AutoTokenizer.from_pretrained("./SmolLM2-135M-Instruct") |
|
|
|
config_path = "config.json" |
|
with open(config_path) as f: |
|
config_dict = json.load(f) |
|
config = LlamaConfig(**{k: v for k, v in config_dict.items() if k in LlamaConfig.__dataclass_fields__}) |
|
|
|
model = LlamaForCausalLM(config).to("cuda") |
|
|
|
dataset = JSONLDataset( |
|
directory_path="./Data_big", |
|
tokenizer=tokenizer, |
|
seq_length=SEQ_LENGTH, |
|
text_key="text", |
|
max_files=None, |
|
) |
|
|
|
train_loader = DataLoader( |
|
dataset, |
|
batch_size=BATCH_SIZE, |
|
shuffle=True, |
|
num_workers=4, |
|
pin_memory=True, |
|
drop_last=True |
|
) |
|
|
|
optimizer = DistributedShampoo( |
|
model.parameters(), |
|
lr=0.0001, |
|
betas=(0.9, 0.999), |
|
epsilon=1e-12, |
|
weight_decay=1e-05, |
|
max_preconditioner_dim=2048, |
|
precondition_frequency=100, |
|
start_preconditioning_step=250, |
|
use_decoupled_weight_decay=False, |
|
grafting_config=AdamGraftingConfig( |
|
beta2=0.999, |
|
epsilon=1e-12, |
|
), |
|
) |
|
|
|
print("*"*100) |
|
torch.set_float32_matmul_precision('high') |
|
|
|
count_parameters_layerwise(model) |
|
|
|
train_model(model, train_loader, optimizer, DEVICE, EPOCHS, forward_dtype=torch.bfloat16) |
|
|
|
if __name__ == "__main__": |
|
main() |