Skip to content

Commit ae47aba

Browse files
authored
Feat/stochastic inputs (#833)
* Use stochastic samples for training/inference of torch models * Fix typo * add a unit test
1 parent cb80cf3 commit ae47aba

File tree

5 files changed

+57
-6
lines changed

5 files changed

+57
-6
lines changed

darts/tests/models/forecasting/test_probabilistic_models.py

+21
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,24 @@ def _get_avgs(series):
242242
"The difference between the mean forecast and the mean series is larger "
243243
"than expected on component 1 for distribution {}".format(lkl),
244244
)
245+
246+
def test_stochastic_inputs(self):
247+
model = RNNModel(input_chunk_length=5)
248+
model.fit(self.constant_ts, epochs=2)
249+
250+
# build a stochastic series
251+
target_vals = self.constant_ts.values()
252+
stochastic_vals = np.random.normal(
253+
loc=target_vals, scale=1.0, size=(len(self.constant_ts), 100)
254+
)
255+
stochastic_vals = np.expand_dims(stochastic_vals, axis=1)
256+
stochastic_series = TimeSeries.from_times_and_values(
257+
self.constant_ts.time_index, stochastic_vals
258+
)
259+
260+
# A deterministic model forecasting a stochastic series
261+
# should return stochastic samples
262+
preds = [model.predict(series=stochastic_series, n=10) for _ in range(2)]
263+
264+
# random samples should differ
265+
self.assertFalse(np.alltrue(preds[0].values() == preds[1].values()))

darts/timeseries.py

+22
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,28 @@ def values(self, copy=True, sample=0) -> np.ndarray:
11761176
else:
11771177
return self._xa.values[:, :, sample]
11781178

1179+
def random_component_values(self, copy=True) -> np.array:
1180+
"""
1181+
Return a 2-D array of shape (time, component), containing the values for
1182+
one sample taken uniformly at random among this series' samples.
1183+
1184+
Parameters
1185+
----------
1186+
copy
1187+
Whether to return a copy of the values, otherwise returns a view.
1188+
Leave it to True unless you know what you are doing.
1189+
1190+
Returns
1191+
-------
1192+
numpy.ndarray
1193+
The values composing one sample taken at random from the time series.
1194+
"""
1195+
sample = np.random.randint(low=0, high=self.n_samples)
1196+
if copy:
1197+
return np.copy(self._xa.values[:, :, sample])
1198+
else:
1199+
return self._xa.values[:, :, sample]
1200+
11791201
def all_values(self, copy=True) -> np.ndarray:
11801202
"""
11811203
Return a 3-D array of dimension (time, component, sample),

darts/utils/data/horizon_based_dataset.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __getitem__(
109109
# determine the index of the time series.
110110
ts_idx = idx // self.nr_samples_per_ts
111111
ts_target = self.target_series[ts_idx]
112-
target_vals = ts_target.values(copy=False)
112+
target_vals = ts_target.random_component_values(copy=False)
113113

114114
raise_if_not(
115115
len(target_vals)
@@ -168,7 +168,9 @@ def __getitem__(
168168
f"({idx}-th sample)",
169169
)
170170

171-
covariate = ts_covariate.values(copy=False)[cov_start:cov_end]
171+
covariate = ts_covariate.random_component_values(copy=False)[
172+
cov_start:cov_end
173+
]
172174

173175
raise_if_not(
174176
len(covariate) == len(past_target),

darts/utils/data/inference_dataset.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def __getitem__(
163163
)
164164

165165
# extract past target values
166-
past_target = target_series.values(copy=False)[-self.input_chunk_length :]
166+
past_target = target_series.random_component_values(copy=False)[
167+
-self.input_chunk_length :
168+
]
167169

168170
# optionally, extract covariates
169171
cov_past, cov_future = None, None
@@ -181,7 +183,9 @@ def __getitem__(
181183
)
182184

183185
# extract covariate values and split into a past (historic) and future part
184-
covariate = covariate_series.values(copy=False)[cov_start:cov_end]
186+
covariate = covariate_series.random_component_values(copy=False)[
187+
cov_start:cov_end
188+
]
185189
if self.input_chunk_length != 0: # regular models
186190
cov_past, cov_future = (
187191
covariate[: self.input_chunk_length],

darts/utils/data/shifted_dataset.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def __getitem__(self, idx) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray
522522
# determine the index of the time series.
523523
ts_idx = idx // self.max_samples_per_ts
524524
ts_target = self.target_series[ts_idx]
525-
target_vals = ts_target.values(copy=False)
525+
target_vals = ts_target.random_component_values(copy=False)
526526

527527
# determine the actual number of possible samples in this time series
528528
n_samples_in_ts = len(target_vals) - self.size_of_both_chunks + 1
@@ -582,7 +582,9 @@ def __getitem__(self, idx) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray
582582
f"that don't extend far enough into the future. ({idx}-th sample)",
583583
)
584584

585-
covariate = ts_covariate.values(copy=False)[cov_start:cov_end]
585+
covariate = ts_covariate.random_component_values(copy=False)[
586+
cov_start:cov_end
587+
]
586588

587589
raise_if_not(
588590
len(covariate)

0 commit comments

Comments
 (0)