Skip to content

Commit

Permalink
An alternate fix that unifies the behavior of fabric CSVLogger and fa…
Browse files Browse the repository at this point in the history
…bric TensorBoardLogger
  • Loading branch information
water-vapor committed Mar 21, 2023
1 parent 046bd5f commit 4907616
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/lightning/fabric/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def experiment(self) -> "_ExperimentWriter":
if self._experiment is not None:
return self._experiment

os.makedirs(self.root_dir, exist_ok=True)
os.makedirs(os.path.join(self.root_dir, self.name), exist_ok=True)
self._experiment = _ExperimentWriter(log_dir=self.log_dir)
return self._experiment

Expand Down Expand Up @@ -152,13 +152,14 @@ def finalize(self, status: str) -> None:

def _get_next_version(self) -> int:
root_dir = self.root_dir
save_dir = os.path.join(root_dir, self.name)

if not self._fs.isdir(root_dir):
log.warning("Missing logger folder: %s", root_dir)
return 0

existing_versions = []
for d in self._fs.listdir(root_dir):
for d in self._fs.listdir(save_dir):
full_path = d["name"]
name = os.path.basename(full_path)
if self._fs.isdir(full_path) and name.startswith("version_"):
Expand Down
6 changes: 3 additions & 3 deletions tests/tests_fabric/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_file_logger_automatic_versioning(tmpdir):
root_dir = tmpdir.mkdir("exp")
root_dir.mkdir("version_0")
root_dir.mkdir("version_1")
logger = CSVLogger(root_dir=root_dir, name="exp")
logger = CSVLogger(root_dir=tmpdir, name="exp")
assert logger.version == 2


Expand All @@ -37,7 +37,7 @@ def test_file_logger_automatic_versioning_relative_root_dir(tmpdir, monkeypatch)
logs_dir.mkdir("version_0")
logs_dir.mkdir("version_1")
monkeypatch.chdir(tmpdir)
logger = CSVLogger(root_dir="exp/logs", name="logs")
logger = CSVLogger(root_dir="exp", name="logs")
assert logger.version == 2


Expand All @@ -47,7 +47,7 @@ def test_file_logger_manual_versioning(tmpdir):
root_dir.mkdir("version_0")
root_dir.mkdir("version_1")
root_dir.mkdir("version_2")
logger = CSVLogger(root_dir=root_dir, name="exp", version=1)
logger = CSVLogger(root_dir=tmpdir, name="exp", version=1)
assert logger.version == 1


Expand Down

0 comments on commit 4907616

Please sign in to comment.