Skip to content

Commit

Permalink
Add utils for setting priors
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Nov 26, 2024
1 parent 0b59e60 commit 762a27f
Showing 1 changed file with 190 additions and 61 deletions.
251 changes: 190 additions & 61 deletions neurobayes/utils/priors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Callable, Dict, List
import inspect

from typing import Callable, Dict, Union, Type
from dataclasses import dataclass
import numpyro
import numpyro.distributions as dist
Expand All @@ -12,63 +14,190 @@ class GPPriors:
output_scale_prior: dist.Distribution = dist.LogNormal(0.0, 1.0)


def sample_weights(name: str, in_channels: int, out_channels: int, scale: float = 1.0) -> jnp.ndarray:
"""Sampling weights matrix"""
w = numpyro.sample(name=name, fn=dist.Normal(
jnp.zeros((in_channels, out_channels)),
scale * jnp.ones((in_channels, out_channels)))
)
return w


def sample_biases(name: str, channels: int, scale: float = 1.0) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Cauchy(
jnp.zeros((channels)),
scale * jnp.ones((channels)))
)
return b


def get_mlp_prior(input_dim: int, output_dim: int,
architecture: List[int], name: str = "main",
scale: float = 1.0
) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior():
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"{name}_w{i}"] = sample_weights(
f"{name}_w{i}", in_channels, out_channels, scale)
params[f"{name}_b{i}"] = sample_biases(
f"{name}_b{i}", out_channels, scale)
in_channels = out_channels
# Output layer
params[f"{name}_w{len(architecture)}"] = sample_weights(
f"{name}_w{len(architecture)}", in_channels, output_dim, scale)
params[f"{name}_b{len(architecture)}"] = sample_biases(
f"{name}_b{len(architecture)}", output_dim, scale)
return params
return mlp_prior


def get_heteroskedastic_mlp_prior(input_dim: int, output_dim: int,
architecture: List[int],
scale: float = 1.0
) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP with heteroskedastic outputs"""
def mlp_prior():
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels, scale)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels, scale)
in_channels = out_channels
# Output layers for mean and variance
params['w_mean'] = sample_weights('w_mean', in_channels, output_dim, scale)
params['b_mean'] = sample_biases('b_mean', output_dim, scale)
params['w_variance'] = sample_weights('w_variance', in_channels, output_dim, scale)
params['b_variance'] = sample_biases('b_variance', output_dim, scale)
return params
return mlp_prior
def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Samples a value from a normal distribution with the specified mean (loc) and standard deviation (scale),
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
return numpyro.sample(param_name, normal_dist(loc, scale))


def place_lognormal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Samples a value from a log-normal distribution with the specified mean (loc) and standard deviation (scale),
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
return numpyro.sample(param_name, lognormal_dist(loc, scale))


def place_halfnormal_prior(param_name: str, scale: float = 1.0):
"""
Samples a value from a half-normal distribution with the specified standard deviation (scale),
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
return numpyro.sample(param_name, halfnormal_dist(scale))


