Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect results when sampling from the prior #90

Open
elanmart opened this issue Mar 5, 2021 · 13 comments
Open

Incorrect results when sampling from the prior #90

elanmart opened this issue Mar 5, 2021 · 13 comments
Labels
bug Something isn't working example Fixing or adding an example priority-1 Not bug, but high priority issue / PR

Comments

@elanmart
Copy link

elanmart commented Mar 5, 2021

While going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.

What's more, I played with some other synthetic examples and they also give unintuitive results, see further down.

Examples

Example from the rethinking

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def model():
    μ <~ dist.Normal(178, 20)
    σ <~ dist.Uniform(0, 50)
    h <~ dist.Normal(μ, σ)
    
    return h

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=model, 
    model_args=(), 
    num_samples=10_000
)

fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=128)
axes = axes.reshape(-1)

sns.kdeplot(prior_predictive["μ"], ax=axes[0])
sns.kdeplot(prior_predictive["σ"], ax=axes[1])
sns.kdeplot(prior_predictive["h"], ax=axes[2])

plt.tight_layout()

Result

image

Expected

image

Synthetic example 1

In this example I sample an offset from Uniform(0, 1).
Then I sample from Uniform(12 - offset, 12 + offset)
So I expect my samples to be distributed in range [11, 13]
But I get samples in range [-15, 15]

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def example_1():
    
    center = 12
    offset <~ dist.Uniform(0, 1)
    
    low = (center - offset)
    high = (center + offset)
    
    outcome <~ dist.Uniform(low, high)

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=example_1, 
    model_args=(), 
    num_samples=10_000
)


ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");

Result

image

Synthetic example 2

This is the same example as above, but center variable is passed as argument, not hardcoded, and results are different (although still not in range [11, 13]

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def example_2(center):
    
    offset <~ dist.Uniform(0, 1)
    
    low = (center - offset)
    high = (center + offset)
    
    outcome <~ dist.Uniform(low, high)

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=example_2, 
    model_args=(12, ), 
    num_samples=10_000
)


ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");

Result

image

Expectation

For the examples 1 and 2, here's what I'd expect to get:

image

Environment

Linux-5.8.0-44-generic-x86_64-with-glibc2.10
Python 3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0]
JAX 0.2.8
NetworkX 2.5
JAXlib 0.1.58
mcx 2a2b94801e68d94d86826863eeee80f0b84c390d
@elanmart
Copy link
Author

elanmart commented Mar 6, 2021

Hi @rlouf

I've looked into this a bit more and identified two issues:

  1. In the mcx models, the same random key seems to be used for multiple distributions, giving incorrect results.
  2. The subtraction Op and negation Op seem broken

Please find the examples of the two issues in the notebook: https://gist.github.com/elanmart/9ab0ba21f282f6b24d972cbfb76b4578

Hope this is helpful

@elanmart elanmart changed the title Unintuitive results when sampling from the prior Incorrect results when sampling from the prior Mar 6, 2021
@rlouf
Copy link
Owner

rlouf commented Mar 6, 2021

Hi @elanmart,

Thank you for taking the time to share this with me! Regarding what you identified:

  1. Indeed, I just noticed that recently. It is indeed problematic if you use the same distribution more than once in the model. This should be corrected soon.
  2. In what sense? Would you mind pasting the output of print(example_1.sample_joint_src) and print(example_2.sample_joint_src) ?

@elanmart
Copy link
Author

elanmart commented Mar 6, 2021

Thanks for the answer! I was wondering how I can inspect the models, sample_joint_src reveals what goes wrong indeed!

The following model

@mcx.model
def example_2_mcx_v1():
    
    offset  <~ dist.Uniform(0, 5)
    low     =  12 - offset
    outcome <~ dist.Uniform(low, 12)
    
    return outcome

is transformed into

def example_2_mcx_v1_sample_forward(rng_key):
    offset = dist.Uniform(0, 5).sample(rng_key)
    low = offset - 12
    outcome = dist.Uniform(low, 12).sample(rng_key)
    forward_samples = {'offset': offset, 'outcome': outcome}
    return forward_samples

Notice how

low = 12 - offset

became

low = offset - 12

EDIT

The issue is not limited to constants. The arguments in subtraction are switched to match the order in which they were defined,
so

A <~ ...
B <~ ...
B - A

