-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathstan_utility.py
110 lines (100 loc) · 4.94 KB
/
stan_utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#Copyright 2017 Columbia University, 2017 Jeff Alstott
#
#Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
#
#1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
#
#2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
#
#3. Neither the name of the copyright holder nor the name INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR IABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pystan
import pickle
import numpy
def check_div(fit):
"""Check transitions that ended with a divergence"""
sampler_params = fit.get_sampler_params(inc_warmup=False)
divergent = [x for y in sampler_params for x in y['divergent__']]
n = sum(divergent)
N = len(divergent)
print('{} of {} iterations ended with a divergence ({}%)'.format(n, N,
100 * n / N))
if n > 0:
print('Try running with larger adapt_delta to remove the divergences')
def check_treedepth(fit, max_depth = 10):
"""Check transitions that ended prematurely due to maximum tree depth limit"""
sampler_params = fit.get_sampler_params(inc_warmup=False)
depths = [x for y in sampler_params for x in y['treedepth__']]
n = sum(1 for x in depths if x == max_depth)
N = len(depths)
if n>0:
c = '31'
else:
c = '32'
print(("\x1b[{}m\"{} of {} iterations saturated the maximum tree depth of {}"
+ " ({}%)\"\x1b[0m").format(c, n, N, max_depth, 100 * n / N))
if n > 0:
print("\x1b[{}m\"Run again with max_depth set to a larger value to avoid saturation\"\x1b[0m".format(c))
def check_energy(fit):
"""Checks the energy Bayesian fraction of missing information (E-BFMI)"""
sampler_params = fit.get_sampler_params(inc_warmup=False)
for chain_num, s in enumerate(sampler_params):
energies = s['energy__']
numer = sum((energies[i] - energies[i - 1])**2 for i in range(1, len(energies))) / len(energies)
denom = numpy.var(energies)
ratio = numer/denom
if ratio < 0.2:
c = '31'
else:
c = '32'
print("\x1b[{}m\"Chain {}: E-BFMI = {}\"\x1b[0m".format(c, chain_num, numer / denom))
if ratio< 0.2:
print("\x1b[{}m\"E-BFMI below 0.2 indicates you may need to reparameterize your model\"\x1b[0m".format(c))
def _by_chain(unpermuted_extraction):
num_chains = len(unpermuted_extraction[0])
result = [[] for _ in range(num_chains)]
for c in range(num_chains):
for i in range(len(unpermuted_extraction)):
result[c].append(unpermuted_extraction[i][c])
return numpy.array(result)
def _shaped_ordered_params(fit):
ef = fit.extract(permuted=False, inc_warmup=False) # flattened, unpermuted, by (iteration, chain)
ef = _by_chain(ef)
ef = ef.reshape(-1, len(ef[0][0]))
ef = ef[:, 0:len(fit.flatnames)] # drop lp__
shaped = {}
idx = 0
for dim, param_name in zip(fit.par_dims, fit.extract().keys()):
length = int(numpy.prod(dim))
shaped[param_name] = ef[:,idx:idx + length]
shaped[param_name].reshape(*([-1] + dim))
idx += length
return shaped
def partition_div(fit):
""" Returns parameter arrays separated into divergent and non-divergent transitions"""
sampler_params = fit.get_sampler_params(inc_warmup=False)
div = numpy.concatenate([x['divergent__'] for x in sampler_params]).astype('int')
params = _shaped_ordered_params(fit)
nondiv_params = dict((key, params[key][div == 0]) for key in params)
div_params = dict((key, params[key][div == 1]) for key in params)
return nondiv_params, div_params
def compile_model(filename, model_name=None, **kwargs):
"""This will automatically cache models - great if you're just running a
script on the command line.
See http://pystan.readthedocs.io/en/latest/avoiding_recompilation.html"""
from hashlib import md5
with open(filename) as f:
model_code = f.read()
code_hash = md5(model_code.encode('ascii')).hexdigest()
if model_name is None:
cache_fn = 'cached-model-{}.pkl'.format(code_hash)
else:
cache_fn = 'cached-{}-{}.pkl'.format(model_name, code_hash)
try:
sm = pickle.load(open(cache_fn, 'rb'))
except:
sm = pystan.StanModel(model_code=model_code)
with open(cache_fn, 'wb') as f:
pickle.dump(sm, f)
else:
print("Using cached StanModel")
return sm