From 466a941b8869712489a94168a323a3711995eec3 Mon Sep 17 00:00:00 2001 From: danhphan Date: Fri, 4 Feb 2022 17:34:08 +1100 Subject: [PATCH] add test_shape_inputs for _OrderedLogistic --- pymc/tests/test_distributions_random.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index eedf8301681..73adc47abb7 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -1688,6 +1688,26 @@ class TestOrderedLogistic(BaseTestDistributionRandom): "check_rv_size", ] + @pytest.mark.parametrize( + "eta, cutpoints, expected", + [ + (0, [-2.0, 0, 2.0], (4,)), + ([-1], [-2.0, 0, 2.0], (1, 4)), + ([1.0, -2.0], [-1.0, 0, 1.0], (2, 4)), + ([[1.0, -1.0, 0.0], [-1.0, 3.0, 5.0]], [-2.0, 0, 1.0], (2, 3, 4)), + ], + ) + def test_shape_inputs(self, eta, cutpoints, expected): + """ + This test checks when providing different shapes for `eta` parameters. + """ + categorical = _OrderedLogistic.dist( + eta=eta, + cutpoints=cutpoints, + ) + p = categorical.owner.inputs[3].eval() + assert p.shape == expected + class TestOrderedProbit(BaseTestDistributionRandom): pymc_dist = _OrderedProbit