diff --git a/tests/test_models.py b/tests/test_models.py index 4747490..15a7504 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,11 +4,14 @@ from utils.models import ChristianModel -@pytest.mark.parametrize("in_channels, num_classes", [(1, 6), (3, 6)]) -def test_christian_model(in_channels, num_classes): - n, c, h, w = 5, in_channels, 16, 16 +@pytest.mark.parametrize( + "image_shape, num_classes", + [((1, 16, 16), 6), ((3, 16, 16), 6)], +) +def test_christian_model(image_shape, num_classes): + n, c, h, w = 5, *image_shape - model = ChristianModel(c, num_classes) + model = ChristianModel(image_shape, num_classes) x = torch.randn(n, c, h, w) y = model(x)