Skip to content

Commit

Permalink
Add truncated distribution for redshift smearing (#19)
Browse files Browse the repository at this point in the history
* small update

* add truncated cauchy and normal distribution
  • Loading branch information
echaussidon authored Nov 17, 2022
1 parent 93ec62f commit c55a6b8
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 27 deletions.
4 changes: 2 additions & 2 deletions desi/from_box_to_desi_cutsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def apply_radial_mask(cutsky, zmin=0., zmax=6., nz_filename='nz_qso_final.dat',
mask_radial = TabulatedRadialMask(z=zbin_mid, nbar=n_z / volume, interp_order=2, zrange=(zmin, zmax))

if apply_redshift_smearing:
from mockfactory.desi import RedshiftSmearing
from mockfactory.desi import TracerRedshiftSmearing
# Note: apply redshift smearing before the n(z) match since n(z) is what we observe (ie) containing the smearing
cutsky['Z'] = cutsky['Z'] + RedshiftSmearing(tracer=tracer_smearing).sample(cutsky['Z'], seed=seed + 13)
cutsky['Z'] = cutsky['Z'] + TracerRedshiftSmearing(tracer=tracer_smearing).sample(cutsky['Z'], seed=seed + 13)

return cutsky[mask_radial(cutsky['Z'], seed=seed)]

Expand Down
2 changes: 1 addition & 1 deletion mockfactory/desi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .brick_pixel_quantities import get_brick_pixel_quantities
from .footprint import is_in_desi_footprint
from .redshift_smearing import QSORedshiftSmearing, RedshiftSmearing
from .redshift_smearing import TracerRedshiftSmearing
50 changes: 27 additions & 23 deletions mockfactory/desi/redshift_smearing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,34 @@ def TracerRedshiftSmearingRVS(tracer='QSO', fn=None):
table = vstack([table[table['mean_z'] < tt['mean_z'][0]], tt])
else:
table = tt

rvs_nongaussian, rvs_gaussian, laz = [], [], []
for iz, z in enumerate(table['mean_z']):
if tracer == 'QSO':
A0, x0, s0, sg, la = table['val_fit'][iz]
s0, sg = s0 / np.sqrt(2), sg / np.sqrt(2)
rvs_nongaussian.append(stats.laplace(x0, s0))
rvs_gaussian.append(stats.norm(x0, sg))
else:
elif tracer == 'LRG':
sigma, x0, p, mu, la = table['val_fit'][iz]
rvs_nongaussian.append(stats.cauchy(scale=p / 2, loc=mu))
rvs_gaussian.append(stats.norm(scale=sigma, loc=x0))
elif tracer in ['ELG', 'BGS']:
sigma, x0, p, mu, la = table['val_fit'][iz]
# need to use truncated cauchy (utils.trunccauchy) (range=[a, b]) instead stats.cauchy
# do not use scipy.stats.truncnorm (strange behavior and do not work here
# cannot use scale and loc.. --> sc and lo instead :)
""" TO DO HERE by Jiaxi --> can split ELG and BGS if they not have the same range"""
trunc = 150
rvs_nongaussian.append(utils.trunccauchy(a=-trunc, b=trunc).freeze(sc=p / 2, lo=mu))
rvs_gaussian.append(utils.truncnorm(a=-trunc, b=trunc).freeze(sc=sigma, lo=x0))
laz.append(la)
laz = np.array(laz)

if tracer == 'QSO':

def dztransform(z, dz):
return dz / (constants.c / 1e3) / (1. + z) # file unit is dz (1 + z) c [km / s]

else:

def dztransform(z, dz):
return dz / (constants.c / 1e3) * (1. + z) # file unit is c dz / (1 + z) [km / s]

Expand All @@ -79,6 +86,7 @@ def TracerRedshiftSmearing(tracer='QSO', fn=None):
dzscale = 200
else:
raise ValueError(f'{tracer} redshift smearing does not exist')

return RVS2DRedshiftSmearing.average([RVS2DRedshiftSmearing(z, rv, dzsize=10000, dzscale=dzscale, dztransform=dztransform) for rv in rvs], weights=weights)


Expand All @@ -88,38 +96,32 @@ def TracerRedshiftSmearing(tracer='QSO', fn=None):
from matplotlib import pyplot as plt
from mockfactory import setup_logging

setup_logging()

def parse_args():
parser = ArgumentParser()
parser.add_argument(
"--tracer", help="the tracer for redshift smearing: QSO, LRG, ELG, BGS",
type=str, default='QSO', required=True,
)
args = None
args = parser.parse_args()
return args
def collect_argparser():
parser = ArgumentParser(description="Load and display the redshift smearing for args.tracer")
parser.add_argument("--tracer", type=str, required=True, default='QSO',
help="the tracer for redshift smearing: QSO, LRG, ELG, BGS")
return parser.parse_args()

args = parse_args()
tracer = args.tracer
setup_logging()
args = collect_argparser()

# Instantiate redshift smearing class
rs = TracerRedshiftSmearing(tracer=tracer)
rs = TracerRedshiftSmearing(tracer=args.tracer)

# Load random variates, to get pdf to compare to
z, rvs, weights, dztransform = TracerRedshiftSmearingRVS(tracer=tracer)
z, rvs, weights, dztransform = TracerRedshiftSmearingRVS(tracer=args.tracer)

# z slices where to plot distributions
lz = np.linspace(z[0], z[-1], 15)
# Tabulated dz where to evaluate pdf
if tracer == 'QSO':
if args.tracer == 'QSO':
dvscale = 5e3
elif tracer in ['ELG', 'BGS']:
elif args.tracer in ['ELG', 'BGS']:
dvscale = 150
elif tracer == 'LRG':
elif args.tracer == 'LRG':
dvscale = 200

#unit = 'dz'
# unit = 'dz'
unit = 'dv [km/s]'

fig, lax = plt.subplots(3, 5, figsize=(20, 10))
Expand Down Expand Up @@ -152,4 +154,6 @@ def parse_args():
ax.set_xlim(xmin, xmax)

if rs.mpicomm.rank == 0:
plt.tight_layout()
plt.savefig('test.png')
plt.show()
2 changes: 1 addition & 1 deletion mockfactory/make_survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,7 +1902,7 @@ def average(cls, *others, weights=None):
weights = np.asarray(weights, dtype='f8')
weights = weights / np.sum(weights, axis=0)
new = others[0].copy()
for other in others:
for i, other in enumerate(others):
if not np.allclose(other.dz, new.dz):
raise ValueError('Input redshift smearing pdfs must have same support to be averaged')
# Remove first / end points (typically 0, 1) to avoid potential warning with infs in _support_transform
Expand Down
84 changes: 84 additions & 0 deletions mockfactory/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A few utilities."""

import numpy as np
from scipy import stats

from mpytools.utils import mkdir, setup_logging, BaseMetaClass, BaseClass

Expand Down Expand Up @@ -167,3 +168,86 @@ def vector_projection(vector, direction):
projection = projection[:, None] * direction

return projection


class trunccauchy(stats.rv_continuous):
"""
A truncated cauchy continuous random variable, where the range ``[a, b]`` is user-provided
In order to have correct cfd and able to draw sample, we just need to redine correctly the pdf. This is simple done by truncated the stats.cauchy.pdf
and then divided by the integral of the pdf in the restriced area (simply stats.cauchy.cdf(b) - stats.cauchy.cdf(a)). Implement only the pdf is not super
efficient, especially to draw samples with .rvs(). That is why we also implement ppf (used to draw) and cdf which is used to compute ppf doing the inversion via interpolation.
Remark: For proper implementation, once should use logpdf as in truncnorm to avoid division by zero when the truncation is done far from the core of the distribution.
Warning: loc and scale are built-in keywords. One cannot use them in _pdf ! Use lo and sc instead.
Example:
'''
e = trunccauchy(a=-1, b=1, shapes='lo, sc')
e = e.freeze(lo=0, sc=0.1) # to freeze the parameter lo and sc
samples = e.rvs(size=1000)
'''
References:
* https://docs.scipy.org/doc/scipy/tutorial/stats.html#making-a-continuous-distribution-i-e-subclassing-rv-continuous
* truncated normal function already implemented (see also source code): https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html
"""

def _argcheck(*args):
""" by default _argcheck return true only for args > 0, this is not our case since we use loc for scipy.cauchy which could be negative..."""
return True

def _pdf(self, x, lo, sc):
""" Without any optimzation, pdf is the only function that we need to define a prbability law. """
return stats.cauchy.pdf(x, loc=lo, scale=sc) / (stats.cauchy.cdf(self.b, loc=lo, scale=sc) - stats.cauchy.cdf(self.a, loc=lo, scale=sc))

def _cdf(self, x, lo, sc):
""" Need to implement cdf and not only the pdf to compute ppf efficiently ! """
cdf = stats.cauchy.cdf(x, loc=lo, scale=sc) - stats.cauchy.cdf(self.a, loc=lo, scale=sc)
cdf[x < self.a] = 0
cdf[x > self.b] = stats.cauchy.cdf(self.b, loc=lo, scale=sc) - stats.cauchy.cdf(self.a, loc=lo, scale=sc)
cdf /= (stats.cauchy.cdf(self.b, loc=lo, scale=sc) - stats.cauchy.cdf(self.a, loc=lo, scale=sc))
return cdf

def _ppf(self, q, lo, sc):
""" Need to implement ppf and not only the pdf adn cdf if you want to draw quickly sample with rvs(). To speed up the inversion, we use interpolation.
Direct computation may be more efficient. """
from scipy.interpolate import interp1d
x_interp = np.linspace(self.a, self.b, 1000)
cdf = self._cdf(x_interp, lo=lo, sc=sc)
return interp1d(cdf, x_interp, kind='cubic')(q)


class truncnorm(stats.rv_continuous):
"""
A truncated normal continuous random variable, where the range ``[a, b]`` is user-provided.
Similar remark than `utils.trunccauchy`.
Note: I implemented a new truncnorm function to have exactly the same behaviour than `utils.trunccauchy` instead of used `scipy.stats.truncnorm`
"""

def _argcheck(*args):
""" by default _argcheck return true only for args > 0, this is not our case since we use loc for scipy.norm which could be negative..."""
return True

def _pdf(self, x, lo, sc):
""" Without any optimzation, pdf is the only function that we need to define a prbability law. """
return stats.norm.pdf(x, loc=lo, scale=sc) / (stats.norm.cdf(self.b, loc=lo, scale=sc) - stats.norm.cdf(self.a, loc=lo, scale=sc))

def _cdf(self, x, lo, sc):
""" Need to implement cdf and not only the pdf to compute ppf efficiently ! """
cdf = stats.norm.cdf(x, loc=lo, scale=sc) - stats.norm.cdf(self.a, loc=lo, scale=sc)
cdf[x < self.a] = 0
cdf[x > self.b] = stats.norm.cdf(self.b, loc=lo, scale=sc) - stats.norm.cdf(self.a, loc=lo, scale=sc)
cdf /= (stats.norm.cdf(self.b, loc=lo, scale=sc) - stats.norm.cdf(self.a, loc=lo, scale=sc))
return cdf

def _ppf(self, q, lo, sc):
""" Need to implement ppf and not only the pdf adn cdf if you want to draw quickly sample with rvs(). To speed up the inversion, we use interpolation. """
from scipy.interpolate import interp1d
x_interp = np.linspace(self.a, self.b, 1000)
cdf = self._cdf(x_interp, lo=lo, sc=sc)
return interp1d(cdf, x_interp, kind='cubic')(q)

0 comments on commit c55a6b8

Please sign in to comment.