Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/pytorch lightning draft #697

Closed
wants to merge 3 commits into from
Closed

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Dec 17, 2021

Summary

this is a (by no means finished) draft of how it could look like if we integrated pytorch-lightning into our TorchForecastingModels (you can ignore failing tests).

It requires (on the example of TFTModel):

  • new abstract class for pytorch lightning modules all _*Model (_TFTModel, torch.nn.module) implementations inherit from it
  • moving some methods from TorchForecastingModel to new abstract class
  • moving some methods from TFTModel class to the _TFTModel class (torch.nn.module)
  • fit and predict will be performed by a pytorch_lightning.Trainer():
    • can either be provided by user or is automatically set up by TorchForecastingModel
    • can be used for callbacks, customization of nearly everything

Needs further work

  • loading, saving, tensorboard, ... (I think this can all be handled by the trainer)
  • random seed
  • check if iterating through batches works properly
  • check stuff in general

My thougts

  • I actually think it is a good idea to do it. pytorch-lightning can be customized almost everywhere
  • it can help make the code cleaner
  • would add a lot of new features for users through the PL-Trainer

Let me know if this is the way we want to go

And here would be a little code to play around with it (right now it still outputs a lot of warnings/ugly progress bar which need to be investigated/improved)

import pandas as pd
import pytorch_lightning as pl

from darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models.forecasting.tft_model import TFTModel
from darts.metrics import mape
from darts.datasets import AirPassengersDataset

from matplotlib import pyplot as plt

# Read data:
series = AirPassengersDataset().load()

n_epochs = 200
use_pl_trainer = False
seed = 42

# TODO: we have to check how we handle random seeds !!!
# this will work but user would have to specify it themselves
pl.seed_everything(seed)
if use_pl_trainer:
    trainer = pl.Trainer(max_epochs=n_epochs)
else:
    trainer = None

# Create training and validation sets:
training_cutoff = pd.Timestamp('19570101')
train, val = series.split_after(training_cutoff)

# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)

my_model = TFTModel(
    optimizer_kwargs={'lr': 1e-3},
    hidden_size=32,
    input_chunk_length=24,
    output_chunk_length=12,
    force_reset=True,
    dropout=0.1,
    num_attention_heads=4,
    full_attention=False,
    batch_size=16,
    n_epochs=n_epochs,
    add_relative_index=False,
    add_encoders={
        'datetime_attribute': {'future': ['year', 'month']},
        'position': {'future': ['absolute']},
        'transformer': Scaler()
    },
    random_state=seed,
    lstm_layers=1,
    trainer=trainer,
)

my_model.fit(train_transformed,
             val_series=val_transformed,
             verbose=True)

def eval_model(model):
    pred_series = model.predict(n=13, num_samples=100)
    plt.figure(figsize=(8,5))
    series_transformed[pred_series.end_time()- 36*pred_series.freq:pred_series.end_time()].plot(label='actual')
    pred_series.plot(label='1-99%', low_quantile=0.01, high_quantile=0.99)
    pred_series.plot(label='10-90%', low_quantile=0.1, high_quantile=0.9)
    plt.title('MAPE: {:.2f}%'.format(mape(pred_series, val_transformed)))
    plt.legend()
    plt.show()
    return pred_series

pred_series = eval_model(my_model)

@dennisbader
Copy link
Collaborator Author

Closing as deprecated version of #702

@madtoinou madtoinou deleted the feat/pytorch_lightning_draft branch December 11, 2023 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant