Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logp with transforms fails in FAST_COMPILE #5618

Closed
katosh opened this issue Mar 18, 2022 · 5 comments · Fixed by #6735
Closed

Logp with transforms fails in FAST_COMPILE #5618

katosh opened this issue Mar 18, 2022 · 5 comments · Fixed by #6735

Comments

@katosh
Copy link
Contributor

katosh commented Mar 18, 2022

Description of your problem

I seems the Dirichlet distribution does not work in the current beta, although it seems to be expected to work.

import numpy as np
import pymc as pm

with pm.Model():
    a = pm.Dirichlet('a', np.ones(3))
    pm.sample()
Complete error traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
    962 try:
    963     outputs = (
--> 964         self.fn()
    965         if output_subset is None
    966         else self.fn(output_subset=output_subset)
    967     )
    968 except Exception:

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    518 @is_thunk_type
    519 def rval(
    520     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    521 ):
--> 522     r = p(n, [x[0] for x in i], o)
    523     for o in node.outputs:

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py:48, in TransformedVariable.perform(self, node, inputs, outputs)
     47 def perform(self, node, inputs, outputs):
---> 48     raise NotImplementedError(
     49         "These `Op`s should be removed from graphs used for computation."
     50     )

NotImplementedError: These `Op`s should be removed from graphs used for computation.

During handling of the above exception, another exception occurred:

NotImplementedError                       Traceback (most recent call last)
Input In [9], in <cell line: 4>()
      4 with pm.Model() as model:
      5     a = pm.Dirichlet('a', np.ones(3))
----> 6     pm.sample()

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/sampling.py:487, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    485 # One final check that shapes and logps at the starting points are okay.
    486 for ip in initial_points:
--> 487     model.check_start_vals(ip)
    488     _check_start_shape(model, ip)
    490 sample_args = {
    491     "draws": draws,
    492     "step": step,
   (...)
    503     "discard_tuned_samples": discard_tuned_samples,
    504 }

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1680, in Model.check_start_vals(self, start)
   1674     valid_keys = ", ".join(self.named_vars.keys())
   1675     raise KeyError(
   1676         "Some start parameters do not appear in the model!\n"
   1677         f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
   1678     )
-> 1680 initial_eval = self.point_logps(point=elem)
   1682 if not all(np.isfinite(v) for v in initial_eval.values()):
   1683     raise SamplingError(
   1684         "Initial evaluation of model at starting point failed!\n"
   1685         f"Starting values:\n{elem}\n\n"
   1686         f"Initial evaluation results:\n{initial_eval}"
   1687     )

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1721, in Model.point_logps(self, point, round_vals)
   1715 factors = self.basic_RVs + self.potentials
   1716 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
   1717 return {
   1718     factor.name: np.round(np.asarray(factor_logp), round_vals)
   1719     for factor, factor_logp in zip(
   1720         factors,
-> 1721         self.compile_fn(factor_logps_fn)(point),
   1722     )
   1723 }

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/pymc/model.py:1820, in PointFunc.__call__(self, state)
   1819 def __call__(self, state):
-> 1820     return self.f(**state)

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:977, in Function.__call__(self, *args, **kwargs)
    975     if hasattr(self.fn, "thunks"):
    976         thunk = self.fn.thunks[self.fn.position_of_error]
--> 977     raise_with_op(
    978         self.maker.fgraph,
    979         node=self.fn.nodes[self.fn.position_of_error],
    980         thunk=thunk,
    981         storage_map=getattr(self.fn, "storage_map", None),
    982     )
    983 else:
    984     # old-style linkers raise their own exceptions
    985     raise

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/link/utils.py:538, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    533     warnings.warn(
    534         f"{exc_type} error does not allow us to add an extra error message"
    535     )
    536     # Some exception need extra parameter in inputs. So forget the
    537     # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
    961 t0_fn = time.time()
    962 try:
    963     outputs = (
--> 964         self.fn()
    965         if output_subset is None
    966         else self.fn(output_subset=output_subset)
    967     )
    968 except Exception:
    969     restore_defaults()

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    518 @is_thunk_type
    519 def rval(
    520     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    521 ):
