Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ux rule suggestions #423

Merged
merged 9 commits into from
Nov 27, 2024
Merged

Ux rule suggestions #423

merged 9 commits into from
Nov 27, 2024

Conversation

wmkouw
Copy link
Member

@wmkouw wmkouw commented Oct 15, 2024

I have added a rule lookup with the RuleMethodError to print the list of existing rules, with a corresponding test as part of the rule_method_error testset.

The inference procedure

@model function test(y)
   r ~ Gamma(1.,1.)
   x ~ Bernoulli(r)
   y ~ Beta(x,1.)
end
infer(model = test(), data = (y = 1.0,))

now returns:

ERROR: RuleMethodError: no method matching rule for the given arguments

Existing rule(s) for node:

Beta(μ(a) :: PointMass, μ(b) :: PointMass)


Possible fix, define:

@rule Beta(:a, Marginalisation) (q_out::PointMass, q_b::PointMass, ) = begin 
    return ...
end

I am happy with this result but I noticed that there is a clash with a functional form constrain error now:

ERROR: The expression `q(r)` has an undefined functional form of type `ProductOf{GammaShapeRate{Float64}, Beta{Float64}}`. 
This is likely because the inference backend does not support the product of these distributions. 
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `q(r)`.

Possible solutions:
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance).
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example:
```julia
using ExponentialFamilyProjection

@constraints begin
    q(r) :: ProjectedTo(NormalMeanVariance)
end

So, this existing rule look may also need to be implemented in there. What do you think?

fixes #397

@wmkouw wmkouw self-assigned this Oct 15, 2024
@wmkouw wmkouw linked an issue Oct 15, 2024 that may be closed by this pull request
@wmkouw
Copy link
Member Author

wmkouw commented Oct 16, 2024

After a discussion in the RxInfer meeting today, we decided that the RuleMethodError was fine and that we will add a point to the bullet list in the constrain_form error message pointing the user to the Wikipedia page on conjugate priors.

Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of comments here:

  • Tests were actually failing, but due to another bug in the previous code in get_messages_from_rule_method, that returns "μ(a) :: BayesBase.PointMass}" (yes with this strange } at the end). The code for extracting those is old and perhaps should be rewritten entirely. But for now I modified the tests a bit.
  • The suggestion in the PR are incomplete because it doesn't show rules that include marginals, we have a separate function get_marginals_from_rule_method for this purpose.
  • Simply including marginals from get_marginals_from_rule_method wouldn't be entirely correct though, because the order matters, e.g rule that accepts q(a)::Something, μ(b)::Something and μ(b):: Something, q(a)::Something are two different rules, but perhaps we may skip it for now and just show first marginals and then messages?
  • Small comment, the rules use q_a while the error suggests q(a). WDYT about this discrepancy?

src/rule.jl Outdated
for node_rule in this_node_rules
node_name = get_node_from_rule_method(node_rule)
node_inputs = get_messages_from_rule_method(node_rule)
if typeof(node_inputs) !== Vector{Any}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_messages_from_rule_method(node_rule) will return Any[] on an empty list (see rule.jl#1297).

For example:

> all_rules = methods(ReactiveMP.rule)
> this_node_rules = all_rules[ReactiveMP.get_node_from_rule_method.(all_rules) .== "MvNormalMeanVariance"]
> ReactiveMP.get_messages_from_rule_method(this_node_rules[1])

Without the type check, this snippet would print a rule for an Any type (which doesn't exist).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, perhaps then we should use !isempty(node_inputs)?

@wmkouw
Copy link
Member Author

wmkouw commented Nov 4, 2024

The suggestion in the PR are incomplete because it doesn't show rules that include marginals, we have a separate function get_marginals_from_rule_method for this purpose.

Well, the intention of this functionality is to give the user advice on how to specify their model. For example, if they specify a Bernoulli likelihood with a Gamma prior, then they will get a RuleNotFound error. The new error reports that the Bernoulli node has a Beta rule, which may inspire the user to re-define their model to a Beta prior. Message rules are sufficient for achieving this goal, don't you think?

Simply including marginals from get_marginals_from_rule_method wouldn't be entirely correct though, because the order matters, e.g rule that accepts q(a)::Something, μ(b)::Something and μ(b):: Something, q(a)::Something are two different rules, but perhaps we may skip it for now and just show first marginals and then messages?

Yes. We could add "Note that the order of input arguments matters".

Small comment, the rules use q_a while the error suggests q(a). WDYT about this discrepancy?

I thought about catching this and reverting it. But the goal of this error info is to advise the user to re-specify their model. If they re-define their node, then ReactiveMP will operate on q_a and this will not be a problem. But if you prefer q_a over q(a), then I can change this.

@bvdmitri
Copy link
Member

bvdmitri commented Nov 4, 2024

The new error reports that the Bernoulli node has a Beta rule, which may inspire the user to re-define their model to a Beta prior. Message rules are sufficient for achieving this goal, don't you think?

Ah, you're right. I see. Yes, it will work for this case, but not for all. For instance, while the Gamma distribution is a conjugate prior for the precision parameter in a Normal distribution, we lack sum-product rules that use Gamma as a message. Instead, we only have variational inference rules that use Gamma as a marginal. This means the user will see an empty list of suggestions, even though a VI rule exists that could recommend switching to a Gamma prior (though this would require adjusting the factorization constraint....)

@wmkouw
Copy link
Member Author

wmkouw commented Nov 4, 2024

Ah ok. Then I will convert it back to draft and think of a solution.

@wmkouw wmkouw marked this pull request as draft November 4, 2024 13:49
@bvdmitri
Copy link
Member

bvdmitri commented Nov 4, 2024

We can also brain-storm together in the office. The proposed changes are also fine for me since its definitely better than the current state :)

@wmkouw
Copy link
Member Author

wmkouw commented Nov 26, 2024

I have updated the procedure

  1. I report the alternative rule suggestion after the possible fix suggestion.
  2. It now reports both messages and marginals.
  3. They are formatted with m_a, q_a instead of m(a),q(a).
  4. I print "Note that .. order matters".
  5. Two test cases have been added.

Example:

@model function test(y)
          x ~ Bernoulli(.5)
          y ~ Beta(x, 1.)
       end
infer(model = test(), data = (y = 1.0,))

generates:

ERROR: RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule Beta(:a, Marginalisation) (q_out::PointMass, q_b::PointMass, ) = begin 
    return ...
end

Alternatively, consider re-specifying model using an existing rule:

Beta(m_a::PointMass, m_b::PointMass)
Beta(q_a::PointMass, q_b::PointMass)

Note that for marginal rules (i.e., involving q_*), the order of input types matters.

I cleaned up the code a little bit by adhering to the same structure as the earlier function calls in the error. I split @bartvanerp get_messages/marginals_from_rule_method into three separate functions each. These are also tested in a new set titled get_from_rule_method.

Shall I add some documentation for this suggestion? Or would you agree that it speaks for itself?

@wmkouw wmkouw requested a review from bartvanerp November 26, 2024 14:23
Copy link

codecov bot commented Nov 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 72.48%. Comparing base (9af77aa) to head (aea1dde).
Report is 52 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #423      +/-   ##
==========================================
- Coverage   72.62%   72.48%   -0.14%     
==========================================
  Files         190      190              
  Lines        5454     5492      +38     
==========================================
+ Hits         3961     3981      +20     
- Misses       1493     1511      +18     
Flag Coverage Δ
?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wmkouw wmkouw marked this pull request as ready for review November 26, 2024 15:32
@bartvanerp
Copy link
Member

Looks very nice @wmkouw! Sorry that you had to go through my regex expressions :D

@bvdmitri
Copy link
Member

The BayesBase.PointMass} seems to be a bug in method parsing with regexps.... To me it sounds minor given the extra benefits of the suggestions, we can fix } in a separate PR

@bvdmitri
Copy link
Member

Thanks @wmkouw !

@bvdmitri bvdmitri merged commit 7f9e27c into main Nov 27, 2024
4 of 5 checks passed
@bvdmitri bvdmitri deleted the ux-rule-suggestions branch November 27, 2024 08:14
@wmkouw
Copy link
Member Author

wmkouw commented Nov 27, 2024

After readying the PR, I read through your comments on the original PR, Dmitry, and realized you already mentioned that that BayesBase.PointMass} was coming from the regexp in get_messages_from_rule_method. I tried fixing it last night, but the weird thing is that there is a discrepancy between my REPL (regexp is correct) and the testrunner (regexp picks up the curly bracket). I wanted to talk about it today, with the hope that we could resolve it before merging the PR. But maybe it's better to indeed look at it in a different PR.

@bartvanerp, no worries about the regexp. Thank you for all the work you've done. That made this task much easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expand RuleMethodError with list of defined rules for given node
3 participants