Skip to content

Commit

Permalink
Loc-scale variant of Normal and LogNormal distribution (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash authored Jan 21, 2024
1 parent d8c9522 commit aebeaa1
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
18 changes: 9 additions & 9 deletions jaxampler/_src/rvs/lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,35 @@


class LogNormal(ContinuousRV):
def __init__(self, mu: Numeric | Any = 0.0, sigma: Numeric | Any = 1.0, name: Optional[str] = None) -> None:
shape, self._mu, self._sigma = jx_cast(mu, sigma)
def __init__(self, loc: Numeric | Any = 0.0, scale: Numeric | Any = 1.0, name: Optional[str] = None) -> None:
shape, self._loc, self._scale = jx_cast(loc, scale)
self.check_params()
super().__init__(name=name, shape=shape)

def check_params(self) -> None:
assert jnp.all(self._sigma > 0.0), "All sigma must be greater than 0.0"
assert jnp.all(self._scale > 0.0), "All sigma must be greater than 0.0"

@partial(jit, static_argnums=(0,))
def logpdf_x(self, x: Numeric) -> Numeric:
constants = -(jnp.log(self._sigma) + 0.5 * jnp.log(2 * jnp.pi))
constants = -(jnp.log(self._scale) + 0.5 * jnp.log(2 * jnp.pi))
logpdf_val = jnp.where(
x <= 0,
-jnp.inf,
constants - jnp.log(x) - (0.5 * jnp.power(self._sigma, -2)) * jnp.power((jnp.log(x) - self._mu), 2),
constants - jnp.log(x) - (0.5 * jnp.power(self._scale, -2)) * jnp.power((jnp.log(x) - self._loc), 2),
)
return logpdf_val

@partial(jit, static_argnums=(0,))
def logcdf_x(self, x: Numeric) -> Numeric:
return log_ndtr((jnp.log(x) - self._mu) / self._sigma)
return log_ndtr((jnp.log(x) - self._loc) / self._scale)

@partial(jit, static_argnums=(0,))
def cdf_x(self, x: Numeric) -> Numeric:
return ndtr((jnp.log(x) - self._mu) / self._sigma)
return ndtr((jnp.log(x) - self._loc) / self._scale)

@partial(jit, static_argnums=(0,))
def ppf_x(self, x: Numeric) -> Numeric:
return jnp.exp(self._mu + self._sigma * ndtri(x))
return jnp.exp(self._loc + self._scale * ndtri(x))

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
if key is None:
Expand All @@ -65,7 +65,7 @@ def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
return self.ppf_x(U)

def __repr__(self) -> str:
string = f"LogNormal(mu={self._mu}, sigma={self._sigma}"
string = f"LogNormal(loc={self._loc}, scale={self._scale}"
if self._name is not None:
string += f", name={self._name}"
string += ")"
Expand Down
40 changes: 30 additions & 10 deletions jaxampler/_src/rvs/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,63 @@


class Normal(ContinuousRV):
def __init__(self, mu: Numeric | Any = 0.0, sigma: Numeric | Any = 1.0, name: Optional[str] = None) -> None:
shape, self._mu, self._sigma = jx_cast(mu, sigma)
def __init__(self, loc: Numeric | Any = 0.0, scale: Numeric | Any = 1.0, name: Optional[str] = None) -> None:
shape, self._loc, self._scale = jx_cast(loc, scale)
self.check_params()
self._logZ = 0.0
super().__init__(name=name, shape=shape)

def check_params(self) -> None:
assert jnp.all(self._sigma > 0.0), "All sigma must be greater than 0.0"
assert jnp.all(self._scale > 0.0), "All sigma must be greater than 0.0"

@partial(jit, static_argnums=(0,))
def logpdf_x(self, x: Numeric) -> Numeric:
return jax_norm.logpdf(x, self._mu, self._sigma)
return jax_norm.logpdf(
x=x,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def logcdf_x(self, x: Numeric) -> Numeric:
return jax_norm.logcdf(x, self._mu, self._sigma)
return jax_norm.logcdf(
x=x,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def pdf_x(self, x: Numeric) -> Numeric:
return jax_norm.pdf(x, self._mu, self._sigma)
return jax_norm.pdf(
x=x,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def cdf_x(self, x: Numeric) -> Numeric:
return jax_norm.cdf(x, self._mu, self._sigma)
return jax_norm.cdf(
x=x,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def ppf_x(self, x: Numeric) -> Numeric:
return jax_norm.ppf(x, self._mu, self._sigma)
return jax_norm.ppf(
q=x,
loc=self._loc,
scale=self._scale,
)

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
if key is None:
key = self.get_key()
new_shape = shape + self._shape
return jax.random.normal(key, shape=new_shape) * self._sigma + self._mu
return self._loc + self._scale * jax.random.normal(key, shape=new_shape)

def __repr__(self) -> str:
string = f"Normal(mu={self._mu}, sigma={self._sigma}"
string = f"Normal(loc={self._loc}, scale={self._scale}"
if self._name is not None:
string += f", name={self._name}"
string += ")"
Expand Down
2 changes: 1 addition & 1 deletion tests/lognormal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
class TestLogNormal:
def test_invalid_params(self):
with pytest.raises(AssertionError):
LogNormal(sigma=-1.0)
LogNormal(scale=-1.0)

0 comments on commit aebeaa1

Please sign in to comment.