-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
152 lines (119 loc) · 5.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import torch
import hydra
from omegaconf import DictConfig
from typing import List, Optional, Tuple
import pyrootutils
import pytorch_lightning as pl
from pytorch_lightning.loggers import Logger
from pytorch_lightning import Callback
# from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
# allow_ops_in_compiled_graph()
from src.utils.logging import setup_logging
from src.utils.distributed import get_rank, get_world_size
import src.utils as utils
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/pyrootutils
# ------------------------------------------------------------------------------------ #
torch.set_num_threads(8)
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""
# Clear up residual cache from previous runs
if torch.cuda.is_available():
torch.cuda.empty_cache()
log = setup_logging(
output=cfg.paths.output_dir,
num_gpu=get_rank(),
num_shards=get_world_size(),
name="smart_canada_goose",
color=True,
)
log.info("Environment info:\n" + utils.collect_env_info())
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: pl.LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
# Hard code: to determine the shape of model's cls head
cfg.model.model.num_class = datamodule.dataset.n_classes
model: pl.LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: pl.Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger
)
#, fast_dev_run = True
#, overfit_batches=1
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
if cfg.get("compile"):
log.info("Compiling model!")
# model = torch.compile(model)
model = torch.compile(model, mode="reduce-overhead")
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="./configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
# train the model
metric_dict, _ = train(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = utils.get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
main()