Failing to sample from Dirichlet Process Mixture Model using SGLD #260
-
Bug DescriptionI'm trying to reproduce the example from tensorflow-probability titled Fitting Dirichlet Process Mixture Model Using Preconditioned Stochastic Gradient Langevin Dynamics by using blackjax api. I've used the SGLD class. I tried to follow the example notebook that shows how to use SGMCMC in blackjax, but I can't seem to reproduce the results in the tensorflow example. Steps/Code to ReproduceI've implemented the blackjax version of the example in a Google Colab notebook located here Expected ResultsSamples from the posterior where there are three clusters Actual ResultsI'm getting VersionsBlackJAX 0.8.2 Additional commentI suspect there are two causes of the issue: |
Beta Was this translation helpful? Give feedback.
Replies: 6 comments 2 replies
-
Thank you for opening an issue. To help me debug, could you do the following: |
Beta Was this translation helpful? Give feedback.
-
Reducing the |
Beta Was this translation helpful? Give feedback.
-
The step size is probably too small. Moving this to discussions as it does not seem to be a bug. Have you tried the parameters given in the TFP article? # Learning rates and decay
starter_learning_rate = 1e-6
end_learning_rate = 1e-10
decay_steps = 1e4
# Number of training steps
training_steps = 10000
# Mini-batch size
batch_size = 20
# Sample size for parameter posteriors
sample_size = 100 |
Beta Was this translation helpful? Give feedback.
-
Most likely it is because you are trying to sample bounded variable, which results in non-valid proposal (log prob returns |
Beta Was this translation helpful? Give feedback.
-
@junpenglao, you are right! I forgot about the constrained support of some of the distributions. After applying the required transformation, I'm able to get the same exact result as the example! (Thanks to @rlouf as well for taking the time to respond) |
Beta Was this translation helpful? Give feedback.
-
If you are open to PR submissions, I can create an example notebook containing the blackjax version, and open a PR. Please let me know. |
Beta Was this translation helpful? Give feedback.
Most likely it is because you are trying to sample bounded variable, which results in non-valid proposal (log prob returns
nan
).Recommend to check out https://github.com/blackjax-devs/blackjax/blob/main/examples/change_of_variable_hmc.ipynb