From 276ec77f6f4c91da1605efb51be422322568d26b Mon Sep 17 00:00:00 2001 From: "d.a.bunin" Date: Fri, 25 Nov 2022 11:13:48 +0300 Subject: [PATCH] Remove trainer from DeepAR and TFT --- etna/models/nn/deepar.py | 5 ++--- etna/models/nn/tft.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/etna/models/nn/deepar.py b/etna/models/nn/deepar.py index 5d2d7e97e..8d49d8692 100644 --- a/etna/models/nn/deepar.py +++ b/etna/models/nn/deepar.py @@ -103,7 +103,6 @@ def __init__( self.trainer_kwargs = trainer_kwargs if trainer_kwargs is not None else dict() self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict() self.model: Optional[Union[LightningModule, DeepAR]] = None - self.trainer: Optional[pl.Trainer] = None self._last_train_timestamp = None self._freq: Optional[str] = None @@ -163,11 +162,11 @@ def fit(self, ts: TSDataset) -> "DeepARModel": ) trainer_kwargs.update(self.trainer_kwargs) - self.trainer = pl.Trainer(**trainer_kwargs) + trainer = pl.Trainer(**trainer_kwargs) train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size) - self.trainer.fit(self.model, train_dataloader) + trainer.fit(self.model, train_dataloader) return self diff --git a/etna/models/nn/tft.py b/etna/models/nn/tft.py index e945cbddc..8e5835ee1 100644 --- a/etna/models/nn/tft.py +++ b/etna/models/nn/tft.py @@ -110,7 +110,6 @@ def __init__( self.trainer_kwargs = trainer_kwargs if trainer_kwargs is not None else dict() self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict() self.model: Optional[Union[LightningModule, TemporalFusionTransformer]] = None - self.trainer: Optional[pl.Trainer] = None self._last_train_timestamp = None self._freq: Optional[str] = None @@ -170,11 +169,11 @@ def fit(self, ts: TSDataset) -> "TFTModel": ) trainer_kwargs.update(self.trainer_kwargs) - self.trainer = pl.Trainer(**trainer_kwargs) + trainer = pl.Trainer(**trainer_kwargs) train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size) - self.trainer.fit(self.model, train_dataloader) + trainer.fit(self.model, train_dataloader) return self