Skip to content

Commit 0d9b394

Browse files
grllhrznTheMP
authored
Features/indexing (#150)
* add support for columns to the TimeSeries object * add colum support indexing to timeseries * fix wrong docstring * refactor indexing, fix docstring, columns as last arg * clean indexing method * refactor indexing only based on loc and iloc * Update darts/timeseries.py Co-authored-by: Julien Herzen <[email protected]> * use underlying columns by default * fix column added on intern _df and use self.freq_str * fix parameter position in from_times_and_values * fix the tests to use str columns * fix docstring timeseries * remove None check on df that should exists * add comment for clarifying that _df is a copy * add separate function to process columns * adapt map with str col indexing Co-authored-by: Julien Herzen <[email protected]> Co-authored-by: TheMP <[email protected]>
1 parent f6a31e9 commit 0d9b394

File tree

4 files changed

+157
-101
lines changed

4 files changed

+157
-101
lines changed

darts/models/forecasting_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _build_forecast_series(self,
9393

9494
time_index = self._generate_new_dates(len(points_preds))
9595

96-
return TimeSeries.from_times_and_values(time_index, points_preds, self.training_series.freq())
96+
return TimeSeries.from_times_and_values(time_index, points_preds, freq=self.training_series.freq())
9797

9898

9999
class UnivariateForecastingModel(ForecastingModel):

darts/tests/test_timeseries.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,9 @@ def test_getitem(self):
347347
seriesA: TimeSeries = self.series1.drop_after(pd.Timestamp("20130105"))
348348
self.assertEqual(self.series1[pd.date_range('20130101', ' 20130104')], seriesA)
349349
self.assertEqual(self.series1[:4], seriesA)
350-
self.assertTrue(self.series1[pd.Timestamp('20130101')].equals(self.series1.pd_dataframe()[:1]))
351-
self.assertEqual(self.series1[pd.Timestamp('20130101'):pd.Timestamp('20130105')], seriesA)
350+
self.assertTrue(self.series1[pd.Timestamp('20130101')] == TimeSeries(self.series1.pd_dataframe()[:1],
351+
freq=self.series1.freq()))
352+
self.assertEqual(self.series1[pd.Timestamp('20130101'):pd.Timestamp('20130104')], seriesA)
352353

353354
with self.assertRaises(IndexError):
354355
self.series1[pd.date_range('19990101', '19990201')]
@@ -441,21 +442,21 @@ def test_map(self):
441442
df_01 = series.pd_dataframe()
442443
df_012 = series.pd_dataframe()
443444

444-
df_0[[0]] = df_0[[0]].applymap(fn)
445-
df_2[[2]] = df_2[[2]].applymap(fn)
446-
df_01[[0, 1]] = df_01[[0, 1]].applymap(fn)
445+
df_0[["0"]] = df_0[["0"]].applymap(fn)
446+
df_2[["2"]] = df_2[["2"]].applymap(fn)
447+
df_01[["0", "1"]] = df_01[["0", "1"]].applymap(fn)
447448
df_012 = df_012.applymap(fn)
448449

449450
series_0 = TimeSeries(df_0, 'D')
450451
series_2 = TimeSeries(df_2, 'D')
451452
series_01 = TimeSeries(df_01, 'D')
452453
series_012 = TimeSeries(df_012, 'D')
453454

454-
self.assertEqual(series_0, series.map(fn, 0))
455-
self.assertEqual(series_0, series.map(fn, [0]))
456-
self.assertEqual(series_2, series.map(fn, 2))
457-
self.assertEqual(series_01, series.map(fn, [0, 1]))
458-
self.assertEqual(series_012, series.map(fn, [0, 1, 2]))
455+
self.assertEqual(series_0, series.map(fn, "0"))
456+
self.assertEqual(series_0, series.map(fn, ["0"]))
457+
self.assertEqual(series_2, series.map(fn, "2"))
458+
self.assertEqual(series_01, series.map(fn, ["0", "1"]))
459+
self.assertEqual(series_012, series.map(fn, ["0", "1", "2"]))
459460
self.assertEqual(series_012, series.map(fn))
460461

461462
self.assertNotEqual(series_01, series.map(fn))

darts/tests/test_timeseries_multivariate.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@ class TimeSeriesMultivariateTestCase(unittest.TestCase):
1313
times1 = pd.date_range('20130101', '20130110')
1414
times2 = pd.date_range('20130206', '20130215')
1515
dataframe1 = pd.DataFrame({
16-
0: range(10),
17-
1: range(5, 15),
18-
2: range(10, 20)
16+
"0": range(10),
17+
"1": range(5, 15),
18+
"2": range(10, 20)
1919
}, index=times1)
2020
dataframe2 = pd.DataFrame({
21-
0: np.arange(1, 11),
22-
1: np.arange(1, 11) * 3,
23-
2: np.arange(1, 11) * 5
21+
"0": np.arange(1, 11),
22+
"1": np.arange(1, 11) * 3,
23+
"2": np.arange(1, 11) * 5
2424
}, index=times1)
2525
dataframe3 = pd.DataFrame({
26-
0: np.arange(1, 11),
27-
1: np.arange(11, 21),
26+
"0": np.arange(1, 11),
27+
"1": np.arange(11, 21),
2828
}, index=times2)
2929
series1 = TimeSeries(dataframe1)
3030
series2 = TimeSeries(dataframe2)
@@ -44,7 +44,7 @@ def test_creation(self):
4444
# Series cannot be lower than three without passing frequency as argument to constructor
4545
with self.assertRaises(ValueError):
4646
TimeSeries(self.dataframe1.iloc[:2, :])
47-
TimeSeries(self.dataframe1.iloc[:2, :], 'D')
47+
TimeSeries(self.dataframe1.iloc[:2, :], freq='D')
4848

4949
def test_eq(self):
5050
seriesA = TimeSeries(self.dataframe1)
@@ -73,9 +73,9 @@ def test_rescale(self):
7373
seriesB = self.series2.rescale_with_value(1)
7474
self.assertEqual(seriesB, TimeSeries(
7575
pd.DataFrame({
76-
0: np.arange(1, 11),
77-
1: np.arange(1, 11),
78-
2: np.arange(1, 11)
76+
"0": np.arange(1, 11),
77+
"1": np.arange(1, 11),
78+
"2": np.arange(1, 11)
7979
}, index=self.dataframe2.index).astype(float)
8080
))
8181

@@ -119,7 +119,7 @@ def test_stack(self):
119119
self.series1.stack(self.series3)
120120
seriesA = self.series1.stack(self.series2)
121121
dataframeA = pd.concat([self.dataframe1, self.dataframe2], axis=1)
122-
dataframeA.columns = range(6)
122+
dataframeA.columns = [str(i) for i in range(6)]
123123
self.assertTrue((seriesA.pd_dataframe() == dataframeA).all().all())
124124
self.assertEqual(seriesA.values().shape, (len(self.dataframe1),
125125
len(self.dataframe1.columns) + len(self.dataframe2.columns)))
@@ -130,7 +130,7 @@ def test_univariate_component(self):
130130
with self.assertRaises(ValueError):
131131
self.series1.univariate_component(3)
132132
seriesA = self.series1.univariate_component(1)
133-
self.assertTrue(seriesA == TimeSeries.from_times_and_values(self.times1, range(5, 15)))
133+
self.assertTrue(seriesA == TimeSeries.from_times_and_values(self.times1, range(5, 15), columns=["1"]))
134134
seriesB = self.series1.univariate_component(0).stack(seriesA).stack(self.series1.univariate_component(2))
135135
self.assertTrue(self.series1 == seriesB)
136136

0 commit comments

Comments
 (0)