DECIDIA-code / train.py
lixiangchun
first commit
25f71fa
#!/opt/software/install/miniconda38/bin/python
import argparse
parser = argparse.ArgumentParser(description='DECIDIA training program')
parser.add_argument('--input_dir', type=str, help='input directory')
parser.add_argument('--sequence_embedding', type=str, help='sequence embedding directory')
parser.add_argument('--num_hidden_layers', type=int, default=1, help='num_hidden_layers [1]')
parser.add_argument('--train_file', type=str, help='training file')
parser.add_argument('--val_file', type=str, help='validation file')
parser.add_argument('--device', type=str, help='device', default='cuda:1')
parser.add_argument('--num_classes', type=int, help='num_classes [32]', default=32)
parser.add_argument('--diseases', type=str, default=None, help='diseases included, e.g "LUAD,LUSC"')
parser.add_argument('--weight_decay', type=float, help='weight_decay [1e-5]', default=1e-5)
parser.add_argument('--modeling_context', action='store_true', help='whether use OPT to model context dependency')
parser.add_argument("--lr_scheduler_type", type=str,
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
default="constant", help="The scheduler type to use.")
parser.add_argument('--pretrained_weight', type=str, help='pretrained weight')
parser.add_argument('--pretrained_cls_token', type=str, help='pretrained cls token')
parser.add_argument('--epochs', type=int, default=100, help='epochs (default: 100)')
parser.add_argument('--num_sequences', type=int, default=None, help='num of sequences to sample from training set')
parser.add_argument('--num_train_patients', type=int, default=None, help='num of patients data to sample from training set')
args = parser.parse_args()
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import sys
import glob
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import AdamW, Adam, SGD, Adagrad
from sklearn.utils import resample
from transformers import get_scheduler
import numpy as np
import pandas as pd
import random
import time
from transformers import (
PreTrainedTokenizerFast,
OPTForCausalLM
)
from model import DeepAttnMIL
torch.set_num_threads(2)
device = args.device
random.seed(123)
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.sequence_embedding)
net = OPTForCausalLM.from_pretrained(args.sequence_embedding)
net = net.to(device)
net.eval()
feature_dim = net.config.hidden_size
trn_df = pd.read_csv(f'{args.input_dir}/trn.csv.gz')
reads_per_patient = trn_df.patient.value_counts().unique()
assert len(reads_per_patient) == 1
reads_per_patient = reads_per_patient[0]
if args.num_sequences < reads_per_patient:
trn_df = pd.concat([df.sample(args.num_sequences, random_state=123) for patient, df in trn_df.groupby('patient')])
num_train_samples = len(trn_df.patient.unique())
if args.num_train_patients is None:
args.num_train_patients = num_train_samples
if args.num_train_patients < num_train_samples:
trn_df = trn_df[trn_df.patient.isin(random.sample(trn_df.patient.unique().tolist(), args.num_train_patients))]
trn_x = torch.zeros(args.num_train_patients, args.num_sequences, feature_dim)
trn_y = torch.as_tensor([-1] * args.num_train_patients)
test_df = pd.read_csv(f'{args.input_dir}/test.csv.gz')
num_test_samples = len(test_df.patient.unique())
test_x = torch.zeros(num_test_samples, reads_per_patient, feature_dim)
test_y = torch.as_tensor([-1] * num_test_samples)
test_patients = []
val_df = pd.read_csv(f'{args.input_dir}/val.csv.gz')
num_val_samples = len(val_df.patient.unique())
val_x = torch.zeros(num_val_samples, reads_per_patient, feature_dim)
val_y = torch.as_tensor([-1] * num_val_samples)
val_patients = []
pad_token_id = net.config.pad_token_id
for i, (patient, e) in tqdm(enumerate(trn_df.groupby('patient')), total=args.num_train_patients):
a = [' '.join(list(s)) for s in e.seq]
inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False)
for k, v in inputs.items():inputs[k] = v.to(device)
with torch.inference_mode():
out = net.model(**inputs)
features = out.last_hidden_state.mean(1).cpu()
trn_x[i] = features
trn_y[i] = e.label.tolist()[0]
for i, (patient, e) in tqdm(enumerate(test_df.groupby('patient')), total=num_test_samples):
a = [' '.join(list(s)) for s in e.seq]
inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False)
for k, v in inputs.items():inputs[k] = v.to(device)
with torch.inference_mode():
out = net.model(**inputs)
features = out.last_hidden_state.mean(1).cpu()
test_x[i] = features
test_y[i] = e.label.tolist()[0]
test_patients.append(patient)
for i, (patient, e) in tqdm(enumerate(val_df.groupby('patient')), total=num_val_samples):
a = [' '.join(list(s)) for s in e.seq]
inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False)
for k, v in inputs.items():inputs[k] = v.to(device)
with torch.inference_mode():
out = net.model(**inputs)
features = out.last_hidden_state.mean(1).cpu()
val_x[i] = features
val_y[i] = e.label.tolist()[0]
val_patients.append(patient)
fout = open(f'{args.input_dir}/log-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.txt', 'w')
print("epoch\ttrain_loss\ttrain_acc\tval_loss\tval_acc\teval_loss\teval_acc", file=fout)
model = DeepAttnMIL(input_dim=feature_dim, n_classes=args.num_classes, size_arg='big')
if args.pretrained_weight:
state_dict = torch.load(args.pretrained_weight, map_location='cpu')
if state_dict['classifier.weight'].size(0) != args.num_classes:
del state_dict['classifier.weight']
del state_dict['classifier.bias']
msg = model.load_state_dict(state_dict, strict=False)
print(msg)#, file=fout)
model = model.to(device)
print(model)#, file=fout)
criterion = nn.CrossEntropyLoss()
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 1e-5,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
opt = AdamW(optimizer_grouped_parameters, lr=2e-5)
num_update_steps_per_epoch = len(trn_df)
max_train_steps = args.epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=opt, num_warmup_steps=num_update_steps_per_epoch*1, num_training_steps=max_train_steps)
best_eval_acc = 0.0
best_eval_loss = 100000.0
best_val_loss = 100000.0
for epoch in range(args.epochs):
model.train()
total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0
idxs = random.sample(range(len(trn_y)), len(trn_y))
for idx in idxs:
x = trn_x[idx]
y = trn_y[idx].unsqueeze(0)
x = x.to(device)
y = y.to(device)
logit = model(x)
loss = criterion(logit, y)
opt.zero_grad()
loss.backward()
opt.step()
lr_scheduler.step()
total_loss += loss.item()
total_batch += 1
total_num += len(y)
correct_k += logit.argmax(1).eq(y).sum()
train_acc = correct_k / total_num
train_loss = total_loss / total_batch
#######Evalutate on test set ################
model.eval()
total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0
eval_probs = []
for x, y, pid in zip(test_x, test_y, test_patients):
y = y.unsqueeze(0).to(device)
x = x.to(device)
with torch.inference_mode():
logit = model(x)
loss = criterion(logit, y)
eval_probs.append(logit.flatten().softmax(0).tolist())
total_loss += loss.item()
total_batch += 1
total_num += len(y)
correct_k += logit.argmax(1).eq(y).sum()
eval_acc = correct_k / total_num
eval_loss = total_loss / total_batch
#######Evalutate on val set ################
model.eval()
total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0
val_probs = []
for x, y, pid in zip(val_x, val_y, val_patients):
y = y.unsqueeze(0).to(device)
x = x.to(device)
with torch.inference_mode():
logit = model(x)
loss = criterion(logit, y)
val_probs.append(logit.flatten().softmax(0).tolist())
total_loss += loss.item()
total_batch += 1
total_num += len(y)
correct_k += logit.argmax(1).eq(y).sum()
val_acc = correct_k / total_num
val_loss = total_loss / total_batch
print(f"{epoch+1}\t{train_loss}\t{train_acc}\t{val_loss}\t{val_acc}\t{eval_loss}\t{eval_acc}", file=fout)
fout.flush()
if val_loss < best_val_loss:
torch.save(model.state_dict(), f'{args.input_dir}/model-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.pt')
best_val_loss = val_loss
eval_probs = pd.DataFrame(eval_probs, columns=['p_normal', 'p_cancer'])
info = pd.DataFrame({'patient':test_patients, 'label':test_y.tolist()})
info = pd.concat([info, eval_probs], axis=1)
info.to_csv(f'{args.input_dir}/test_prediction-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.csv', index=False)
val_probs = pd.DataFrame(val_probs, columns=['p_normal', 'p_cancer'])
info = pd.DataFrame({'patient':val_patients, 'label':val_y.tolist()})
info = pd.concat([info, val_probs], axis=1)
info.to_csv(f'{args.input_dir}/val_prediction-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.csv', index=False)
fout.close()