You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
WARNING:absl:grids_and_weights is None, this will create a dummy integration that always returns 0. The gradient of this dummy integration would still work.
0.0
import jax
import jax.numpy as jnp
from jaxtyping import Float32, Array
from autofd import function
import autofd.operators as o
Define the normal distribution q_lambda(theta) with dependency on lambda
@function
def q_lambda(theta: Float32[Array, ""], lambda_: Float32[Array, ""]) -> Float32[Array, ""]:
mu, sigma = 0.0, 1.0 # Mean and standard deviation for the normal distribution
return jnp.exp(-(theta - lambda_) ** 2 / (2 * (sigma ) ** 2)) / ((sigma ) * jnp.sqrt(2 * jnp.pi))
Define the likelihood p(x|theta)
@function
def p_x_given_theta(x: Float32[Array, ""], theta: Float32[Array, ""]) -> Float32[Array, ""]:
return jnp.exp(-(x - theta) ** 2 / 2) / jnp.sqrt(2 * jnp.pi)
Define the functional to compute E_{q_lambda}[p(x|theta)]
@function
def E_q_p(x: Float32[Array, ""], lambda_: Float32[Array, ""]) -> Float32[Array, ""]:
@function # Decorate integrand with @function
def integrand(theta: Float32[Array, ""]) -> Float32[Array, ""]: # Explicit return type
return p_x_given_theta(x, theta) * q_lambda(theta, lambda_)
return o.integrate(integrand)
x = jnp.array(50.0)
lambda_ = jnp.array(1.0) # Initial lambda value
result = E_q_p(x, lambda_)
The text was updated successfully, but these errors were encountered: