Skip to content

Commit

Permalink
FIX: hiden problem with revert method
Browse files Browse the repository at this point in the history
  • Loading branch information
MarekWadinger committed Apr 29, 2024
1 parent 08472bb commit cbac2c5
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions river/decomposition/osvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def update(self, x: dict | np.ndarray):
P = A - self._U @ M # m x c
# TODO: [1] suggest computing orthogonal basis of P.
# Results seems to be the same for non rank-increasing updates.
Pot = np.linalg.qr(P)[0].T # c x m or m x m if m < c
Po = np.linalg.qr(P)[0]
Pot = Po.T # c x m or m x m if m < c
R_A = np.pad(
Pot @ P, ((0, P.shape[1] - Pot.shape[0]), (0, 0))
) # c x c
Expand Down Expand Up @@ -308,7 +309,7 @@ def update(self, x: dict | np.ndarray):
random_state=self.seed,
) # r + c x r; ...; r x r + c

U_ = np.column_stack((self._U, Pot.T)) @ U_ # m x r
U_ = np.column_stack((self._U, Po)) @ U_ # m x r
Vt_ = Vt_ @ np.row_stack((_Vt, Qot)) # r x n + c

if self.force_orth:
Expand All @@ -322,15 +323,15 @@ def revert(self, x: dict | np.ndarray, idx: int = 0):
c = 1 if isinstance(x, dict) else x.shape[0]
nc = self._Vt.shape[1]
B = np.zeros((nc, c)) # n + c x c
B[-c:] = np.identity(c)
# Schmid takes first c columns of Vt
# N = _Vt @ B # r x c
if idx >= 0:
B[idx : idx + c, :] = np.identity(c)
N = self._Vt[:, idx : idx + c] # r x c
elif idx == -1:
B[-c:, :] = np.identity(c)
N = self._Vt[:, -c:] # r x c
else:
B[-c + idx + 1 : idx + 1, :] = np.identity(c)

# Schmid takes first c columns of Vt
N = self._Vt @ B # r x c
N = self._Vt[:, -c + idx + 1 : idx + 1] # r x c
V = self._Vt.T # n + c x r
Q = B - V @ N # n + c x c
Qot = np.linalg.qr(Q)[
Expand All @@ -348,6 +349,7 @@ def revert(self, x: dict | np.ndarray, idx: int = 0):
U_, S_, Vt_ = _svd(
K,
self.n_components,
# Seems like this converges to different results
v0=np.row_stack((self._Vt, Qot))[:, 0],
solver=self.solver,
random_state=self.seed,
Expand Down Expand Up @@ -576,7 +578,7 @@ def update(self, x: dict | np.ndarray):
d = Q.T @ (self.W @ A) # k x c
e = A - Q @ d # m x c
p = np.sqrt(e.T @ self.W @ e) # c x c
p[np.isnan(p)] = 0.0 # c x c
p[np.isnan(p)] = np.zeros((c,c)) # c x c
# Step 2: Check tolerance
if (p < self.tol).all(): # n_incr += c
self.q += 1 # 1 x 1
Expand Down

0 comments on commit cbac2c5

Please sign in to comment.