Skip to content

Commit

Permalink
ENH: extend power_iteration to accept a matrix in implicit form
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613924237
  • Loading branch information
fabianp authored and OptaxDev committed Mar 14, 2024
1 parent f45b2eb commit 053c55b
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 52 deletions.
123 changes: 82 additions & 41 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# ==============================================================================
"""Linear algebra utilities used in optimisation."""

from typing import Callable, Optional, Union

import chex
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np

from optax import tree_utils as otu
from optax._src import base
from optax._src import numerics

Expand All @@ -30,54 +31,94 @@ def global_norm(updates: base.PyTree) -> chex.Array:
jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates)))


def power_iteration(matrix: chex.Array,
num_iters: int = 100,
error_tolerance: float = 1e-6,
precision: lax.Precision = lax.Precision.HIGHEST):
def power_iteration(
matrix: Union[chex.Array, Callable[[chex.ArrayTree], chex.ArrayTree]],
*,
v0: Optional[chex.ArrayTree] = None,
num_iters: int = 100,
error_tolerance: float = 1e-6,
precision: lax.Precision = lax.Precision.HIGHEST,
key: Optional[chex.PRNGKey] = None,
):
r"""Power iteration algorithm.
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
of `A`, and a vector v, which is the corresponding eigenvector of `A`.
This algorithm computes the dominant eigenvalue and its associated eigenvector
of a diagonalizable matrix. This matrix can be given as an array or as a
callable that implements a matrix-vector product.
References:
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
Wikipedia contributors. `Power iteration
<https://en.wikipedia.org/w/index.php?title=Power_iteration>`_.
Args:
matrix: the symmetric PSD matrix.
num_iters: Number of iterations.
error_tolerance: Iterative exit condition.
precision: precision XLA related flag, the available options are:
a) lax.Precision.DEFAULT (better step time, but not precise);
b) lax.Precision.HIGH (increased precision, slower);
c) lax.Precision.HIGHEST (best possible precision, slowest).
matrix: a square matrix, either as an array or a callable implementing a
matrix-vector product.
v0: initial vector approximating the dominiant eigenvector of ``matrix``.
If ``matrix`` is an array of size (n, n), v0 must be a vector of size
(n,). If ``matrix`` is a callable, then v0 must be a tree with the same
structure as the input of this callable. If this argument is None and
``matrix`` is an array, then a random vector sampled from a uniform
distribution in [-1, 1] is used as initial vector.
num_iters: Number of power iterations.
error_tolerance: Iterative exit condition. The procedure stops when the
relative error of the estimate of the dominant eigenvalue is below this
threshold.
precision: precision XLA related flag, the available options are: a)
lax.Precision.DEFAULT (better step time, but not precise); b)
lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST
(best possible precision, slowest).
key: random key for the initialization of ``v0`` when not given
explicitly. When this argument is None, `jax.random.PRNGKey(0)` is used.
Returns:
eigen vector, eigen value
A pair (eigenvalue, eigenvector), where eigenvalue is the dominant
eigenvalue of ``matrix`` and eigenvector is its associated eigenvector.
"""
matrix_size = matrix.shape[-1]
def _iter_condition(state):
i, unused_v, unused_s, unused_s_v, run_step = state
return jnp.logical_and(i < num_iters, run_step)

def _iter_body(state):
"""One step of power iteration."""
i, new_v, s, s_v, unused_run_step = state
new_v = new_v / jnp.linalg.norm(new_v)

s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
return (i + 1, s_v, s_new, s_v,
jnp.greater(jnp.abs(s_new - s), error_tolerance))

# Figure out how to use step as seed for random.
v_0 = np.random.uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)

init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
_, v_out, s_out, _, _ = lax.while_loop(
_iter_condition, _iter_body, init_state)
v_out = v_out / jnp.linalg.norm(v_out)
return v_out, s_out
if callable(matrix):
mvp = matrix
if v0 is None:
# v0 must be given as we have no way of knowing the underlying shape.
raise ValueError('v0 must be provided when `matrix` is a callable.')
else:
def mvp(v):
return jnp.matmul(matrix, v, precision=precision)
if v0 is None:
if key is None:
key = jax.random.PRNGKey(0)
# v0 is uniformly distributed in [-1, 1]
v0 = jax.random.uniform(
key,
shape=matrix.shape[-1:],
dtype=matrix.dtype,
minval=-1.0,
maxval=1.0,
)

