forked from macabdul9/CASA-Dialogue-Act-Classifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
63 lines (49 loc) · 1.61 KB
/
main.py
File metadata and controls
63 lines (49 loc) · 1.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from config import config
from transformers import AutoTokenizer
from models.ContextAwareDAC import ContextAwareDAC
from Trainer import LightningModel
from pytorch_lightning.callbacks import EarlyStopping, ProgressBar, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
if __name__=="__main__":
logger = WandbLogger(
name="grammarly-context-aware-attention",
save_dir=config["save_dir"],
project=config["project"],
log_model=True,
)
early_stopping = EarlyStopping(
monitor=config["monitor"],
min_delta=config["min_delta"],
patience=5,
)
checkpoints = ModelCheckpoint(
filepath=config["filepath"],
monitor=config["monitor"],
save_top_k=1
)
trainer = pl.Trainer(
logger=logger,
gpus=[0],
checkpoint_callback=checkpoints,
callbacks=[early_stopping],
default_root_dir="../working/",
max_epochs=config["epochs"],
precision=config["precision"],
automatic_optimization=True
)
base = ContextAwareDAC()
tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
model = LightningModel(model=base, tokenizer=tokenizer, config=config)
trainer = pl.Trainer(
logger=logger,
gpus=[0],
checkpoint_callback=checkpoints,
callbacks=[early_stopping],
default_root_dir="../working/",
max_epochs=config["epochs"],
precision=config["precision"],
automatic_optimization=True
)
trainer.fit(model)
trainer.test(model)