Skip to content

Commit

Permalink
Implement log_graph for CometLogger. (#5295)
Browse files Browse the repository at this point in the history
Co-authored-by: chaton <[email protected]>
  • Loading branch information
neighthan and tchaton authored Jan 14, 2021
1 parent 9515750 commit ddd9cc1
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch import is_tensor

from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, _module_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -310,3 +311,7 @@ def __getstate__(self):
# needed later
state["_experiment"] = None
return state

def log_graph(self, model: LightningModule, input_array=None) -> None:
if self._experiment is not None:
self._experiment.set_model_graph(model)

0 comments on commit ddd9cc1

Please sign in to comment.