Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampler utils #217

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
138 changes: 35 additions & 103 deletions enterprise_extensions/hypermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from enterprise import constants as const
from PTMCMCSampler.PTMCMCSampler import PTSampler as ptmcmc

from .sampler import JumpProposal, get_parameter_groups, save_runtime_info
from .sampler import (JumpProposal, get_parameter_groups,
save_runtime_info, BuildPriorDraw, EmpDistrDraw)


class HyperModel(object):
Expand Down Expand Up @@ -216,111 +217,42 @@ def setup_sampler(self, outdir='chains', resume=False, sample_nmodel=True,
sampler.jp = jp

# always add draw from prior
sampler.addProposalToCycle(jp.draw_from_prior, 5)
sampler.addProposalToCycle(BuildPriorDraw(self.params,
self.param_names[:-1], # ignore nmodel
name='draw_from_prior'), 5)

# try adding empirical proposals
if empirical_distr is not None:
print('Adding empirical proposals...\n')
sampler.addProposalToCycle(jp.draw_from_empirical_distr, 25)

# Red noise prior draw
if 'red noise' in self.snames:
print('Adding red noise prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_red_prior, 10)

# DM GP noise prior draw
if 'dm_gp' in self.snames:
print('Adding DM GP noise prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_dm_gp_prior, 10)

# DM annual prior draw
if 'dm_s1yr' in jp.snames:
print('Adding DM annual prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_dm1yr_prior, 10)

# DM dip prior draw
if 'dmexp' in '\t'.join(jp.snames):
print('Adding DM exponential dip prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_dmexpdip_prior, 10)

# DM cusp prior draw
if 'dm_cusp' in jp.snames:
print('Adding DM exponential cusp prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_dmexpcusp_prior, 10)

# DMX prior draw
if 'dmx_signal' in jp.snames:
print('Adding DMX prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_dmx_prior, 10)

# Chromatic GP noise prior draw
if 'chrom_gp' in self.snames:
print('Adding Chromatic GP noise prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_chrom_gp_prior, 10)

# SW prior draw
if 'gp_sw' in jp.snames:
print('Adding Solar Wind DM GP prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_dm_sw_prior, 10)

# Chromatic GP noise prior draw
if 'chrom_gp' in self.snames:
print('Adding Chromatic GP noise prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_chrom_gp_prior, 10)

# Ephemeris prior draw
if 'd_jupiter_mass' in self.param_names:
print('Adding ephemeris model prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_ephem_prior, 10)

# GWB uniform distribution draw
if np.any([('gw' in par and 'log10_A' in par) for par in self.param_names]):
print('Adding GWB uniform distribution draws...\n')
sampler.addProposalToCycle(jp.draw_from_gwb_log_uniform_distribution, 10)

# Dipole uniform distribution draw
if 'dipole_log10_A' in self.param_names:
print('Adding dipole uniform distribution draws...\n')
sampler.addProposalToCycle(jp.draw_from_dipole_log_uniform_distribution, 10)

# Monopole uniform distribution draw
if 'monopole_log10_A' in self.param_names:
print('Adding monopole uniform distribution draws...\n')
sampler.addProposalToCycle(jp.draw_from_monopole_log_uniform_distribution, 10)

# BWM prior draw
if 'bwm_log10_A' in self.param_names:
print('Adding BWM prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_bwm_prior, 10)

# FDM prior draw
if 'fdm_log10_A' in self.param_names:
print('Adding FDM prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_fdm_prior, 10)

# CW prior draw
if 'cw_log10_h' in self.param_names:
print('Adding CW prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_cw_log_uniform_distribution, 10)

# free spectrum prior draw
if np.any(['log10_rho' in par for par in self.param_names]):
print('Adding free spectrum prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_gw_rho_prior, 25)

# Prior distribution draw for parameters named GW
if any([str(p).split(':')[0] for p in list(self.params) if 'gw' in str(p)]):
print('Adding gw param prior draws...\n')
sampler.addProposalToCycle(jp.draw_from_par_prior(
par_names=[str(p).split(':')[0] for
p in list(self.params)
if 'gw' in str(p)]), 10)

# Model index distribution draw
if sample_nmodel:
if 'nmodel' in self.param_names:
print('Adding nmodel uniform distribution draws...\n')
sampler.addProposalToCycle(self.draw_from_nmodel_prior, 25)
print('Attempting to add empirical proposals...\n')
sampler.addProposalToCycle(EmpDistrDraw(jp.empirical_distr,
self.param_names[:-1], # ignore nmodel
name='draw_from_empirical_distr'), 10)

# list of typical signal names
snames = ['red noise', 'dm_gp', 'chrom_gp',
'dmx_signal', 'phys_ephem', 'bwm', 'fdm', 'cw', 'gp_sw',
'linear timing model', 'ecorr_sherman-morrison',
'measurement_noise', 'tnequad']

for sname in snames:
# adding prior draws
if (sname in jp.snames) and (len(jp.snames[sname]) >= 1):
print(f'Adding {sname} prior draws...\n')
param_names = [p.name for p in jp.snames[sname]]
sampler.addProposalToCycle(BuildPriorDraw(self.params,
param_names,
name='draw_from_'+sname), 10)

# adding other signal draws
param_names = ['dipole', 'monopole', 'hd', 'log10_rho',
'dmexp', 'dm_cusp', 'dm_s1yr']

for p in param_names:
par_names = [par for par in self.param_names if p in par]
if len(par_names) >= 1:
print(f'Adding {p} prior draws...\n')
sampler.addProposalToCycle(BuildPriorDraw(self.params, par_names,
name='draw_from_'+p), 10)

return sampler

Expand Down
Loading