becomes

A - B

and so the model here

@mcx.model
def example():
    A <~ dist.Normal(0, 1)
    B <~ dist.Normal(0, 2)
    
    μ = B - A
    Y <~ dist.Normal(μ, 1)
    
    return Y

becomes

def example_sample_forward(rng_key):
    B = dist.Normal(0, 2).sample(rng_key)
    A = dist.Normal(0, 1).sample(rng_key)
    μ = A - B
    Y = dist.Normal(μ, 1).sample(rng_key)
    forward_samples = {'A': A, 'B': B, 'Y': Y}
    return forward_samples

@elanmart
Copy link
Author

elanmart commented Mar 6, 2021

Ah, and also regarding point 1. (same rng_key used many times):
is there any simple workaround I could use as a temporary solution, however hacky?

@rlouf
Copy link
Owner

rlouf commented Mar 7, 2021

That's strange regarding A-B, I identified the problem 10 days ago and I thought I'd fixed it. Are you running the latest version (latest commit)?

Unfortunately no workaround for the rng_key but I can try to push a fix next week. I'll make sure it works on these examples. In the meantime you can keep moving forward, checking the source code each time there's something weird. You'd just have to re-run your notebooks once the fixes are made.

Now I see how convenient compiling to a python function is for debugging 😄 Thank you for dealing with the teething problems here, it is really helpful for us.

@rlouf rlouf added bug Something isn't working example Fixing or adding an example priority-1 Not bug, but high priority issue / PR labels Mar 7, 2021
@elanmart
Copy link
Author

elanmart commented Mar 7, 2021

OK, so my poetry.lock file indicated that I have the latest commit, but after clean re-install the issue is indeed resolved...
I'm really sorry for generating noise 😢

Do you want me to close this ticket and open a clean one for rng_key topic? Out of curiosity -- what is the fix you envision there? Adding _, key = jax.random.split(key) statement to the graph after each sample() call? Or is there a nicer solution?

Thank you for dealing with the teething problems

No worries, I would love to understand the compiler a bit better to be able to debug similar issues myself.

@rlouf
Copy link
Owner

rlouf commented Mar 7, 2021

I'm really sorry for generating noise 😢

No worries, you're really helpful :)

Do you want me to close this ticket and open a clean one for rng_key topic?

Yes please! Leave this one open until we solve the issue completely.

Out of curiosity -- what is the fix you envision there? Adding _, key = jax.random.split(key) statement to the graph after each sample() call? Or is there a nicer solution?

So that would be the quick and dirty solution. I think that I might instead generate as many keys as needed at the beginning of the function.

No worries, I would love to understand the compiler a bit better to be able to debug similar issues myself.

Well now you know that you can at least print the code generated by the compiler. It's a good start point.

@tblazina
Copy link
Contributor

@elanmart - with regards to the compiler - I made some (not really organized) notes here, some of which I suppose are (or will soon be invalid) invalid after the <~ operator is phased out. In any case maybe they would be helpful to you!

@rlouf
Copy link
Owner

rlouf commented Mar 14, 2021

some of which I suppose are (or will soon be invalid) invalid after the <~ operator is phased out.

Actually the general principle with stay exactly the same.

@elanmart
Copy link
Author

Thank you @tblazina ! This looks extremely useful, will go through it over the weekend!

@rlouf
Copy link
Owner

rlouf commented May 5, 2021

Just to let you know, I'll make some time to work on this and the other issue one my NUTS PR is merged on BlackJAX (which means MCX will support NUTS). How is the implementation on SR going?

@elanmart
Copy link
Author

elanmart commented May 5, 2021

Thanks for the update! Looking forward to the NUTS sampler as well.

I've decided to first go through the theory, and then make a second pass implementing the examples.

I've just finished the book, so I'm going back to the code, which hopefully should go faster now.

There were a few places in the book where some advanced STAN featuers were used.
I'm a bit worried about those, but we'll see how it goes.

@rlouf
Copy link
Owner

rlouf commented May 6, 2021

Great! If you remember which ones don't hesitate to open issues now.

@rlouf rlouf mentioned this issue Jul 29, 2021
14 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working example Fixing or adding an example priority-1 Not bug, but high priority issue / PR
Projects
None yet
Development

No branches or pull requests

3 participants