Skip to content

Commit eb18103

Browse files
authored
fix torch import for flavors (unit8co#1129)
1 parent 4e5f1e6 commit eb18103

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

darts/tests/models/forecasting/test_transformer_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import tempfile
33

44
import pandas as pd
5-
import torch.nn as nn
65

76
from darts import TimeSeries
87
from darts.logging import get_logger
@@ -12,6 +11,8 @@
1211
logger = get_logger(__name__)
1312

1413
try:
14+
import torch.nn as nn
15+
1516
from darts.models.components.transformer import (
1617
CustomFeedForwardDecoderLayer,
1718
CustomFeedForwardEncoderLayer,

0 commit comments

Comments
 (0)