From b06bdfde7805ed423171bc5dd3c6c66f9b6729f6 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Sun, 21 Jan 2024 23:16:11 +0500 Subject: [PATCH 1/2] loc-scale variant of chi-sq distribution --- jaxampler/_src/rvs/chi2.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/jaxampler/_src/rvs/chi2.py b/jaxampler/_src/rvs/chi2.py index e5ecbd5..c14fd79 100644 --- a/jaxampler/_src/rvs/chi2.py +++ b/jaxampler/_src/rvs/chi2.py @@ -27,8 +27,8 @@ class Chi2(ContinuousRV): - def __init__(self, nu: Numeric | Any, name: Optional[str] = None) -> None: - shape, self._nu = jx_cast(nu) + def __init__(self, nu: Numeric | Any, loc: Numeric = 0.0, scale: Numeric = 1.0, name: Optional[str] = None) -> None: + shape, self._nu, self._loc, self._scale = jx_cast(nu, loc, scale) self.check_params() super().__init__(name=name, shape=shape) @@ -37,19 +37,39 @@ def check_params(self) -> None: @partial(jit, static_argnums=(0,)) def logpdf_x(self, x: Numeric) -> Numeric: - return jax_chi2.logpdf(x, self._nu) + return jax_chi2.logpdf( + x=x, + df=self._nu, + loc=self._loc, + scale=self._scale, + ) @partial(jit, static_argnums=(0,)) def pdf_x(self, x: Numeric) -> Numeric: - return jax_chi2.pdf(x, self._nu) + return jax_chi2.pdf( + x=x, + df=self._nu, + loc=self._loc, + scale=self._scale, + ) @partial(jit, static_argnums=(0,)) def logcdf_x(self, x: Numeric) -> Numeric: - return jax_chi2.logcdf(x, self._nu) + return jax_chi2.logcdf( + x=x, + df=self._nu, + loc=self._loc, + scale=self._scale, + ) @partial(jit, static_argnums=(0,)) def cdf_x(self, x: Numeric) -> Numeric: - return jax_chi2.cdf(x, self._nu) + return jax_chi2.cdf( + x=x, + df=self._nu, + loc=self._loc, + scale=self._scale, + ) @partial(jit, static_argnums=(0,)) def logppf_x(self, x: Numeric) -> Numeric: @@ -59,10 +79,10 @@ def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array: if key is None: key = self.get_key() new_shape = shape + self._shape - return jax.random.chisquare(key, self._nu, shape=new_shape) + return self._loc + self._scale * jax.random.chisquare(key, self._nu, shape=new_shape) def __repr__(self) -> str: - string = f"Chi2(nu={self._nu}" + string = f"Chi2(nu={self._nu}, loc={self._loc}, scale={self._scale}" if self._name is not None: string += f", name={self._name}" string += ")" From a45f4f74ea719775632071359bf1345483834bcb Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 22 Jan 2024 01:01:57 +0500 Subject: [PATCH 2/2] loc-scale types fixed --- jaxampler/_src/rvs/chi2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/jaxampler/_src/rvs/chi2.py b/jaxampler/_src/rvs/chi2.py index c14fd79..ae5a4b7 100644 --- a/jaxampler/_src/rvs/chi2.py +++ b/jaxampler/_src/rvs/chi2.py @@ -27,7 +27,13 @@ class Chi2(ContinuousRV): - def __init__(self, nu: Numeric | Any, loc: Numeric = 0.0, scale: Numeric = 1.0, name: Optional[str] = None) -> None: + def __init__( + self, + nu: Numeric | Any, + loc: Numeric | Any = 0.0, + scale: Numeric | Any = 1.0, + name: Optional[str] = None, + ) -> None: shape, self._nu, self._loc, self._scale = jx_cast(nu, loc, scale) self.check_params() super().__init__(name=name, shape=shape)