Skip to content

Commit

Permalink
auto set log_interval if None
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed May 8, 2024
1 parent 960f08c commit 2eaa06d
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 7 deletions.
8 changes: 6 additions & 2 deletions danling/runner/accelerate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __post_init__(self) -> None:
dataloader_kwargs[k].setdefault("drop_last", not getattr(d, "train", True))
self.dataloaders[k] = self.prepare(utils.data.DataLoader(d, **dataloader_kwargs[k]))
default_kwargs.update(dataloader_kwargs)
if self.state.get("log_interval") is None:
self.state.log_interval = max(len(d) for d in self.dataloaders.values()) // 10

@property
def deepspeed(self) -> dict | None:
Expand Down Expand Up @@ -157,6 +159,7 @@ def train_epoch(self, split: str = "train") -> NestedDict:
loader = self.dataloaders[split]
length = len(loader) - 1
last_print_iteration = -1
log_interval = self.state.get("log_interval", -1)
self.meters.reset()
if self.metrics is not None:
self.metrics.reset()
Expand All @@ -176,7 +179,7 @@ def train_epoch(self, split: str = "train") -> NestedDict:
self.metrics.update(pred, target)
self.step(loss)

if self.log_interval > 0 and (iteration > 0 and iteration % self.log_interval == 0 or iteration == length):
if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length):
interval = iteration - last_print_iteration
if self.device == torch.device("cuda"):
torch.cuda.synchronize()
Expand Down Expand Up @@ -233,6 +236,7 @@ def evaluate_epoch(self, split: str = "val") -> NestedDict:
loader = self.dataloaders[split]
length = len(loader) - 1
last_print_iteration = -1
log_interval = self.state.get("log_interval", -1)
self.meters.reset()
if self.metrics is not None:
self.metrics.reset()
Expand All @@ -246,7 +250,7 @@ def evaluate_epoch(self, split: str = "val") -> NestedDict:
if self.metrics is not None:
self.metrics.update(pred, target)

if self.log_interval > 0 and (iteration > 0 and iteration % self.log_interval == 0 or iteration == length):
if log_interval > 0 and (iteration > 0 and iteration % log_interval == 0 or iteration == length):
interval = iteration - last_print_iteration
if self.device == torch.device("cuda"):
torch.cuda.synchronize()
Expand Down
2 changes: 1 addition & 1 deletion danling/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def init_deepspeed( # pylint: disable=too-many-branches, too-many-statements
if isinstance(config, str):
config = NestedDict.load(config)
if config.get("steps_per_print", "auto") == "auto":
config["steps_per_print"] = self.log_interval
config["steps_per_print"] = self.state.log_interval
if config.get("train_micro_batch_size_per_gpu", "auto") == "auto":
config["train_micro_batch_size_per_gpu"] = self.batch_size
if "amp" in config:
Expand Down
4 changes: 2 additions & 2 deletions danling/runner/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes
tensorboard (bool): Whether to use `tensorboard`.
Defaults to `False`.
log_interval (int): Interval of printing logs.
Defaults to -1.
Defaults to `None`, print logs every 1/10 of the longest split.
save_interval (int): Interval of saving intermediate checkpoints.
Defaults to `None`, never save checkpoints.
If <= 0, save only the latest and the best checkpoints.
Expand Down Expand Up @@ -112,7 +112,7 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes
checkpoint_dir_name: str = "checkpoints"
log: bool = True
tensorboard: bool = False
log_interval: int = -1
log_interval: Optional[int] = None
save_interval: Optional[int] = None

distributed: Optional[bool] = None
Expand Down
2 changes: 1 addition & 1 deletion tests/runner/test_base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self):
self.log = False
self.tensorboard = False
self.gradient_clip = False
self.log_interval = 10
self.log_interval = None
self.save_interval = None
self.train_iterations_per_epoch = 64
self.val_iterations_per_epoch = 16
Expand Down
2 changes: 1 addition & 1 deletion tests/runner/test_torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):
self.optim.weight_decay = 1e-4
self.log = False
self.tensorboard = False
self.log_interval = 1000
self.log_interval = None
self.save_interval = None
self.score_split = "val"
self.score_name = "loss"
Expand Down

0 comments on commit 2eaa06d

Please sign in to comment.