Skip to content

Commit

Permalink
Test expected (inferred) and actual shape of draws in `TestBaseDistri…
Browse files Browse the repository at this point in the history
…butionRandom`
  • Loading branch information
ricardoV94 committed Jan 24, 2022
1 parent 1e58bce commit 475ffbf
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,9 @@ def check_rv_size(self):
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
for size, expected in zip(sizes_to_check, sizes_expected):
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
actual = tuple(pymc_rv.shape.eval())
assert actual == expected, f"size={size}, expected={expected}, actual={actual}"
expected_symbolic = tuple(pymc_rv.shape.eval())
actual = pymc_rv.eval().shape
assert actual == expected_symbolic == expected

# test multi-parameters sampling for univariate distributions (with univariate inputs)
if (
Expand All @@ -386,8 +387,9 @@ def check_rv_size(self):
]
for size, expected in zip(sizes_to_check, sizes_expected):
pymc_rv = self.pymc_dist.dist(**params, size=size)
actual = tuple(pymc_rv.shape.eval())
assert actual == expected
expected_symbolic = tuple(pymc_rv.shape.eval())
actual = pymc_rv.eval().shape
assert actual == expected_symbolic == expected

def validate_tests_list(self):
assert len(self.tests_to_run) == len(
Expand Down

0 comments on commit 475ffbf

Please sign in to comment.