def place_uniform_prior(param_name: str,
low: float = None,
high: float = None,
X: jnp.ndarray = None):
"""
Samples a value from a uniform distribution with the specified low and high values,
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
d = uniform_dist(low, high, X)
return numpyro.sample(param_name, d)


def place_gamma_prior(param_name: str,
c: float = None,
r: float = None,
X: jnp.ndarray = None):
"""
Samples a value from a uniform distribution with the specified concentration (c) and rate (r) values,
and assigns it to a named random variable in the probabilistic model. Can be useful for defining prior mean functions
in structured Gaussian processes.
"""
d = gamma_dist(c, r, X)
return numpyro.sample(param_name, d)


def normal_dist(loc: float = None, scale: float = None
) -> numpyro.distributions.Distribution:
"""
Generate a Normal distribution based on provided center (loc) and standard deviation (scale) parameters.
If neither are provided, uses 0 and 1 by default. It can be used to pass custom priors to GP models.
"""
loc = loc if loc is not None else 0.0
scale = scale if scale is not None else 1.0
return numpyro.distributions.Normal(loc, scale)


def lognormal_dist(loc: float = None, scale: float = None) -> numpyro.distributions.Distribution:
"""
Generate a LogNormal distribution based on provided center (loc) and standard deviation (scale) parameters.
If neither are provided, uses 0 and 1 by default. It can be used to pass custom priors to GP models.
"""
loc = loc if loc is not None else 0.0
scale = scale if scale is not None else 1.0
return numpyro.distributions.LogNormal(loc, scale)


def halfnormal_dist(scale: float = None) -> numpyro.distributions.Distribution:
"""
Generate a half-normal distribution based on provided standard deviation (scale).
If none is provided, uses 1.0 by default. It can be used to pass custom priors to GP models.
"""
scale = scale if scale is not None else 1.0
return numpyro.distributions.HalfNormal(scale)


def gamma_dist(c: float = None,
r: float = None,
input_vec: jnp.ndarray = None
) -> numpyro.distributions.Distribution:
"""
Generate a Gamma distribution based on provided shape (c) and rate (r) parameters. If the shape (c) is not provided,
it attempts to infer it using the range of the input vector divided by 2. The rate parameter defaults to 1.0 if not provided.
It can be used to pass custom priors to GP models.
"""
if c is None:
if input_vec is not None:
c = (input_vec.max() - input_vec.min()) / 2
else:
raise ValueError("Provide either c or an input array")
if r is None:
r = 1.0
return numpyro.distributions.Gamma(c, r)


def uniform_dist(low: float = None,
high: float = None,
input_vec: jnp.ndarray = None
) -> numpyro.distributions.Distribution:
"""
Generate a Uniform distribution based on provided low and high bounds. If one of the bounds is not provided,
it attempts to infer the missing bound(s) using the minimum or maximum value from the input vector.
It can be used to pass custom priors to GP models.
:
Assign custom prior to kernel lengthscale during GP model initialization
>>> model = gpax.ExactGP(input_dm, kernel, lengthscale_prior_dist=gpax.priors.uniform_dist(1, 3))
Train as usual
>>> model.fit(rng_key, X, y)
"""
if (low is None or high is None) and input_vec is None:
raise ValueError(
"If 'low' or 'high' is not provided, an input array must be provided.")
low = low if low is not None else input_vec.min()
high = high if high is not None else input_vec.max()

return numpyro.distributions.Uniform(low, high)


def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Generates a function that, when invoked, samples from normal or log-normal distributions
for each parameter of the given deterministic function, except the first one.
Args:
func (Callable): The deterministic function for which to set normal or log-normal priors.
params_begin_with (int): Parameters to account for start from this number.
loc (float, optional): Mean of the normal or log-normal distribution. Defaults to 0.0.
scale (float, optional): Standard deviation of the normal or log-normal distribution. Defaults to 1.0.
Returns:
A function that, when invoked, returns a dictionary of sampled values
from normal or log-normal distributions for each parameter of the original function.
"""
place_prior = place_lognormal_prior if dist_type == 'lognormal' else place_normal_prior

# Get the names of the parameters of the function excluding the first one (dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[params_begin_with:]

def sample_priors() -> Dict[str, Union[float, Type[Callable]]]:
# Return a dictionary with normal priors for each parameter
return {name: place_prior(name, loc, scale) for name in params_names}

return sample_priors


def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Places normal priors over function parameters.
Args:
func (Callable): The deterministic function for which to set normal priors.
loc (float, optional): Mean of the normal distribution. Defaults to 0.0.
scale (float, optional): Standard deviation of the normal distribution. Defaults to 1.0.
Returns:
A function that, when invoked, returns a dictionary of sampled values
from normal distributions for each parameter of the original function.
"""
return auto_priors(func, 1, 'normal', loc, scale)


def auto_lognormal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Places log-normal priors over function parameters.
Args:
func (Callable): The deterministic function for which to set log-normal priors.
loc (float, optional): Mean of the log-normal distribution. Defaults to 0.0.
scale (float, optional): Standard deviation of the log-normal distribution. Defaults to 1.0.
Returns:
A function that, when invoked, returns a dictionary of sampled values
from log-normal distributions for each parameter of the original function.
"""
return auto_priors(func, 1, 'lognormal', loc, scale)

0 comments on commit 762a27f

Please sign in to comment.