Skip to content

Commit

Permalink
Address latest review comments from RicardoV94
Browse files Browse the repository at this point in the history
  • Loading branch information
zoj613 committed Jul 9, 2021
1 parent 195221b commit 7b8d33b
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from typing import List, Optional, Tuple, Union

import aesara
import aesara.tensor as at
import numpy as np

Expand Down Expand Up @@ -4106,7 +4107,7 @@ def make_node(self, x, h, z):
z = at.as_tensor_variable(floatX(z))
shape = broadcast_shape(x, h, z)
broadcastable = [] if not shape else [False] * len(shape)
return Apply(self, [x, h, z], [at.TensorType(np.float64, broadcastable)()])
return Apply(self, [x, h, z], [at.TensorType(aesara.config.floatX, broadcastable)()])

def perform(self, node, ins, outs):
x, h, z = ins[0], ins[1], ins[2]
Expand Down Expand Up @@ -4200,12 +4201,12 @@ class PolyaGamma(PositiveContinuous):

@classmethod
def dist(cls, h=1.0, z=0.0, **kwargs):
hh = at.as_tensor_variable(floatX(h))
zz = at.as_tensor_variable(floatX(z))
h = at.as_tensor_variable(floatX(h))
z = at.as_tensor_variable(floatX(z))

msg = f"The variable {hh} specified for PolyaGamma has non-positive "
msg = f"The variable {h} specified for PolyaGamma has non-positive "
msg += "values, making it unsuitable for this parameter."
Assert(msg)(hh, at.all(at.gt(hh, 0.0)))
Assert(msg)(h, at.all(at.gt(h, 0.0)))

return super().dist([h, z], **kwargs)

Expand Down

0 comments on commit 7b8d33b

Please sign in to comment.