diff --git a/jaxampler/_src/rvs/lognormal.py b/jaxampler/_src/rvs/lognormal.py index c63c994..0267351 100644 --- a/jaxampler/_src/rvs/lognormal.py +++ b/jaxampler/_src/rvs/lognormal.py @@ -27,35 +27,35 @@ class LogNormal(ContinuousRV): - def __init__(self, mu: Numeric | Any = 0.0, sigma: Numeric | Any = 1.0, name: Optional[str] = None) -> None: - shape, self._mu, self._sigma = jx_cast(mu, sigma) + def __init__(self, loc: Numeric | Any = 0.0, scale: Numeric | Any = 1.0, name: Optional[str] = None) -> None: + shape, self._loc, self._scale = jx_cast(loc, scale) self.check_params() super().__init__(name=name, shape=shape) def check_params(self) -> None: - assert jnp.all(self._sigma > 0.0), "All sigma must be greater than 0.0" + assert jnp.all(self._scale > 0.0), "All sigma must be greater than 0.0" @partial(jit, static_argnums=(0,)) def logpdf_x(self, x: Numeric) -> Numeric: - constants = -(jnp.log(self._sigma) + 0.5 * jnp.log(2 * jnp.pi)) + constants = -(jnp.log(self._scale) + 0.5 * jnp.log(2 * jnp.pi)) logpdf_val = jnp.where( x <= 0, -jnp.inf, - constants - jnp.log(x) - (0.5 * jnp.power(self._sigma, -2)) * jnp.power((jnp.log(x) - self._mu), 2), + constants - jnp.log(x) - (0.5 * jnp.power(self._scale, -2)) * jnp.power((jnp.log(x) - self._loc), 2), ) return logpdf_val @partial(jit, static_argnums=(0,)) def logcdf_x(self, x: Numeric) -> Numeric: - return log_ndtr((jnp.log(x) - self._mu) / self._sigma) + return log_ndtr((jnp.log(x) - self._loc) / self._scale) @partial(jit, static_argnums=(0,)) def cdf_x(self, x: Numeric) -> Numeric: - return ndtr((jnp.log(x) - self._mu) / self._sigma) + return ndtr((jnp.log(x) - self._loc) / self._scale) @partial(jit, static_argnums=(0,)) def ppf_x(self, x: Numeric) -> Numeric: - return jnp.exp(self._mu + self._sigma * ndtri(x)) + return jnp.exp(self._loc + self._scale * ndtri(x)) def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array: if key is None: @@ -65,7 +65,7 @@ def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array: return self.ppf_x(U) def __repr__(self) -> str: - string = f"LogNormal(mu={self._mu}, sigma={self._sigma}" + string = f"LogNormal(loc={self._loc}, scale={self._scale}" if self._name is not None: string += f", name={self._name}" string += ")" diff --git a/jaxampler/_src/rvs/normal.py b/jaxampler/_src/rvs/normal.py index 2e7516f..e6825a8 100644 --- a/jaxampler/_src/rvs/normal.py +++ b/jaxampler/_src/rvs/normal.py @@ -27,43 +27,63 @@ class Normal(ContinuousRV): - def __init__(self, mu: Numeric | Any = 0.0, sigma: Numeric | Any = 1.0, name: Optional[str] = None) -> None: - shape, self._mu, self._sigma = jx_cast(mu, sigma) + def __init__(self, loc: Numeric | Any = 0.0, scale: Numeric | Any = 1.0, name: Optional[str] = None) -> None: + shape, self._loc, self._scale = jx_cast(loc, scale) self.check_params() self._logZ = 0.0 super().__init__(name=name, shape=shape) def check_params(self) -> None: - assert jnp.all(self._sigma > 0.0), "All sigma must be greater than 0.0" + assert jnp.all(self._scale > 0.0), "All sigma must be greater than 0.0" @partial(jit, static_argnums=(0,)) def logpdf_x(self, x: Numeric) -> Numeric: - return jax_norm.logpdf(x, self._mu, self._sigma) + return jax_norm.logpdf( + x=x, + loc=self._loc, + scale=self._scale, + ) @partial(jit, static_argnums=(0,)) def logcdf_x(self, x: Numeric) -> Numeric: - return jax_norm.logcdf(x, self._mu, self._sigma) + return jax_norm.logcdf( + x=x, + loc=self._loc, + scale=self._scale, + ) @partial(jit, static_argnums=(0,)) def pdf_x(self, x: Numeric) -> Numeric: - return jax_norm.pdf(x, self._mu, self._sigma) + return jax_norm.pdf( + x=x, + loc=self._loc, + scale=self._scale, + ) @partial(jit, static_argnums=(0,)) def cdf_x(self, x: Numeric) -> Numeric: - return jax_norm.cdf(x, self._mu, self._sigma) + return jax_norm.cdf( + x=x, + loc=self._loc, + scale=self._scale, + ) @partial(jit, static_argnums=(0,)) def ppf_x(self, x: Numeric) -> Numeric: - return jax_norm.ppf(x, self._mu, self._sigma) + return jax_norm.ppf( + q=x, + loc=self._loc, + scale=self._scale, + ) def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array: if key is None: key = self.get_key() new_shape = shape + self._shape - return jax.random.normal(key, shape=new_shape) * self._sigma + self._mu + return self._loc + self._scale * jax.random.normal(key, shape=new_shape) def __repr__(self) -> str: - string = f"Normal(mu={self._mu}, sigma={self._sigma}" + string = f"Normal(loc={self._loc}, scale={self._scale}" if self._name is not None: string += f", name={self._name}" string += ")" diff --git a/tests/lognormal_test.py b/tests/lognormal_test.py index edb1481..d0c5f9b 100644 --- a/tests/lognormal_test.py +++ b/tests/lognormal_test.py @@ -24,4 +24,4 @@ class TestLogNormal: def test_invalid_params(self): with pytest.raises(AssertionError): - LogNormal(sigma=-1.0) + LogNormal(scale=-1.0)