Skip to content

Commit

Permalink
Merge pull request scipy#6946 from evgenyzhurko/hypergeom-betaln
Browse files Browse the repository at this point in the history
ENH: hypergeom.logpmf in terms of betaln
  • Loading branch information
ev-br authored Jan 19, 2017
2 parents 97bce4d + cd00f6d commit bae6752
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
8 changes: 4 additions & 4 deletions scipy/stats/_discrete_distns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import division, print_function, absolute_import

from scipy import special
from scipy.special import entr, logsumexp, gammaln as gamln
from scipy.special import entr, logsumexp, betaln, gammaln as gamln
from scipy._lib._numpy_compat import broadcast_to

from numpy import floor, ceil, log, exp, sqrt, log1p, expm1, tanh, cosh, sinh
Expand Down Expand Up @@ -321,9 +321,9 @@ def _argcheck(self, M, n, N):
def _logpmf(self, k, M, n, N):
tot, good = M, n
bad = tot - good
return gamln(good+1) - gamln(good-k+1) - gamln(k+1) + gamln(bad+1) \
- gamln(bad-N+k+1) - gamln(N-k+1) - gamln(tot+1) + gamln(tot-N+1) \
+ gamln(N+1)
return betaln(good+1, 1) + betaln(bad+1,1) + betaln(tot-N+1, N+1)\
- betaln(k+1, good-k+1) - betaln(N-k+1,bad-N+k+1)\
- betaln(tot+1, 1)

def _pmf(self, k, M, n, N):
# same as the following but numerically more precise
Expand Down
34 changes: 34 additions & 0 deletions scipy/stats/tests/test_discrete_distns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import division, print_function, absolute_import

from scipy.stats import hypergeom, bernoulli
import numpy as np
from numpy.testing import assert_almost_equal, run_module_suite

def test_hypergeom_logpmf():
# symmetries test
# f(k,N,K,n) = f(n-k,N,N-K,n) = f(K-k,N,K,N-n) = f(k,N,n,K)
k = 5
N = 50
K = 10
n = 5
logpmf1 = hypergeom.logpmf(k,N,K,n)
logpmf2 = hypergeom.logpmf(n-k,N,N-K,n)
logpmf3 = hypergeom.logpmf(K-k,N,K,N-n)
logpmf4 = hypergeom.logpmf(k,N,n,K)
assert_almost_equal(logpmf1, logpmf2, decimal=12)
assert_almost_equal(logpmf1, logpmf3, decimal=12)
assert_almost_equal(logpmf1, logpmf4, decimal=12)

# test related distribution
# Bernoulli distribution if n = 1
k = 1
N = 10
K = 7
n = 1
hypergeom_logpmf = hypergeom.logpmf(k,N,K,n)
bernoulli_logpmf = bernoulli.logpmf(k,K/N)
assert_almost_equal(hypergeom_logpmf, bernoulli_logpmf, decimal=12)


if __name__ == "__main__":
run_module_suite()

0 comments on commit bae6752

Please sign in to comment.