Skip to content

Commit 8fa7dbf

Browse files
Feat/kalman with fitting (#692)
* First proof of concept * Clean up and add checks * Remove support for multiple timeseries * Add tests and fixes * Add test given KF and fixes * Add docstrings * Add and adapt example * Improve covariates (error) handling * Remove unused imports Co-authored-by: Dennis Bader <[email protected]> * Improve docstring line length Co-authored-by: Dennis Bader <[email protected]> * Fix typos * Remove filterpy from requirements * Remove unnecessary parentheses * Add num_samples to docstring * Add num_block_rows to docstring * Clarify covariance formula * Update to nfoursid with relaxed requirements * Bump nfoursid requirement Co-authored-by: Dennis Bader <[email protected]>
1 parent ed88209 commit 8fa7dbf

File tree

6 files changed

+395
-128
lines changed

6 files changed

+395
-128
lines changed

darts/models/filtering/filtering_model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from abc import ABC, abstractmethod
99

1010
from darts.timeseries import TimeSeries
11-
from darts.logging import get_logger
11+
from darts.logging import get_logger, raise_if_not
1212

1313
logger = get_logger(__name__)
1414

@@ -35,4 +35,5 @@ def filter(self, series: TimeSeries) -> TimeSeries:
3535
TimeSeries
3636
A time series containing the filtered values.
3737
"""
38-
pass
38+
raise_if_not(series.is_deterministic, 'The input series must be '
39+
'deterministic (observations).')

darts/models/filtering/gaussian_process_filter.py

-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def filter(self,
5959
A stochastic ``TimeSeries`` sampled from the Gaussian Process, or its mean
6060
if `num_samples` is set to 1.
6161
"""
62-
raise_if_not(series.is_deterministic, 'The input series for the Gaussian Process filter must be '
63-
'deterministic (observations).')
6462
super().filter(series)
6563

6664
values = series.values(copy=False)

darts/models/filtering/kalman_filter.py

+135-81
Original file line numberDiff line numberDiff line change
@@ -4,150 +4,204 @@
44
"""
55

66
from abc import ABC
7-
8-
from typing import Optional
9-
from filterpy.kalman import KalmanFilter as FpKalmanFilter
107
from copy import deepcopy
8+
from typing import Optional
9+
1110
import numpy as np
11+
import pandas as pd
12+
from nfoursid.kalman import Kalman
13+
from nfoursid.nfoursid import NFourSID
1214

1315
from darts.models.filtering.filtering_model import FilteringModel
1416
from darts.timeseries import TimeSeries
15-
from darts.utils.utils import raise_if_not
17+
from darts.logging import raise_if, raise_if_not
1618

1719

1820
class KalmanFilter(FilteringModel, ABC):
1921
def __init__(
2022
self,
2123
dim_x: int = 1,
22-
x_init: Optional[np.array] = None,
23-
P: Optional[np.array] = None,
24-
Q: Optional[np.array] = None,
25-
R: Optional[np.array] = None,
26-
H: Optional[np.array] = None,
27-
F: Optional[np.array] = None,
28-
kf: Optional[FpKalmanFilter] = None
24+
kf: Optional[Kalman] = None
2925
):
3026
"""
31-
This model implements a Kalman filter over a time series (without control signal).
27+
This model implements a Kalman filter over a time series.
3228
3329
The key method is `KalmanFilter.filter()`.
3430
It considers the provided time series as containing (possibly noisy) observations z obtained from a
3531
(possibly noisy) linear dynamical system with hidden state x. The function `filter(series)` returns a new
36-
`TimeSeries` describing the distribution of the state x, as inferred by the Kalman filter from
37-
sequentially observing z from `series`.
38-
Depending on the use case, this can be used to de-noise a series or infer the underlying hidden state of the
39-
data generating process (assuming notably that the dynamical system generating the data is known, as captured
40-
by the `F` matrix.).
32+
`TimeSeries` describing the distribution of the output z (without noise), as inferred by the Kalman filter from
33+
sequentially observing z from `series`, and the dynamics of the linear system of order dim_x.
4134
42-
This implementation wraps around filterpy.kalman.KalmanFilter, so more information the parameters can be found
43-
here: https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html
35+
The method `KalmanFilter.fit()` is used to initialize the Kalman filter by estimating the state space model of
36+
a linear dynamical system and the covariance matrices of the process and measurement noise using the N4SID
37+
algorithm.
4438
45-
The dimensionality of the measurements z is automatically inferred upon calling `filter()`.
46-
This implementation doesn't include control signal.
39+
This implementation uses Kalman from the NFourSID package. More information can be found here:
40+
https://nfoursid.readthedocs.io/en/latest/source/kalman.html.
41+
42+
The dimensionality of the measurements z and optional control signal (covariates) u is automatically inferred upon
43+
calling `filter()`.
4744
4845
Parameters
4946
----------
5047
dim_x : int
51-
Size of the Kalman filter state vector. It determines the dimensionality of the `TimeSeries`
52-
returned by the `filter()` function.
53-
x_init : ndarray (dim_x, 1), default: [0, 0, ..., 0]
54-
Initial state; will be updated at each time step.
55-
P : ndarray (dim_x, dim_x), default: identity matrix
56-
initial covariance matrix; will be update at each time step
57-
Q : ndarray (dim_x, dim_x), default: identity matrix
58-
Process noise covariance matrix
59-
R : ndarray (dim_z, dim_z), default: identity matrix
60-
Measurement noise covariance matrix. `dim_z` must match the dimensionality (width) of the `TimeSeries`
61-
used with `filter()`.
62-
H : ndarray (dim_z, dim_x), default: all-ones matrix
63-
measurement function; describes how the measurement z is obtained from the state vector x
64-
F : ndarray (dim_x, dim_x), default: identity matrix
65-
State transition matrix; describes how the state evolves from one time step to the next
66-
in the underlying dynamical system.
67-
kf : filterpy.kalman.KalmanFilter
68-
Optionally, an instance of `filterpy.kalman.KalmanFilter`.
69-
If this is provided, the other parameters are ignored. This instance will be copied for every
48+
Size of the Kalman filter state vector.
49+
kf : nfoursid.kalman.Kalman
50+
Optionally, an instance of `nfoursid.kalman.Kalman`.
51+
If this is provided, the parameter dim_x is ignored. This instance will be copied for every
7052
call to `filter()`, so the state is not carried over from one time series to another across several
7153
calls to `filter()`.
72-
The various dimensionality in the filter must match those in the `TimeSeries` used when calling `filter()`.
54+
The various dimensionalities of the filter must match those of the `TimeSeries` used when calling `filter()`.
7355
"""
56+
# TODO: Add support for x_init. Needs reimplementation of NFourSID.
57+
7458
super().__init__()
59+
self._expect_covariates = False
60+
7561
if kf is None:
76-
self.dim_x = dim_x
77-
self.x_init = x_init if x_init is not None else np.zeros(self.dim_x,)
78-
self.P = P if P is not None else np.eye(self.dim_x)
79-
self.Q = Q if Q is not None else np.eye(self.dim_x)
80-
self.R = R
81-
self.H = H
82-
self.F = F if F is not None else np.eye(self.dim_x)
8362
self.kf = None
84-
self.kf_provided = False
63+
self.dim_x = dim_x
64+
self._kf_provided = False
8565
else:
8666
self.kf = kf
87-
self.kf_provided = True
67+
self.dim_u = kf.state_space.u_dim
68+
self.dim_x = kf.state_space.x_dim
69+
self.dim_y = kf.state_space.y_dim
70+
self._kf_provided = True
71+
if self.dim_u > 0:
72+
self._expect_covariates = True
8873

8974
def __str__(self):
9075
return 'KalmanFilter(dim_x={})'.format(self.dim_x)
9176

77+
def fit(self,
78+
series: TimeSeries,
79+
covariates: Optional[TimeSeries] = None,
80+
num_block_rows: Optional[int] = None) -> None:
81+
"""
82+
Initializes the Kalman filter using the N4SID algorithm.
83+
84+
Parameters
85+
----------
86+
series : TimeSeries
87+
The series of outputs (observations) used to infer the underlying state space model.
88+
This must be a deterministic series (containing one sample).
89+
covariates : Optional[TimeSeries]
90+
An optional series of inputs (control signal) that will also be used to infer the underlying state space model.
91+
This must be a deterministic series (containing one sample).
92+
num_block_rows : Optional[int]
93+
The number of block rows to use in the block Hankel matrices used in the N4SID algorithm.
94+
See the documentation of nfoursid.nfoursid.NFourSID for more information.
95+
If not provided, the dimensionality of the state space model will be used, with a maximum of 10.
96+
"""
97+
if covariates is not None:
98+
self._expect_covariates = True
99+
covariates = covariates.slice_intersect(series)
100+
raise_if_not(series.has_same_time_as(covariates),
101+
'The number of timesteps in the series and the covariates must match.')
102+
103+
# TODO: Handle multiple timeseries. Needs reimplementation of NFourSID?
104+
self.dim_y = series.width
105+
outputs = series.pd_dataframe()
106+
outputs.columns = [f'y_{i}' for i in outputs.columns]
107+
108+
if covariates is not None:
109+
self.dim_u = covariates.width
110+
inputs = covariates.pd_dataframe()
111+
inputs.columns = [f'u_{i}' for i in inputs.columns]
112+
input_columns = list(inputs.columns)
113+
measurements = pd.concat([outputs, inputs], axis=1)
114+
else:
115+
measurements = outputs
116+
input_columns = None
117+
118+
if num_block_rows is None:
119+
num_block_rows = max(10, self.dim_x)
120+
nfoursid = NFourSID(measurements,
121+
output_columns=list(outputs.columns),
122+
input_columns=input_columns,
123+
num_block_rows=num_block_rows)
124+
nfoursid.subspace_identification()
125+
state_space_identified, covariance_matrix = nfoursid.system_identification(
126+
rank=self.dim_x
127+
)
128+
129+
self.kf = Kalman(state_space_identified, covariance_matrix)
130+
131+
92132
def filter(self,
93133
series: TimeSeries,
94-
num_samples: int = 1):
134+
covariates: Optional[TimeSeries] = None,
135+
num_samples: int = 1) -> TimeSeries:
95136
"""
96137
Sequentially applies the Kalman filter on the provided series of observations.
97138
98139
Parameters
99140
----------
100141
series : TimeSeries
101-
The series of observations used to infer the state values according to the specified Kalman process.
142+
The series of outputs (observations) used to infer the underlying outputs according to the specified Kalman process.
143+
This must be a deterministic series (containing one sample).
144+
covariates : Optional[TimeSeries]
145+
An optional series of inputs (control signal), necessary if the Kalman filter was initialized with covariates.
102146
This must be a deterministic series (containing one sample).
147+
num_samples : int, default: 1
148+
The number of samples to generate from the inferred distribution of the output z. If this is set to 1, the
149+
output is a `TimeSeries` containing a single sample using the mean of the distribution.
103150
104151
Returns
105152
-------
106153
TimeSeries
107-
A stochastic `TimeSeries` of state values, of dimension `dim_x`.
154+
A (stochastic) `TimeSeries` of the inferred output z, of the same width as the input series.
108155
"""
156+
super().filter(series)
109157

110-
raise_if_not(series.is_deterministic, 'The input series for the Kalman filter must be '
111-
'deterministic (observations).')
158+
raise_if(self.kf is None, 'The Kalman filter has not been fitted yet. Call `fit()` first '
159+
'or provide Kalman filter in constructor.')
160+
161+
raise_if_not(series.width == self.dim_y, 'The provided TimeSeries dimensionality does not match '
162+
'the output dimensionality of the Kalman filter.')
112163

113-
dim_z = series.width
164+
raise_if(covariates is not None and not self._expect_covariates,
165+
'Covariates were provided, but the Kalman filter was not fitted with covariates.')
114166

115-
if not self.kf_provided:
116-
kf = FpKalmanFilter(dim_x=self.dim_x, dim_z=dim_z)
117-
kf.x = self.x_init
118-
kf.P = self.P
119-
kf.Q = self.Q
120-
kf.R = self.R if self.R is not None else np.eye(dim_z)
121-
kf.H = self.H if self.H is not None else np.ones((dim_z, self.dim_x))
122-
kf.F = self.F
123-
else:
124-
raise_if_not(dim_z == self.kf.dim_z, 'The provided TimeSeries dimensionality does not match '
125-
'the filter observation dimensionality dim_z.')
126-
kf = deepcopy(self.kf)
167+
if self._expect_covariates:
168+
raise_if(covariates is None,
169+
'The Kalman filter was fitted with covariates, but these were not provided.')
127170

128-
super().filter(series)
129-
values = series.values(copy=False)
171+
raise_if_not(covariates.is_deterministic,
172+
'The covariates must be deterministic (observations).')
173+
174+
covariates = covariates.slice_intersect(series)
175+
raise_if_not(series.has_same_time_as(covariates),
176+
'The number of timesteps in the series and the covariates must match.')
130177

178+
kf = deepcopy(self.kf)
179+
180+
y_values = series.values(copy=False)
181+
if self._expect_covariates:
182+
u_values = covariates.values(copy=False)
183+
else:
184+
u_values = np.zeros((len(y_values), 0))
185+
131186
# For each time step, we'll sample "n_samples" from a multivariate Gaussian
132187
# whose mean vector and covariance matrix come from the Kalman filter.
133188
if num_samples == 1:
134-
sampled_states = np.zeros(((len(values)), self.dim_x, ))
189+
sampled_states = np.zeros((len(y_values), self.dim_y, ))
135190
else:
136-
sampled_states = np.zeros(((len(values)), self.dim_x, num_samples))
191+
sampled_states = np.zeros((len(y_values), self.dim_y, num_samples))
137192

138-
# process_means = np.zeros((len(values), self.dim_x)) # mean values
139-
# process_covariances = ... # covariance matrices; TODO
140-
for i in range(len(values)):
141-
obs = values[i, :]
142-
kf.predict()
143-
kf.update(obs)
144-
mean_vec = kf.x.reshape(self.dim_x,)
193+
for i in range(len(y_values)):
194+
y = y_values[i, :].reshape(-1, 1)
195+
u = u_values[i, :].reshape(-1, 1)
196+
kf.step(y, u)
197+
mean_vec = kf.y_filtereds[-1].reshape(self.dim_y,)
145198

146199
if num_samples == 1:
147-
# It's actually not sampled in this case
148200
sampled_states[i, :] = mean_vec
149201
else:
150-
cov_matrix = kf.P
202+
# The measurement covariance matrix is given by the sum of the covariance matrix of the
203+
# state estimate (transformed by C) and the covariance matrix of the measurement noise.
204+
cov_matrix = kf.state_space.c @ kf.p_filtereds[-1] @ kf.state_space.c.T + kf.r
151205
sampled_states[i, :, :] = np.random.multivariate_normal(mean_vec, cov_matrix, size=num_samples).T
152206

153207
# TODO: later on for a forecasting model we'll have to do something like

0 commit comments

Comments
 (0)