-
Notifications
You must be signed in to change notification settings - Fork 387
/
multinomial.py
49 lines (41 loc) · 1.19 KB
/
multinomial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Code by Tae-Hwan Hung(@graykode)
https://en.wikipedia.org/wiki/Multinomial_distribution
3-Class Example
"""
import numpy as np
from matplotlib import pyplot as plt
import operator as op
from functools import reduce
def factorial(n):
return reduce(op.mul, range(1, n + 1), 1)
def const(n, a, b, c):
"""
return n! / a! b! c!, where a+b+c == n
"""
assert a + b + c == n
numer = factorial(n)
denom = factorial(a) * factorial(b) * factorial(c)
return numer / denom
def multinomial(n):
"""
:param x : list, sum(x) should be `n`
:param n : number of trial
:param p: list, sum(p) should be `1`
"""
# get all a,b,c where a+b+c == n, a<b<c
ls = []
for i in range(1, n + 1):
for j in range(i, n + 1):
for k in range(j, n + 1):
if i + j + k == n:
ls.append([i, j, k])
y = [const(n, l[0], l[1], l[2]) for l in ls]
x = np.arange(len(y))
return x, y, np.mean(y), np.std(y)
for n_experiment in [20, 21, 22]:
x, y, u, s = multinomial(n_experiment)
plt.scatter(x, y, label=r'$trial=%d$' % (n_experiment))
plt.legend()
plt.savefig('graph/multinomial.png')
plt.show()