Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pretty-printing and fix latex underscores #6533

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
)
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import BlockModelAccess
from pymc.printing import str_for_dist
from pymc.pytensorf import collect_default_updates, convert_observed_data
from pymc.util import UNSET, _add_future_warning_tag
from pymc.vartypes import string_types
Expand Down Expand Up @@ -198,7 +197,7 @@ class SymbolicRandomVariable(OpFromGraph):
"""

_print_name: Tuple[str, str] = ("Unknown", "\\operatorname{Unknown}")
"""Tuple of (name, latex name) used for for pretty-printing variables of this type"""
"Tuple of (name, latex name) used for for pretty-printing variables of this type"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason for this change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry - Refreshing on pep 257 I've realized triple quotes are probably preferred!


def __init__(self, *args, ndim_supp, **kwargs):
self.ndim_supp = ndim_supp
Expand Down Expand Up @@ -320,9 +319,12 @@ def __new__(
)

# add in pretty-printing support
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
from pymc.printing import str_for_model_var

rv_out.str_repr = types.MethodType(str_for_model_var, rv_out)
rv_out._repr_latex_ = types.MethodType(
functools.partial(str_for_dist, formatting="latex"), rv_out
functools.partial(str_for_model_var, formatting="latex"),
rv_out,
)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
Expand Down
4 changes: 2 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ class _LKJCholeskyCovBaseRV(RandomVariable):
ndim_supp = 1
ndims_params = [0, 0, 1]
dtype = "floatX"
_print_name = ("_lkjcholeskycovbase", "\\operatorname{_lkjcholeskycovbase}")
_print_name = ("_lkjcholeskycovbase", r"\operatorname{\_lkjcholeskycovbase}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the addition of \ in front of _? All the other random variables follow ("rv_name", "\\operatorname{rv_name}"). Without the \, do you obtain a strange looking string?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, thinking about this again... the answer is probably in the title of this PR and related to issue #6508

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep exactly!


def make_node(self, rng, size, dtype, n, eta, D):
n = at.as_tensor_variable(n)
Expand Down Expand Up @@ -1164,7 +1164,7 @@ def rng_fn(self, rng, n, eta, D, size):
# be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper
class _LKJCholeskyCovRV(SymbolicRandomVariable):
default_output = 1
_print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}")
_print_name = ("_lkjcholeskycov", r"\operatorname{\_lkjcholeskycov}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above


def update(self, node):
return {node.inputs[0]: node.outputs[0]}
Expand Down
16 changes: 6 additions & 10 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,15 +2015,13 @@ def Deterministic(name, var, model=None, dims=None):
model.deterministics.append(var)
model.add_named_variable(var, dims)

from pymc.printing import str_for_potential_or_deterministic
from pymc.printing import str_for_model_var

var.str_repr = types.MethodType(
functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
functools.partial(str_for_model_var, dist_name="Deterministic"), var
)
var._repr_latex_ = types.MethodType(
functools.partial(
str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
),
functools.partial(str_for_model_var, dist_name="Deterministic", formatting="latex"),
var,
)

Expand All @@ -2047,15 +2045,13 @@ def Potential(name, var, model=None):
model.potentials.append(var)
model.add_named_variable(var)

from pymc.printing import str_for_potential_or_deterministic
from pymc.printing import str_for_model_var

var.str_repr = types.MethodType(
functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
functools.partial(str_for_model_var, dist_name="Potential"), var
)
var._repr_latex_ = types.MethodType(
functools.partial(
str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
),
functools.partial(str_for_model_var, dist_name="Potential", formatting="latex"),
var,
)

Expand Down
Loading