-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
139 lines (126 loc) · 6.45 KB
/
train.py
File metadata and controls
139 lines (126 loc) · 6.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from torch.utils.data import random_split
from transformers import MambaForCausalLM, MambaConfig
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
import os
import argparse
import math
import wandb
from dataset import UniRefDataset
from eval_metrics import compute_metrics
from utils import set_seed, LogTrainingMetricsCallback
os.environ["WANDB_DISABLE_CODE"] = "true" #WEIGHTS AND BIASES
set_seed(42)
#IMPORTANT PART
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type = int, help = "number of epochs to train the model", default = 25)
parser.add_argument("--batch_size", type = int, help = "batch size per device for both training and evaluation", default = 128)
parser.add_argument("--nhl", type = int, help = "number of hidden layers in the model", default = 32) #IMPORTANT
parser.add_argument("--state_size" , type = int, help = "the size of the state space latents used in the model", default = 16) #IMPORTANT
parser.add_argument("--hidden_size" , type = int, help = "dimensionality of the embeddings and hidden states", default = 768) #IMPORTANT
parser.add_argument("--max_length" , type = int, help = "context size, or length of each input sequence to the model", default = 512) #IMPORTANT
parser.add_argument("--lr", type = float, help = "maximum value of learning rate to be used during training", default = 1e-3)
parser.add_argument("--warmup_ratio", type = float, help = "ratio of the learning rate warm up phase to the entire trining", default = 0.0001)
parser.add_argument("--weight_decay", type = float, help = "l2 regularization constant", default = 0.001)
parser.add_argument("--max_grad_norm", type = float, help = "maximum gradient norm threshold for norm clipping", default = 1.0)
parser.add_argument("--device_index", type = int, help = "index of the cuda device", default = 0)
parser.add_argument("--tokenizer_path", type = str, help = "path to the pretrained tokenizer, either huggingface hub directory or local directory") #IMPORTANT
parser.add_argument("--model_save_path", type = str, help = "directory to save the model weights and optimizer state")
parser.add_argument("--dataset_path", type = str, help = "path to the dataset in the fasta file format")
parser.add_argument("--model_load_path", type = str, help = "the path to the model state dict if not training from zero" )
parser.add_argument("--wandb_run_id", type = str, help = "provide wandb run id if continuing a previous run")
parser.add_argument("--wandb_run_name", type = str, help = "provide wandb run name for new runs")
parser.add_argument("--wandb_project_name", type = str, help = "provide wandb project name for new runs")
parser.add_argument("--log_freq", type = int, help = "frequency of logging training metrics to wandb, provide in terms of training steps", default = 5000)
parser.add_argument("--eval_freq", type = int, help = "frequency of evaluation step during training, provide in terms of training steps", default = 30000)
parser.add_argument("--num_data", type = int, help = "total number of protein sequences to use in UniRef50 dataset", default = int(60e6))
args = parser.parse_args()
device = torch.device(f'cuda:{args.device_index}' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D", padding_side = 'left') #PADDING SIDE????????
def train():
run_config = {
'batch_size': args.batch_size,
'lr': args.lr,
'hidden_size': args.hidden_size,
'state_size': args.state_size,
'nhl': args.nhl,
'weight_decay': args.weight_decay,
'epochs': args.epochs
}
run = wandb.init(
project = args.wandb_project_name,
mode = 'online',
name = args.wandb_run_name,
config = run_config
)
training_args = TrainingArguments(
gradient_accumulation_steps=1,
output_dir = f"{args.model_save_path}{args.wandb_run_name}",
num_train_epochs = args.epochs,
per_device_train_batch_size= args.batch_size,
per_device_eval_batch_size= args.batch_size,
weight_decay= args.weight_decay,
logging_dir = f"{args.model_save_path}/logs/",
logging_steps=args.log_freq,
evaluation_strategy="steps",
eval_steps=args.eval_freq,
eval_on_start = False,
report_to="wandb",
learning_rate=args.lr,
save_total_limit=1,
greater_is_better=False,
metric_for_best_model='eval_loss',
max_grad_norm=args.max_grad_norm,
warmup_ratio=args.warmup_ratio,
save_strategy='steps',
save_steps = args.eval_freq,
bf16 = True,
resume_from_checkpoint = args.model_load_path if args.model_load_path != None else False
)
mamba_config = MambaConfig(
vocab_size=tokenizer.vocab_size,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
hidden_size=args.hidden_size,
num_hidden_layers=args.nhl,
state_size=args.state_size,
use_cache = False
)
model = MambaForCausalLM(mamba_config) if args.model_load_path == None else MambaForCausalLM.from_pretrained(args.model_load_path)
dataset = UniRefDataset(
dataset_path=args.dataset_path,
tokenizer=tokenizer,
num_data=args.num_data,
max_len=args.max_length,
device=device
)
len_dataset = len(dataset)
train_size = int(0.8 * len_dataset)
test_size = int(0.1 * len_dataset)
val_size = int(0.1 * len_dataset)
train_dataset, test_dataset, val_dataset = random_split(dataset, [train_size, test_size, val_size])
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
callbacks = [
LogTrainingMetricsCallback(run)
]
trainer = Trainer( #IMPORTANT
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks = callbacks
)
trainer.train()
metrics = trainer.evaluate()
print(metrics)
run.finish()
if __name__ == "__main__":
train()