def _normalize_tree(x):
# divide by the L2 norm of the tree weights.
return otu.tree_scalar_mul(1.0 / otu.tree_l2_norm(x), x)

v0 = _normalize_tree(v0)

def _cond_fun(loop_vars):
x, z, eig, iter_num = loop_vars
residual = otu.tree_l2_norm(otu.tree_sub(z, otu.tree_scalar_mul(eig, x)))
converged = jnp.abs(residual / eig) < error_tolerance
return ~converged & (iter_num < num_iters)

def _body_fun(loop_vars):
_, z, _, iter_num = loop_vars
x = _normalize_tree(z)
z = mvp(x)
eig = otu.tree_vdot(x, z)
return x, z, eig, iter_num + 1

init_vars = (v0, mvp(v0), jnp.asarray(0.0), jnp.asarray(0))
_, unormalized_eigenvector, dominant_eigenvalue, _ = (
jax.lax.while_loop(_cond_fun, _body_fun, init_vars)
)
normalized_eigenvector = _normalize_tree(unormalized_eigenvector)
return dominant_eigenvalue, normalized_eigenvector


def matrix_inverse_pth_root(matrix: chex.Array,
Expand Down
156 changes: 145 additions & 11 deletions optax/_src/linear_algebra_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,179 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for optax._src.linear_algebra."""

from absl.testing import absltest
from typing import Iterable

from absl.testing import absltest
from absl.testing import parameterized
import chex
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from optax import tree_utils
from optax._src import linear_algebra
import scipy.stats


class LinearAlgebraTest(absltest.TestCase):
class MLP(nn.Module):
# Multi-layer perceptron (MLP).
num_outputs: int
hidden_sizes: Iterable[int]

@nn.compact
def __call__(self, x):
for num_hidden in self.hidden_sizes:
x = nn.Dense(num_hidden)(x)
x = nn.gelu(x)
return nn.Dense(self.num_outputs)(x)


class LinearAlgebraTest(chex.TestCase):

def test_global_norm(self):
flat_updates = jnp.array([2., 4., 3., 5.], dtype=jnp.float32)
flat_updates = jnp.array([2.0, 4.0, 3.0, 5.0], dtype=jnp.float32)
nested_updates = dict(
a=jnp.array([2., 4.], dtype=jnp.float32),
b=jnp.array([3., 5.], dtype=jnp.float32))
a=jnp.array([2.0, 4.0], dtype=jnp.float32),
b=jnp.array([3.0, 5.0], dtype=jnp.float32),
)
np.testing.assert_array_equal(
jnp.sqrt(jnp.sum(flat_updates**2)),
linear_algebra.global_norm(nested_updates))
linear_algebra.global_norm(nested_updates),
)

@chex.all_variants
@parameterized.parameters(
dict(implicit=True),
dict(implicit=False),
)
def test_power_iteration(
self, implicit=True, dim=6, tol=1e-3, num_iters=100
):
"""Test power_iteration by comparing to numpy.linalg.eigh."""

if implicit:
# test the function when the matrix is given in implicit form by a
# matrix-vector product.
def power_iteration(matrix, *, v0):
return linear_algebra.power_iteration(
lambda x: matrix @ x,
v0=v0,
error_tolerance=tol,
num_iters=num_iters,
)
else:
power_iteration = linear_algebra.power_iteration

# test this function with/without jax.jit and on different devices
power_iteration = self.variant(power_iteration)

# create a random PSD matrix
matrix = jax.random.normal(jax.random.PRNGKey(0), (dim, dim))
matrix = matrix @ matrix.T
v0 = jnp.ones((dim,))

eigval_power, eigvec_power = power_iteration(matrix, v0=v0)
all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(matrix)

self.assertAlmostEqual(eigval_power, all_eigenval[-1], delta=10 * tol)
np.testing.assert_array_almost_equal(
all_eigenvec[:, -1] * jnp.sign(all_eigenvec[:, -1][0]),
eigvec_power * jnp.sign(eigvec_power[0]),
decimal=3,
)

@chex.all_variants
def test_power_iteration_pytree(
self, dim=6, tol=1e-3, num_iters=100
):
"""Test power_iteration for matrix-vector products acting on pytrees."""

def matrix_vector_product(x):
# implements a block-diagonal matrix where each block is a scaled
# identity matrix. The scaling factor is 2 and 1 for the first and second
# block respectively.
return {'a': 2 * x['a'], 'b': x['b']}

@self.variant
def power_iteration(*, v0):
return linear_algebra.power_iteration(
matrix_vector_product,
v0=v0,
error_tolerance=tol,
num_iters=num_iters,
)

v0 = {'a': jnp.ones((dim,)), 'b': jnp.ones((dim,))}

eigval_power, _ = power_iteration(v0=v0)

# from the block-diagonal structure of matrix, largest eigenvalue is 2.
self.assertAlmostEqual(eigval_power, 2., delta=10 * tol)

@chex.all_variants
def test_power_iteration_mlp_hessian(
self, input_dim=16, output_dim=4, tol=1e-3
):
"""Test power_iteration on the Hessian of an MLP."""
mlp = MLP(num_outputs=output_dim, hidden_sizes=[input_dim, 8, output_dim])
key = jax.random.PRNGKey(0)
key_params, key_input, key_output = jax.random.split(key, 3)
# initialize the mlp
params = mlp.init(key_params, jnp.ones(input_dim))
x = jax.random.normal(key_input, (input_dim,))
y = jax.random.normal(key_output, (output_dim,))

@self.variant
def train_obj(params_):
z = mlp.apply(params_, x)
return jnp.sum((z - y) ** 2)

def hessian_vector_product(tangents_):
return jax.jvp(jax.grad(train_obj), (params,), (tangents_,))[1]

eigval_power, eigvec_power = linear_algebra.power_iteration(
hessian_vector_product, v0=tree_utils.tree_ones_like(params)
)

params_flat, unravel = jax.flatten_util.ravel_pytree(params)
eigvec_power_flat, _ = jax.flatten_util.ravel_pytree(eigvec_power)

def train_obj_flat(params_flat_):
params_ = unravel(params_flat_)
return train_obj(params_)

hessian = jax.hessian(train_obj_flat)(params_flat)
all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(hessian)

self.assertAlmostEqual(all_eigenval[-1], eigval_power, delta=10 * tol)
np.testing.assert_array_almost_equal(
all_eigenvec[:, -1] * jnp.sign(all_eigenvec[:, -1][0]),
eigvec_power_flat * jnp.sign(eigvec_power_flat[0]),
decimal=3,
)

def test_matrix_inverse_pth_root(self):
"""Test for matrix inverse pth root."""

def _gen_symmetrix_matrix(dim, condition_number):
u = scipy.stats.ortho_group.rvs(dim=dim).astype(np.float64)
v = u.T
diag = np.diag([condition_number ** (-i/(dim-1)) for i in range(dim)])
diag = np.diag([condition_number ** (-i / (dim - 1)) for i in range(dim)])
return u @ diag @ v

# Fails after it reaches a particular condition number.
for e in range(2, 12):
condition_number = 10 ** e
condition_number = 10**e
ms = _gen_symmetrix_matrix(16, condition_number)
self.assertLess(
np.abs(np.linalg.cond(ms) - condition_number),
condition_number * 0.01)
np.abs(np.linalg.cond(ms) - condition_number), condition_number * 0.01
)
error = linear_algebra.matrix_inverse_pth_root(
ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1]
ms.astype(np.float32), 4, ridge_epsilon=1e-12
)[1]
if e < 7:
self.assertLess(error, 0.1)
else:
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from optax.tree_utils._tree_math import tree_sum
from optax.tree_utils._tree_math import tree_vdot
from optax.tree_utils._tree_math import tree_zeros_like

0 comments on commit 053c55b

Please sign in to comment.