-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
66 lines (53 loc) · 1.64 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from src.datamodule import NL2NLDM
from src.model import NL2NL
from config import (
GLOBAL_SEED,
PROJECT_NAME
)
def train_nl2nl(args):
seed_everything(GLOBAL_SEED)
dm = NL2NLDM(
tokenizer_model=args.tokenizer,
pl=args.pl,
path_base_models=args.path_base_models,
path_cache_dataset=args.path_cache_datasets,
max_seq_len=args.max_seq_len,
padding=args.padding,
batch_size=args.batch_size,
num_workers=args.workers,
)
model = NL2NL(
encoder_model=args.nl_en_model,
decoder_model=args.nl_de_model,
learning_rate=args.lr,
weight_decay=args.wd,
)
logger = TensorBoardLogger(
save_dir=args.path_logs,
run_name=args.run_name,
)
if args.logger == "wandb":
logger = WandbLogger(
save_dir=args.path_logs,
name=args.run_name,
id=args.run_name,
project=PROJECT_NAME
)
logger.log_hyperparams({"jobid": args.jobid}) # Logging jobid of HPC
trainer = pl.Trainer(
logger=logger,
accelerator="gpu",
devices=args.gpus,
max_epochs=args.epochs,
log_every_n_steps=2,
deterministic=True # Hopefully get same results on different GPUs
)
trainer.fit(model, datamodule=dm)
model.save(
encoder_path=args.path_save_nl_encoder,
decoder_path=args.path_save_nl_decoder,
lm_path=args.path_save_nl_lm
)