Generalize likelihoods #93
Replies: 2 comments 8 replies
-
This is very high on the "to do list". If you were interested, then I would be very happy to support and work with you to get some more likelihoods implemented into GPJax. In addition to the Poisson and/or Negative Binomial, I would like to see categorical and exponential likelihood functions supported. Fortunately, implementing new distributions in Distrax is much more straightforward than in TFP. Further, any TFP distribution can seamlessly be used within distrax e.g., d_tfp = tfd.Normal(0., 1.)
d_dx = dx.Normal(1., 2.5)
d_dx.kl_divergence(d_tfp) My first thought would be that we should open a PR to the Distrax group for any new likelihood that we support as it would be good to minimise the amount of TFP code within GPJax. From the Distrax docs, a new likelihood would simply involve completing the following object class MyDistribution(distrax.Distribution):
def __init__(self, ...):
...
def _sample_n(self, key, n):
samples = ...
return samples
def log_prob(self, value):
log_prob = ...
return log_prob
def event_shape(self):
event_shape = ...
return event_shape
def _sample_n_and_log_prob(self, key, n):
# Optional. Only when more efficient implementation is possible.
samples, log_prob = ...
return samples, log_prob |
Beta Was this translation helpful? Give feedback.
-
Update: I have some working code that uses tensorflow_probability.substrates.jax.distributions. Swapping in distrax distributions once they're available should be easy, but there seems to be little traction on my PR. @thomaspinder @daniel-dodd shall I open a PR for a negative binomial likelihood based on |
Beta Was this translation helpful? Give feedback.
-
First thing I would want to get started on is implementing different likelihoods (Poisson, Negative Binomial). It seems that currently only bernoulli is allowed: https://github.com/thomaspinder/GPJax/blob/master/gpjax/likelihoods.py#L131
Having some docs on how to implement different likelihoods would be good as well.
Unfortunately, neither Poisson nor Negative Binomial seem to implemented in distrax at the moment. Is this something that should be done first?
Beta Was this translation helpful? Give feedback.
All reactions