Skip to content

Commit

Permalink
fix forward declarations
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Apr 20, 2021
1 parent e031dc3 commit 6241792
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import torch

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only


Expand Down Expand Up @@ -289,7 +289,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
kwargs: Optional keywoard arguments, depends on the specific logger being used
"""

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
"""
Record model graph
Expand Down Expand Up @@ -381,7 +381,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
for logger in self._logger_iterable:
logger.log_hyperparams(params)

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
for logger in self._logger_iterable:
logger.log_graph(model, input_array)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from torch import is_tensor

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -318,6 +318,6 @@ def __getstate__(self):
state["_experiment"] = None
return state

def log_graph(self, model: LightningModule, input_array=None) -> None:
def log_graph(self, model: 'pl.LightningModule', input_array=None) -> None:
if self._experiment is not None:
self._experiment.set_model_graph(model)
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
Expand Down Expand Up @@ -210,7 +210,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
raise ValueError(m) from ex

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
def log_graph(self, model: 'pl.LightningModule', input_array=None):
if self._log_graph:
if input_array is None:
input_array = model.example_input_array
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from argparse import Namespace
from typing import Any, Dict, Optional, Union

from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
Expand Down Expand Up @@ -153,7 +153,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
self.experiment.log(metrics, global_step=step)

@rank_zero_only
def log_graph(self, model: LightningModule, input_array=None):
def log_graph(self, model: 'pl.LightningModule', input_array=None):
if self._log_graph:
if input_array is None:
input_array = model.example_input_array
Expand Down

0 comments on commit 6241792

Please sign in to comment.