Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Regression test cases #19

Merged
merged 1 commit into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions curvlinops/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import device as torch_device
from torch import from_numpy, tensor, zeros_like
from torch.autograd import grad
from torch.nn import Module
from torch.nn import Module, Parameter
from torch.nn.utils import parameters_to_vector
from tqdm import tqdm

Expand All @@ -33,7 +33,7 @@ def __init__(
self,
model_func: Callable[[Tensor], Tensor],
loss_func: Callable[[Tensor, Tensor], Tensor],
params: List[Tensor],
params: List[Parameter],
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
check_deterministic: bool = True,
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
self.to_device(old_device)

@staticmethod
def _infer_device(params: List[Tensor]) -> torch_device:
def _infer_device(params: List[Parameter]) -> torch_device:
"""Infer the device on which to carry out matvecs.

Args:
Expand Down
4 changes: 2 additions & 2 deletions curvlinops/gradient_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class EFLinearOperator(_LinearOperator):
def _matvec_batch(
self, X: Tensor, y: Tensor, x_list: List[Tensor]
) -> Tuple[Tensor, ...]:
"""Apply the mini-batch GGN to a vector.
"""Apply the mini-batch uncentered gradient covariance to a vector.

Args:
X: Input to the DNN.
y: Ground truth.
x_list: Vector in list format (same shape as trainable model parameters).

Returns:
Result of GGN-multiplication in list format.
Result of uncentered gradient covariance-multiplication in list format.

Raises:
ValueError: If the loss function's reduction cannot be determined.
Expand Down
37 changes: 35 additions & 2 deletions test/cases.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Contains test cases for linear operators."""

from test.utils import classification_targets, get_available_devices
from test.utils import classification_targets, get_available_devices, regression_targets

from torch import rand, rand_like
from torch.nn import BatchNorm1d, CrossEntropyLoss, Dropout, Linear, ReLU, Sequential
from torch.nn import (
BatchNorm1d,
CrossEntropyLoss,
Dropout,
Linear,
MSELoss,
ReLU,
Sequential,
)
from torch.utils.data import DataLoader, TensorDataset

DEVICES = get_available_devices()
Expand All @@ -13,6 +21,9 @@

# Add test cases here
CASES_NO_DEVICE = [
###############################################################################
# CLASSIFICATION #
###############################################################################
{
"model_func": lambda: Sequential(Linear(10, 5), ReLU(), Linear(5, 2)),
"loss_func": lambda: CrossEntropyLoss(reduction="mean"),
Expand All @@ -32,6 +43,28 @@
],
"seed": 0,
},
###############################################################################
# REGRESSION #
###############################################################################
{
"model_func": lambda: Sequential(Linear(8, 5), ReLU(), Linear(5, 3)),
"loss_func": lambda: MSELoss(reduction="mean"),
"data": lambda: [
(rand(2, 8), regression_targets((2, 3))),
(rand(6, 8), regression_targets((6, 3))),
],
"seed": 0,
},
# same as above, but uses reduction='sum'
{
"model_func": lambda: Sequential(Linear(8, 5), ReLU(), Linear(5, 3)),
"loss_func": lambda: MSELoss(reduction="sum"),
"data": lambda: [
(rand(2, 8), regression_targets((2, 3))),
(rand(6, 8), regression_targets((6, 3))),
],
"seed": 0,
},
]

CASES = []
Expand Down
2 changes: 1 addition & 1 deletion test/test_gradient_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def test_EFLinearOperator_matmat(case, num_vecs: int = 3):
EF_functorch = functorch_empirical_fisher(*case).detach().cpu().numpy()

X = random.rand(EF.shape[1], num_vecs).astype(EF.dtype)
report_nonclose(EF @ X, EF_functorch @ X)
report_nonclose(EF @ X, EF_functorch @ X, atol=1e-7, rtol=1e-4)
2 changes: 1 addition & 1 deletion test/test_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def test_HessianLinearOperator_matmat(case, num_vecs: int = 3):
H_functorch = functorch_hessian(*case).detach().cpu().numpy()

X = random.rand(H.shape[1], num_vecs)
report_nonclose(H @ X, H_functorch @ X)
report_nonclose(H @ X, H_functorch @ X, atol=1e-6, rtol=5e-4)
7 changes: 6 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utility functions to test `curvlinops`."""

from torch import cuda, device, randint
from torch import cuda, device, rand, randint


def get_available_devices():
Expand All @@ -20,3 +20,8 @@ def get_available_devices():
def classification_targets(size, num_classes):
"""Create random targets for classes 0, ..., `num_classes - 1`."""
return randint(size=size, low=0, high=num_classes)


def regression_targets(size):
"""Create random targets for regression."""
return rand(*size)