Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
duydl committed Feb 18, 2025
1 parent 6c3c38d commit ba576bf
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
24 changes: 24 additions & 0 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,30 @@ With :func:`torch.inference_mode` disabled, you can enable the grad of your mode
trainer = Trainer(inference_mode=False)
trainer.validate(model)
enable_autolog_hparams
^^^^^^^^^^^^^^^^^^^^^^

Whether to log hyperparameters at the start of a run. Defaults to True.

.. testcode::

# default used by the Trainer
trainer = Trainer(enable_autolog_hparams=True)

# disable logging hyperparams
trainer = Trainer(enable_autolog_hparams=False)

With the parameter set to false, you can add custom code to log hyperparameters.

.. code-block:: python
model = LitModel()
trainer = Trainer(enable_autolog_hparams=False)
for logger in trainer.loggers:
if isinstance(logger, lightning.pytorch.loggers.CSVLogger):
logger.log_hyperparams(hparams_dict_1)
else:
logger.log_hyperparams(hparams_dict_2)
-----

Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
sync_batchnorm: bool = False,
reload_dataloaders_every_n_epochs: int = 0,
default_root_dir: Optional[_PATH] = None,
log_hyperparams_enabled: bool = True,
enable_autolog_hparams: bool = True,
) -> None:
r"""Customize every aspect of training via flags.
Expand Down Expand Up @@ -291,7 +291,7 @@ def __init__(
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
log_hyperparams_enabled: Whether to log hyperparameters at the start of a run.
enable_autolog_hparams: Whether to log hyperparameters at the start of a run.
Default: ``True``.
Raises:
Expand Down Expand Up @@ -500,7 +500,7 @@ def __init__(
num_sanity_val_steps,
)

self.log_hyperparams_enabled = log_hyperparams_enabled
self.enable_autolog_hparams = enable_autolog_hparams

def fit(
self,
Expand Down Expand Up @@ -969,7 +969,7 @@ def _run(
call._call_lightning_module_hook(self, "on_fit_start")

# only log hparams if enabled
if self.log_hyperparams_enabled:
if self.enable_autolog_hparams:
_log_hyperparams(self)

if self.strategy.restore_checkpoint_after_setup:
Expand Down

0 comments on commit ba576bf

Please sign in to comment.