Skip to content

Commit 7ad25b4

Browse files
hrzndennisbader
andauthored
Fix/filter static cov component (unit8co#1128)
* simplify and improve univariate_component() * add unit test * Update darts/timeseries.py Co-authored-by: Dennis Bader <[email protected]> Co-authored-by: Dennis Bader <[email protected]>
1 parent eb18103 commit 7ad25b4

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

darts/tests/test_timeseries.py

+25
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,31 @@ def test_integer_indexing(self):
105105
list(indexed_ts.time_index) == list(pd.RangeIndex(2, 7, step=1))
106106
)
107107

108+
def test_univariate_component(self):
109+
series = TimeSeries.from_values(np.array([10, 20, 30])).with_columns_renamed(
110+
"0", "component"
111+
)
112+
mseries = concatenate([series] * 3, axis="component")
113+
mseries = mseries.with_hierarchy(
114+
{"component_1": ["component"], "component_2": ["component"]}
115+
)
116+
117+
static_cov = pd.DataFrame(
118+
{"dim0": [1, 2, 3], "dim1": [-2, -1, 0], "dim2": [0.0, 0.1, 0.2]}
119+
)
120+
121+
mseries = mseries.with_static_covariates(static_cov)
122+
123+
for univ_series in [
124+
mseries.univariate_component(1),
125+
mseries.univariate_component("component_1"),
126+
]:
127+
# hierarchy should be dropped
128+
self.assertIsNone(univ_series.hierarchy)
129+
130+
# only the right static covariate column should be retained
131+
self.assertEqual(univ_series.static_covariates.sum().sum(), 1.1)
132+
108133
def test_column_names(self):
109134
# test the column names resolution
110135
columns_before = [

darts/timeseries.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2702,6 +2702,9 @@ def univariate_component(self, index: Union[str, int]) -> "TimeSeries":
27022702
Retrieve one of the components of the series
27032703
and return it as new univariate ``TimeSeries`` instance.
27042704
2705+
This drops the hierarchy (if any), and retains only the relevant static
2706+
covariates column.
2707+
27052708
Parameters
27062709
----------
27072710
index
@@ -2713,11 +2716,8 @@ def univariate_component(self, index: Union[str, int]) -> "TimeSeries":
27132716
TimeSeries
27142717
A new univariate TimeSeries instance.
27152718
"""
2716-
if isinstance(index, int):
2717-
new_xa = self._xa.isel(component=index).expand_dims(DIMS[1], axis=1)
2718-
else:
2719-
new_xa = self._xa.sel(component=index).expand_dims(DIMS[1], axis=1)
2720-
return self.__class__(new_xa)
2719+
2720+
return self[index if isinstance(index, str) else self.components[index]]
27212721

27222722
def add_datetime_attribute(
27232723
self, attribute, one_hot: bool = False, cyclic: bool = False

0 commit comments

Comments
 (0)