diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index dd066ba2de5..cbd613abdcc 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -53,6 +53,9 @@ from pytensor.scalar import ( Abs, Add, + ArcCosh, + ArcSinh, + ArcTanh, Cosh, Erf, Erfc, @@ -71,6 +74,9 @@ from pytensor.tensor.math import ( abs, add, + arccosh, + arcsinh, + arctanh, cosh, erf, erfc, @@ -369,7 +375,23 @@ def apply(self, fgraph: FunctionGraph): class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable""" - valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh, Erf, Erfc, Erfcx) + valid_scalar_types = ( + Exp, + Log, + Add, + Mul, + Pow, + Abs, + Sinh, + Cosh, + Tanh, + ArcSinh, + ArcCosh, + ArcTanh, + Erf, + Erfc, + Erfcx, + ) # Cannot use `transform` as name because it would clash with the property added by # the `TransformValuesRewrite` @@ -501,7 +523,9 @@ def measurable_sub_to_neg(fgraph, node): return [pt.add(minuend, pt.neg(subtrahend))] -@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh, erf, erfc, erfcx]) +@node_rewriter( + [exp, log, add, mul, pow, abs, sinh, cosh, tanh, arcsinh, arccosh, arctanh, erf, erfc, erfcx] +) def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: """Find measurable transformations from Elemwise operators.""" @@ -544,6 +568,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li Sinh: SinhTransform(), Cosh: CoshTransform(), Tanh: TanhTransform(), + ArcSinh: ArcsinhTransform(), + ArcCosh: ArccoshTransform(), + ArcTanh: ArctanhTransform(), Erf: ErfTransform(), Erfc: ErfcTransform(), Erfcx: ErfcxTransform(), @@ -660,6 +687,39 @@ def backward(self, value, *inputs): return pt.arctanh(value) +class ArcsinhTransform(RVTransform): + name = "arcsinh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.arcsinh(value) + + def backward(self, value, *inputs): + return pt.sinh(value) + + +class ArccoshTransform(RVTransform): + name = "arccosh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.arccosh(value) + + def backward(self, value, *inputs): + return pt.cosh(value) + + +class ArctanhTransform(RVTransform): + name = "arctanh" + ndim_supp = 0 + + def forward(self, value, *inputs): + return pt.arctanh(value) + + def backward(self, value, *inputs): + return pt.tanh(value) + + class ErfTransform(RVTransform): name = "erf" ndim_supp = 0 diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 147d4d46a1b..785e2599fbb 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -51,6 +51,9 @@ from pymc.logprob.abstract import MeasurableVariable, _logprob from pymc.logprob.basic import conditional_logp, logp from pymc.logprob.transforms import ( + ArccoshTransform, + ArcsinhTransform, + ArctanhTransform, ChainedTransform, CoshTransform, ErfcTransform, @@ -1028,6 +1031,9 @@ def test_multivariate_transform(shift, scale): (pt.sinh, SinhTransform()), (pt.cosh, CoshTransform()), (pt.tanh, TanhTransform()), + (pt.arcsinh, ArcsinhTransform()), + (pt.arccosh, ArccoshTransform()), + (pt.arctanh, ArctanhTransform()), ], ) def test_erf_logp(pt_transform, transform): @@ -1060,6 +1066,9 @@ def test_erf_logp(pt_transform, transform): SinhTransform(), CoshTransform(), TanhTransform(), + ArcsinhTransform(), + ArccoshTransform(), + ArctanhTransform(), ], ) def test_check_jac_det(transform):