Skip to content

Commit

Permalink
Update tests to accept new model input
Browse files Browse the repository at this point in the history
  • Loading branch information
salomaestro committed Feb 4, 2025
1 parent f4e5591 commit d911e4a
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from utils.models import ChristianModel


@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)])
def test_christian_model(in_channels, num_classes):
n, c, h, w = 5, in_channels, 16, 16
@pytest.mark.parametrize(
"image_shape, num_classes",
[((1, 16, 16), 6), ((3, 16, 16), 6)],
)
def test_christian_model(image_shape, num_classes):
n, c, h, w = 5, *image_shape

model = ChristianModel(c, num_classes)
model = ChristianModel(image_shape, num_classes)

x = torch.randn(n, c, h, w)
y = model(x)
Expand Down

0 comments on commit d911e4a

Please sign in to comment.