--> 522     r = p(n, [x[0] for x in i], o)
    523     for o in node.outputs:
    524         compute_map[o][0] = True

File ~/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py:48, in TransformedVariable.perform(self, node, inputs, outputs)
     47 def perform(self, node, inputs, outputs):
---> 48     raise NotImplementedError(
     49         "These `Op`s should be removed from graphs used for computation."
     50     )

NotImplementedError: These `Op`s should be removed from graphs used for computation.
Apply node that caused the error: TransformedVariable(Softmax{axis=0}.0, a_simplex__)
Toposort index: 20
Inputs types: [TensorType(float64, (None,)), TensorType(float64, (None,))]
Inputs shapes: [(3,), (2,)]
Inputs strides: [(8,), (8,)]
Inputs values: [array([0.33333333, 0.33333333, 0.33333333]), array([0., 0.])]
Outputs clients: [[Elemwise{eq,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0}), Elemwise{gt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 1}), Elemwise{lt,no_inplace}(a_simplex___simplex, TensorConstant{(1,) of 0})]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 203, in apply
    return self.default_transform_opt.optimize(fgraph)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 103, in optimize
    ret = self.apply(fgraph, *args, **kwargs)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1960, in apply
    nb += self.process_node(fgraph, node)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1850, in process_node
    replacements = lopt.transform(fgraph, node)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/opt.py", line 1055, in transform
    return self.fn(fgraph, node)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 148, in transform_values
    new_value_var = transformed_variable(
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aesara/graph/op.py", line 294, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/dotto/.conda/envs/pymc_v1/lib/python3.9/site-packages/aeppl/transforms.py", line 45, in make_node
    return Apply(self, [tran_value, value], [tran_value.type()])

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

The error above seems to indicate that Softmax is applied on the transformed RV of the Dirichlet distribution. However, the transformation currently used is the aeppl.transforms.Simplex which does not explicitly use the Softmax function:

https://github.com/aesara-devs/aeppl/blob/751979802f1aef5478fdbf7cc1839df07df60825/aeppl/transforms.py#L289-L311

Versions and main components

  • PyMC/PyMC3 Version: 4.0.0b4
  • Aesara/Theano Version: 2.5.1
  • aePPL Version: 0.0.27
  • Python Version: 3.9.10
  • Operating system: Ubuntu 18.04.5 LTS
  • How did you install PyMC/PyMC3: conda
@ricardoV94
Copy link
Member

@katosh
Copy link
Contributor Author

katosh commented Mar 19, 2022

Closing this for now since it cannot be reproduced with a different user account on the same machine.

@katosh katosh closed this as completed Mar 19, 2022
@katosh
Copy link
Contributor Author

katosh commented Mar 19, 2022

I solved the issue by removing the following option from my ~/.aesararc:

[global]
optimizer = fast_compile

@ricardoV94 ricardoV94 reopened this Mar 19, 2022
@ricardoV94
Copy link
Member

ricardoV94 commented Mar 19, 2022

I solved the issue by removing the following option from my ~/.aesararc:

[global]
optimizer = fast_compile

If that's what caused it, it's important to address

@michaelosthege michaelosthege added this to the v4.0.0b5 milestone Mar 19, 2022
@ricardoV94 ricardoV94 changed the title Broken Dirichlet distribution in 4.0.0b4 Logp with transforms fail in FAST_COMPILE Mar 20, 2022
@ricardoV94 ricardoV94 changed the title Logp with transforms fail in FAST_COMPILE Logp with transforms fails in FAST_COMPILE Mar 21, 2022
@ricardoV94 ricardoV94 modified the milestones: v4.0.0b5, v4.0.0b6 Mar 22, 2022
@ricardoV94 ricardoV94 modified the milestones: v4.0.0b6, v4.0.0b7 Mar 30, 2022
@ricardoV94
Copy link
Member

This is now our responsibility, should be easy to fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants