Skip to content

Commit

Permalink
Remove trainer from DeepAR and TFT
Browse files Browse the repository at this point in the history
  • Loading branch information
d.a.bunin committed Nov 25, 2022
1 parent 45e62b3 commit 276ec77
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
5 changes: 2 additions & 3 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __init__(
self.trainer_kwargs = trainer_kwargs if trainer_kwargs is not None else dict()
self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict()
self.model: Optional[Union[LightningModule, DeepAR]] = None
self.trainer: Optional[pl.Trainer] = None
self._last_train_timestamp = None
self._freq: Optional[str] = None

Expand Down Expand Up @@ -163,11 +162,11 @@ def fit(self, ts: TSDataset) -> "DeepARModel":
)
trainer_kwargs.update(self.trainer_kwargs)

self.trainer = pl.Trainer(**trainer_kwargs)
trainer = pl.Trainer(**trainer_kwargs)

train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)

self.trainer.fit(self.model, train_dataloader)
trainer.fit(self.model, train_dataloader)

return self

Expand Down
5 changes: 2 additions & 3 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(
self.trainer_kwargs = trainer_kwargs if trainer_kwargs is not None else dict()
self.quantiles_kwargs = quantiles_kwargs if quantiles_kwargs is not None else dict()
self.model: Optional[Union[LightningModule, TemporalFusionTransformer]] = None
self.trainer: Optional[pl.Trainer] = None
self._last_train_timestamp = None
self._freq: Optional[str] = None

Expand Down Expand Up @@ -170,11 +169,11 @@ def fit(self, ts: TSDataset) -> "TFTModel":
)
trainer_kwargs.update(self.trainer_kwargs)

self.trainer = pl.Trainer(**trainer_kwargs)
trainer = pl.Trainer(**trainer_kwargs)

train_dataloader = pf_transform.pf_dataset_train.to_dataloader(train=True, batch_size=self.batch_size)

self.trainer.fit(self.model, train_dataloader)
trainer.fit(self.model, train_dataloader)

return self

Expand Down

0 comments on commit 276ec77

Please sign in to comment.