Skip to content

Commit

Permalink
Loc-scale variant of Pareto distribution (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash authored Jan 21, 2024
1 parent aebeaa1 commit 79c8794
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 26 deletions.
38 changes: 25 additions & 13 deletions jaxampler/_src/rvs/pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,53 +27,65 @@


class Pareto(ContinuousRV):
def __init__(self, alpha: Numeric | Any, scale: Numeric | Any, name: Optional[str] = None) -> None:
shape, self._alpha, self._scale = jx_cast(alpha, scale)
def __init__(
self, a: Numeric | Any, loc: Numeric | Any = 0.0, scale: Numeric | Any = 1.0, name: Optional[str] = None
) -> None:
shape, self._a, self._loc, self._scale = jx_cast(a, loc, scale)
self.check_params()
super().__init__(name=name, shape=shape)

def check_params(self) -> None:
assert jnp.all(self._alpha > 0.0), "alpha must be greater than 0"
assert jnp.all(self._a > 0.0), "alpha must be greater than 0"
assert jnp.all(self._scale > 0.0), "scale must be greater than 0"

@partial(jit, static_argnums=(0,))
def logpdf_x(self, x: Numeric) -> Numeric:
return jax_pareto.logpdf(x, self._alpha, scale=self._scale)
return jax_pareto.logpdf(
x=x,
b=self._a,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def pdf_x(self, x: Numeric) -> Numeric:
return jax_pareto.pdf(x, self._alpha, scale=self._scale)
return jax_pareto.pdf(
x=x,
b=self._a,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def logcdf_x(self, x: Numeric) -> Numeric:
return jnp.where(
self._scale <= x,
jnp.log1p(-jnp.power(self._scale / x, self._alpha)),
self._loc + self._scale <= x,
jnp.log1p(-jnp.power(self._scale / (x - self._loc), self._a)),
-jnp.inf,
)

@partial(jit, static_argnums=(0,))
def logppf_x(self, x: Numeric) -> Numeric:
def ppf_x(self, x: Numeric) -> Numeric:
conditions = [
x < 0.0,
(0.0 <= x) & (x < 1.0),
1.0 <= x,
]
choices = [
-jnp.inf,
jnp.log(self._scale) - (1.0 / self._alpha) * jnp.log(1 - x),
jnp.log(1.0),
0.0,
self._loc + jnp.exp(jnp.log(self._scale) - (1.0 / self._a) * jnp.log(1 - x)),
1.0,
]
return jnp.select(conditions, choices)

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
if key is None:
key = self.get_key()
new_shape = shape + self._shape
return jax.random.pareto(key, self._alpha, shape=new_shape) * self._scale
return self._loc + self._scale * jax.random.pareto(key, self._a, shape=new_shape)

def __repr__(self) -> str:
string = f"Pareto(alpha={self._alpha}, scale={self._scale}"
string = f"Pareto(a={self._a}, loc={self._loc}, scale={self._scale}"
if self._name is not None:
string += f", name={self._name}"
string += ")"
Expand Down
25 changes: 12 additions & 13 deletions tests/pareto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,40 @@


class TestPareto:

def test_pdf(self):
assert Pareto(alpha=0.5, scale=0.1).pdf_x(1) == 0.15811388
assert Pareto(a=0.5, scale=0.1).pdf_x(1) == 0.15811388

def test_shapes(self):
assert jnp.allclose(Pareto(alpha=[0.5, 0.1], scale=[0.1, 0.2]).pdf_x(1), jnp.array([0.15811388, 0.08513397]))
assert jnp.allclose(Pareto(a=[0.5, 0.1], scale=[0.1, 0.2]).pdf_x(1), jnp.array([0.15811388, 0.08513397]))
assert jnp.allclose(
Pareto(alpha=[0.5, 0.1, 0.2], scale=[0.1, 0.2, 0.2]).pdf_x(1),
jnp.array([0.15811388, 0.08513397, 0.14495593]))
Pareto(a=[0.5, 0.1, 0.2], scale=[0.1, 0.2, 0.2]).pdf_x(1), jnp.array([0.15811388, 0.08513397, 0.14495593])
)

def test_imcompatible_shapes(self):
with pytest.raises(ValueError):
Pareto(alpha=[0.5, 0.1, 0.9], scale=[0.1, 0.2])
Pareto(a=[0.5, 0.1, 0.9], scale=[0.1, 0.2])

def test_out_of_bound(self):
# when x is less than zero
assert jnp.allclose(Pareto(alpha=0.5, scale=0.1).pdf_x(-1), 0)
assert jnp.allclose(Pareto(a=0.5, scale=0.1).pdf_x(-1), 0)
# when x is greater than scale
with pytest.raises(AssertionError):
assert jnp.allclose(Pareto(alpha=0.5, scale=0.1).pdf_x(11), 0)
assert jnp.allclose(Pareto(a=0.5, scale=0.1).pdf_x(11), 0)
# when scale is negative
with pytest.raises(AssertionError):
Pareto(alpha=0.5, scale=-1)
Pareto(a=0.5, scale=-1)
# when alpha is negative
with pytest.raises(AssertionError):
Pareto(alpha=-1, scale=2)
Pareto(a=-1, scale=2)

def test_cdf_x(self):
# when x is less than scale
assert Pareto(alpha=0.5, scale=0.1).cdf_x(0.01) == 0
assert Pareto(a=0.5, scale=0.1).cdf_x(0.01) == 0
# when x is greater than scale
assert Pareto(alpha=0.5, scale=0.1).cdf_x(1) == 0.6837722
assert Pareto(a=0.5, scale=0.1).cdf_x(1) == 0.6837722

def test_rvs(self):
tpl_rvs = Pareto(alpha=0.1, scale=0.1)
tpl_rvs = Pareto(a=0.1, scale=0.1)
shape = (3, 4)

# with key
Expand Down

0 comments on commit 79c8794

Please sign in to comment.