-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added basic file logger #1803 * fixup! Added basic file logger #1803 * fixup! Added basic file logger #1803 * fixup! Added basic file logger #1803 * fixup! Added basic file logger #1803 * fixup! Added basic file logger #1803 * csv * Apply suggestions from code review * tests * tests * tests * miss * docs Co-authored-by: xmotli02 <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
- Loading branch information
1 parent
ac4a215
commit 767c449
Showing
7 changed files
with
320 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
""" | ||
CSV logger | ||
---------- | ||
CSV logger for basic experiment logging that does not require opening ports | ||
""" | ||
import io | ||
import os | ||
import csv | ||
import torch | ||
from argparse import Namespace | ||
from typing import Optional, Dict, Any, Union | ||
|
||
from pytorch_lightning import _logger as log | ||
from pytorch_lightning.core.saving import save_hparams_to_yaml | ||
from pytorch_lightning.loggers.base import LightningLoggerBase | ||
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only | ||
|
||
|
||
class ExperimentWriter(object): | ||
r""" | ||
Experiment writer for CSVLogger. | ||
Currently supports to log hyperparameters and metrics in YAML and CSV | ||
format, respectively. | ||
Args: | ||
log_dir: Directory for the experiment logs | ||
""" | ||
|
||
NAME_HPARAMS_FILE = 'hparams.yaml' | ||
NAME_METRICS_FILE = 'metrics.csv' | ||
|
||
def __init__(self, log_dir: str) -> None: | ||
self.hparams = {} | ||
self.metrics = [] | ||
|
||
self.log_dir = log_dir | ||
if os.path.exists(self.log_dir): | ||
rank_zero_warn( | ||
f"Experiment logs directory {self.log_dir} exists and is not empty." | ||
" Previous log files in this directory will be deleted when the new ones are saved!" | ||
) | ||
os.makedirs(self.log_dir, exist_ok=True) | ||
|
||
self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) | ||
|
||
def log_hparams(self, params: Dict[str, Any]) -> None: | ||
"""Record hparams""" | ||
self.hparams.update(params) | ||
|
||
def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: | ||
"""Record metrics""" | ||
def _handle_value(value): | ||
if isinstance(value, torch.Tensor): | ||
return value.item() | ||
return value | ||
|
||
if step is None: | ||
step = len(self.metrics) | ||
|
||
metrics = {k: _handle_value(v) for k, v in metrics_dict.items()} | ||
metrics['step'] = step | ||
self.metrics.append(metrics) | ||
|
||
def save(self) -> None: | ||
"""Save recorded hparams and metrics into files""" | ||
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) | ||
save_hparams_to_yaml(hparams_file, self.hparams) | ||
|
||
if not self.metrics: | ||
return | ||
|
||
last_m = {} | ||
for m in self.metrics: | ||
last_m.update(m) | ||
metrics_keys = list(last_m.keys()) | ||
|
||
with io.open(self.metrics_file_path, 'w', newline='') as f: | ||
self.writer = csv.DictWriter(f, fieldnames=metrics_keys) | ||
self.writer.writeheader() | ||
self.writer.writerows(self.metrics) | ||
|
||
|
||
class CSVLogger(LightningLoggerBase): | ||
r""" | ||
Log to local file system in yaml and CSV format. Logs are saved to | ||
``os.path.join(save_dir, name, version)``. | ||
Example: | ||
>>> from pytorch_lightning import Trainer | ||
>>> from pytorch_lightning.loggers import CSVLogger | ||
>>> logger = CSVLogger("logs", name="my_exp_name") | ||
>>> trainer = Trainer(logger=logger) | ||
Args: | ||
save_dir: Save directory | ||
name: Experiment name. Defaults to ``'default'``. | ||
version: Experiment version. If version is not specified the logger inspects the save | ||
directory for existing versions, then automatically assigns the next available version. | ||
""" | ||
|
||
def __init__(self, | ||
save_dir: str, | ||
name: Optional[str] = "default", | ||
version: Optional[Union[int, str]] = None): | ||
|
||
super().__init__() | ||
self._save_dir = save_dir | ||
self._name = name or '' | ||
self._version = version | ||
self._experiment = None | ||
|
||
@property | ||
def root_dir(self) -> str: | ||
""" | ||
Parent directory for all checkpoint subdirectories. | ||
If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used | ||
and the checkpoint will be saved in "save_dir/version_dir" | ||
""" | ||
if not self.name: | ||
return self.save_dir | ||
return os.path.join(self.save_dir, self.name) | ||
|
||
@property | ||
def log_dir(self) -> str: | ||
""" | ||
The log directory for this run. By default, it is named | ||
``'version_${self.version}'`` but it can be overridden by passing a string value | ||
for the constructor's version parameter instead of ``None`` or an int. | ||
""" | ||
# create a pseudo standard path ala test-tube | ||
version = self.version if isinstance(self.version, str) else f"version_{self.version}" | ||
log_dir = os.path.join(self.root_dir, version) | ||
return log_dir | ||
|
||
@property | ||
def save_dir(self) -> Optional[str]: | ||
return self._save_dir | ||
|
||
@property | ||
def experiment(self) -> ExperimentWriter: | ||
r""" | ||
Actual ExperimentWriter object. To use ExperimentWriter features in your | ||
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following. | ||
Example:: | ||
self.logger.experiment.some_experiment_writer_function() | ||
""" | ||
if self._experiment: | ||
return self._experiment | ||
|
||
os.makedirs(self.root_dir, exist_ok=True) | ||
self._experiment = ExperimentWriter(log_dir=self.log_dir) | ||
return self._experiment | ||
|
||
@rank_zero_only | ||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: | ||
params = self._convert_params(params) | ||
self.experiment.log_hparams(params) | ||
|
||
@rank_zero_only | ||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: | ||
self.experiment.log_metrics(metrics, step) | ||
|
||
@rank_zero_only | ||
def save(self) -> None: | ||
super().save() | ||
self.experiment.save() | ||
|
||
@rank_zero_only | ||
def finalize(self, status: str) -> None: | ||
self.save() | ||
|
||
@property | ||
def name(self) -> str: | ||
return self._name | ||
|
||
@property | ||
def version(self) -> int: | ||
if self._version is None: | ||
self._version = self._get_next_version() | ||
return self._version | ||
|
||
def _get_next_version(self): | ||
root_dir = os.path.join(self._save_dir, self.name) | ||
|
||
if not os.path.isdir(root_dir): | ||
log.warning('Missing logger folder: %s', root_dir) | ||
return 0 | ||
|
||
existing_versions = [] | ||
for d in os.listdir(root_dir): | ||
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): | ||
existing_versions.append(int(d.split("_")[1])) | ||
|
||
if len(existing_versions) == 0: | ||
return 0 | ||
|
||
return max(existing_versions) + 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.