|
|
|
import argparse |
|
parser = argparse.ArgumentParser(description='DECIDIA training program') |
|
parser.add_argument('--input_dir', type=str, help='input directory') |
|
parser.add_argument('--sequence_embedding', type=str, help='sequence embedding directory') |
|
parser.add_argument('--num_hidden_layers', type=int, default=1, help='num_hidden_layers [1]') |
|
parser.add_argument('--train_file', type=str, help='training file') |
|
parser.add_argument('--val_file', type=str, help='validation file') |
|
parser.add_argument('--device', type=str, help='device', default='cuda:1') |
|
parser.add_argument('--num_classes', type=int, help='num_classes [32]', default=32) |
|
parser.add_argument('--diseases', type=str, default=None, help='diseases included, e.g "LUAD,LUSC"') |
|
parser.add_argument('--weight_decay', type=float, help='weight_decay [1e-5]', default=1e-5) |
|
parser.add_argument('--modeling_context', action='store_true', help='whether use OPT to model context dependency') |
|
parser.add_argument("--lr_scheduler_type", type=str, |
|
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], |
|
default="constant", help="The scheduler type to use.") |
|
parser.add_argument('--pretrained_weight', type=str, help='pretrained weight') |
|
parser.add_argument('--pretrained_cls_token', type=str, help='pretrained cls token') |
|
parser.add_argument('--epochs', type=int, default=100, help='epochs (default: 100)') |
|
parser.add_argument('--num_sequences', type=int, default=None, help='num of sequences to sample from training set') |
|
parser.add_argument('--num_train_patients', type=int, default=None, help='num of patients data to sample from training set') |
|
|
|
args = parser.parse_args() |
|
|
|
import os |
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
|
import sys |
|
import glob |
|
import torch |
|
import torch.nn as nn |
|
from tqdm import tqdm |
|
from torch.optim import AdamW, Adam, SGD, Adagrad |
|
from sklearn.utils import resample |
|
from transformers import get_scheduler |
|
import numpy as np |
|
import pandas as pd |
|
import random |
|
import time |
|
from transformers import ( |
|
PreTrainedTokenizerFast, |
|
OPTForCausalLM |
|
) |
|
from model import DeepAttnMIL |
|
|
|
torch.set_num_threads(2) |
|
device = args.device |
|
random.seed(123) |
|
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.sequence_embedding) |
|
net = OPTForCausalLM.from_pretrained(args.sequence_embedding) |
|
net = net.to(device) |
|
net.eval() |
|
|
|
feature_dim = net.config.hidden_size |
|
|
|
trn_df = pd.read_csv(f'{args.input_dir}/trn.csv.gz') |
|
reads_per_patient = trn_df.patient.value_counts().unique() |
|
assert len(reads_per_patient) == 1 |
|
reads_per_patient = reads_per_patient[0] |
|
if args.num_sequences < reads_per_patient: |
|
trn_df = pd.concat([df.sample(args.num_sequences, random_state=123) for patient, df in trn_df.groupby('patient')]) |
|
|
|
num_train_samples = len(trn_df.patient.unique()) |
|
if args.num_train_patients is None: |
|
args.num_train_patients = num_train_samples |
|
if args.num_train_patients < num_train_samples: |
|
trn_df = trn_df[trn_df.patient.isin(random.sample(trn_df.patient.unique().tolist(), args.num_train_patients))] |
|
|
|
trn_x = torch.zeros(args.num_train_patients, args.num_sequences, feature_dim) |
|
trn_y = torch.as_tensor([-1] * args.num_train_patients) |
|
|
|
test_df = pd.read_csv(f'{args.input_dir}/test.csv.gz') |
|
num_test_samples = len(test_df.patient.unique()) |
|
test_x = torch.zeros(num_test_samples, reads_per_patient, feature_dim) |
|
test_y = torch.as_tensor([-1] * num_test_samples) |
|
test_patients = [] |
|
|
|
val_df = pd.read_csv(f'{args.input_dir}/val.csv.gz') |
|
num_val_samples = len(val_df.patient.unique()) |
|
val_x = torch.zeros(num_val_samples, reads_per_patient, feature_dim) |
|
val_y = torch.as_tensor([-1] * num_val_samples) |
|
val_patients = [] |
|
|
|
|
|
pad_token_id = net.config.pad_token_id |
|
|
|
|
|
for i, (patient, e) in tqdm(enumerate(trn_df.groupby('patient')), total=args.num_train_patients): |
|
a = [' '.join(list(s)) for s in e.seq] |
|
inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False) |
|
for k, v in inputs.items():inputs[k] = v.to(device) |
|
with torch.inference_mode(): |
|
out = net.model(**inputs) |
|
features = out.last_hidden_state.mean(1).cpu() |
|
trn_x[i] = features |
|
trn_y[i] = e.label.tolist()[0] |
|
|
|
|
|
for i, (patient, e) in tqdm(enumerate(test_df.groupby('patient')), total=num_test_samples): |
|
a = [' '.join(list(s)) for s in e.seq] |
|
inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False) |
|
for k, v in inputs.items():inputs[k] = v.to(device) |
|
with torch.inference_mode(): |
|
out = net.model(**inputs) |
|
features = out.last_hidden_state.mean(1).cpu() |
|
test_x[i] = features |
|
test_y[i] = e.label.tolist()[0] |
|
test_patients.append(patient) |
|
|
|
for i, (patient, e) in tqdm(enumerate(val_df.groupby('patient')), total=num_val_samples): |
|
a = [' '.join(list(s)) for s in e.seq] |
|
inputs = tokenizer(a, max_length=100, padding='max_length', truncation=True, return_tensors='pt', return_token_type_ids=False) |
|
for k, v in inputs.items():inputs[k] = v.to(device) |
|
with torch.inference_mode(): |
|
out = net.model(**inputs) |
|
features = out.last_hidden_state.mean(1).cpu() |
|
val_x[i] = features |
|
val_y[i] = e.label.tolist()[0] |
|
val_patients.append(patient) |
|
|
|
|
|
|
|
fout = open(f'{args.input_dir}/log-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.txt', 'w') |
|
print("epoch\ttrain_loss\ttrain_acc\tval_loss\tval_acc\teval_loss\teval_acc", file=fout) |
|
|
|
model = DeepAttnMIL(input_dim=feature_dim, n_classes=args.num_classes, size_arg='big') |
|
|
|
|
|
if args.pretrained_weight: |
|
state_dict = torch.load(args.pretrained_weight, map_location='cpu') |
|
if state_dict['classifier.weight'].size(0) != args.num_classes: |
|
del state_dict['classifier.weight'] |
|
del state_dict['classifier.bias'] |
|
|
|
msg = model.load_state_dict(state_dict, strict=False) |
|
print(msg) |
|
|
|
model = model.to(device) |
|
|
|
print(model) |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
no_decay = ["bias", "LayerNorm.weight"] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
"weight_decay": 1e-5, |
|
}, |
|
{ |
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
opt = AdamW(optimizer_grouped_parameters, lr=2e-5) |
|
|
|
|
|
num_update_steps_per_epoch = len(trn_df) |
|
max_train_steps = args.epochs * num_update_steps_per_epoch |
|
lr_scheduler = get_scheduler(name=args.lr_scheduler_type, optimizer=opt, num_warmup_steps=num_update_steps_per_epoch*1, num_training_steps=max_train_steps) |
|
|
|
|
|
best_eval_acc = 0.0 |
|
best_eval_loss = 100000.0 |
|
best_val_loss = 100000.0 |
|
for epoch in range(args.epochs): |
|
model.train() |
|
total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0 |
|
idxs = random.sample(range(len(trn_y)), len(trn_y)) |
|
for idx in idxs: |
|
x = trn_x[idx] |
|
y = trn_y[idx].unsqueeze(0) |
|
x = x.to(device) |
|
y = y.to(device) |
|
|
|
logit = model(x) |
|
loss = criterion(logit, y) |
|
|
|
opt.zero_grad() |
|
loss.backward() |
|
opt.step() |
|
lr_scheduler.step() |
|
|
|
total_loss += loss.item() |
|
total_batch += 1 |
|
total_num += len(y) |
|
correct_k += logit.argmax(1).eq(y).sum() |
|
|
|
train_acc = correct_k / total_num |
|
train_loss = total_loss / total_batch |
|
|
|
|
|
model.eval() |
|
total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0 |
|
eval_probs = [] |
|
for x, y, pid in zip(test_x, test_y, test_patients): |
|
y = y.unsqueeze(0).to(device) |
|
x = x.to(device) |
|
|
|
with torch.inference_mode(): |
|
logit = model(x) |
|
loss = criterion(logit, y) |
|
|
|
eval_probs.append(logit.flatten().softmax(0).tolist()) |
|
|
|
total_loss += loss.item() |
|
total_batch += 1 |
|
total_num += len(y) |
|
correct_k += logit.argmax(1).eq(y).sum() |
|
|
|
eval_acc = correct_k / total_num |
|
eval_loss = total_loss / total_batch |
|
|
|
|
|
model.eval() |
|
total_loss, total_batch, total_num, correct_k = 0, 0, 0, 0 |
|
val_probs = [] |
|
for x, y, pid in zip(val_x, val_y, val_patients): |
|
y = y.unsqueeze(0).to(device) |
|
x = x.to(device) |
|
|
|
with torch.inference_mode(): |
|
logit = model(x) |
|
loss = criterion(logit, y) |
|
|
|
val_probs.append(logit.flatten().softmax(0).tolist()) |
|
|
|
total_loss += loss.item() |
|
total_batch += 1 |
|
total_num += len(y) |
|
correct_k += logit.argmax(1).eq(y).sum() |
|
|
|
val_acc = correct_k / total_num |
|
val_loss = total_loss / total_batch |
|
|
|
|
|
print(f"{epoch+1}\t{train_loss}\t{train_acc}\t{val_loss}\t{val_acc}\t{eval_loss}\t{eval_acc}", file=fout) |
|
fout.flush() |
|
|
|
if val_loss < best_val_loss: |
|
torch.save(model.state_dict(), f'{args.input_dir}/model-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.pt') |
|
best_val_loss = val_loss |
|
|
|
eval_probs = pd.DataFrame(eval_probs, columns=['p_normal', 'p_cancer']) |
|
info = pd.DataFrame({'patient':test_patients, 'label':test_y.tolist()}) |
|
info = pd.concat([info, eval_probs], axis=1) |
|
info.to_csv(f'{args.input_dir}/test_prediction-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.csv', index=False) |
|
|
|
val_probs = pd.DataFrame(val_probs, columns=['p_normal', 'p_cancer']) |
|
info = pd.DataFrame({'patient':val_patients, 'label':val_y.tolist()}) |
|
info = pd.concat([info, val_probs], axis=1) |
|
info.to_csv(f'{args.input_dir}/val_prediction-reads-{args.num_sequences}-patients-trn{args.num_train_patients}-val{num_val_samples}-test{num_test_samples}-tiny.csv', index=False) |
|
|
|
fout.close() |
|
|
|
|
|
|
|
|