Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ux rule suggestions #423

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft

Ux rule suggestions #423

wants to merge 5 commits into from

Conversation

wmkouw
Copy link
Member

@wmkouw wmkouw commented Oct 15, 2024

I have added a rule lookup with the RuleMethodError to print the list of existing rules, with a corresponding test as part of the rule_method_error testset.

The inference procedure

@model function test(y)
   r ~ Gamma(1.,1.)
   x ~ Bernoulli(r)
   y ~ Beta(x,1.)
end
infer(model = test(), data = (y = 1.0,))

now returns:

ERROR: RuleMethodError: no method matching rule for the given arguments

Existing rule(s) for node:

Beta(μ(a) :: PointMass, μ(b) :: PointMass)


Possible fix, define:

@rule Beta(:a, Marginalisation) (q_out::PointMass, q_b::PointMass, ) = begin 
    return ...
end

I am happy with this result but I noticed that there is a clash with a functional form constrain error now:

ERROR: The expression `q(r)` has an undefined functional form of type `ProductOf{GammaShapeRate{Float64}, Beta{Float64}}`. 
This is likely because the inference backend does not support the product of these distributions. 
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `q(r)`.

Possible solutions:
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance).
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example:
```julia
using ExponentialFamilyProjection

@constraints begin
    q(r) :: ProjectedTo(NormalMeanVariance)
end

So, this existing rule look may also need to be implemented in there. What do you think?

fixes #397

@wmkouw wmkouw self-assigned this Oct 15, 2024
@wmkouw wmkouw linked an issue Oct 15, 2024 that may be closed by this pull request
@wmkouw
Copy link
Member Author

wmkouw commented Oct 16, 2024

After a discussion in the RxInfer meeting today, we decided that the RuleMethodError was fine and that we will add a point to the bullet list in the constrain_form error message pointing the user to the Wikipedia page on conjugate priors.

Copy link
Member

@bvdmitri bvdmitri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of comments here:

  • Tests were actually failing, but due to another bug in the previous code in get_messages_from_rule_method, that returns "μ(a) :: BayesBase.PointMass}" (yes with this strange } at the end). The code for extracting those is old and perhaps should be rewritten entirely. But for now I modified the tests a bit.
  • The suggestion in the PR are incomplete because it doesn't show rules that include marginals, we have a separate function get_marginals_from_rule_method for this purpose.
  • Simply including marginals from get_marginals_from_rule_method wouldn't be entirely correct though, because the order matters, e.g rule that accepts q(a)::Something, μ(b)::Something and μ(b):: Something, q(a)::Something are two different rules, but perhaps we may skip it for now and just show first marginals and then messages?
  • Small comment, the rules use q_a while the error suggests q(a). WDYT about this discrepancy?

for node_rule in this_node_rules
node_name = get_node_from_rule_method(node_rule)
node_inputs = get_messages_from_rule_method(node_rule)
if typeof(node_inputs) !== Vector{Any}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_messages_from_rule_method(node_rule) will return Any[] on an empty list (see rule.jl#1297).

For example:

> all_rules = methods(ReactiveMP.rule)
> this_node_rules = all_rules[ReactiveMP.get_node_from_rule_method.(all_rules) .== "MvNormalMeanVariance"]
> ReactiveMP.get_messages_from_rule_method(this_node_rules[1])

Without the type check, this snippet would print a rule for an Any type (which doesn't exist).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, perhaps then we should use !isempty(node_inputs)?

@wmkouw
Copy link
Member Author

wmkouw commented Nov 4, 2024

The suggestion in the PR are incomplete because it doesn't show rules that include marginals, we have a separate function get_marginals_from_rule_method for this purpose.

Well, the intention of this functionality is to give the user advice on how to specify their model. For example, if they specify a Bernoulli likelihood with a Gamma prior, then they will get a RuleNotFound error. The new error reports that the Bernoulli node has a Beta rule, which may inspire the user to re-define their model to a Beta prior. Message rules are sufficient for achieving this goal, don't you think?

Simply including marginals from get_marginals_from_rule_method wouldn't be entirely correct though, because the order matters, e.g rule that accepts q(a)::Something, μ(b)::Something and μ(b):: Something, q(a)::Something are two different rules, but perhaps we may skip it for now and just show first marginals and then messages?

Yes. We could add "Note that the order of input arguments matters".

Small comment, the rules use q_a while the error suggests q(a). WDYT about this discrepancy?

I thought about catching this and reverting it. But the goal of this error info is to advise the user to re-specify their model. If they re-define their node, then ReactiveMP will operate on q_a and this will not be a problem. But if you prefer q_a over q(a), then I can change this.

@bvdmitri
Copy link
Member

bvdmitri commented Nov 4, 2024

The new error reports that the Bernoulli node has a Beta rule, which may inspire the user to re-define their model to a Beta prior. Message rules are sufficient for achieving this goal, don't you think?

Ah, you're right. I see. Yes, it will work for this case, but not for all. For instance, while the Gamma distribution is a conjugate prior for the precision parameter in a Normal distribution, we lack sum-product rules that use Gamma as a message. Instead, we only have variational inference rules that use Gamma as a marginal. This means the user will see an empty list of suggestions, even though a VI rule exists that could recommend switching to a Gamma prior (though this would require adjusting the factorization constraint....)

@wmkouw
Copy link
Member Author

wmkouw commented Nov 4, 2024

Ah ok. Then I will convert it back to draft and think of a solution.

@wmkouw wmkouw marked this pull request as draft November 4, 2024 13:49
@bvdmitri
Copy link
Member

bvdmitri commented Nov 4, 2024

We can also brain-storm together in the office. The proposed changes are also fine for me since its definitely better than the current state :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expand RuleMethodError with list of defined rules for given node
2 participants