Skip to content

Commit

Permalink
DOC added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Nov 18, 2023
1 parent eada628 commit 1416586
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions jax_galsim/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,22 @@ def _shoot(self, photons, rng):
fluxes = jnp.array(
[obj.positive_flux + obj.negative_flux for obj in self.obj_list]
)
# for a sum of objects, we use a slightly different approach than galsim
# galsim uses a binomial distribution to compute the number of photons per object
# we take an equivalent but different approach in order to use fixed size arrays
# of photons. it means we draw more photons but the code is JIT compilable and a bit simpler
#
# this all works as follows:
#
# - for each photon, we draw from a categorical distribution with probabilities
# proportional to the total absolute fluxes of the objects.
# - we then shoot the photons from each object and rescale the fluxes (see comment below)
# - finally, we get the photons that correspond to this object in the cetegorical distribution
# and assign them to the photons object there is a special private method on the
# PhotonArray that does this assignment
#
# one nice thing about this is that the photons come out pre-shuffled and so we don't have
# to mark them as correlated.
rng = BaseDeviate(rng)
key = rng._state.split_one()
cat_inds = jax.random.choice(
Expand All @@ -205,30 +221,17 @@ def _shoot(self, photons, rng):
for i, obj in enumerate(self.obj_list):
pa = obj.shoot(photons.size(), rng=rng)
# now we rescale the fluxes of the photons
# the photons start with
# in galsim, photons end up with a flux that is
#
# flux_per_photon = (obj.positive_flux + obj.negative_flux) / photons.size()
# fluxes[i] / thisN * tot_flux / photons.size() * thisN / fluxes[i]
# = tot_flux / photons.size()
#
# but they should have had a flux per photon of
# our photons start with a flux of
#
# flux_per_photon = (self.positive_flux + self.negative_flux) / thisN
# = fluxes[i] / thisN
#
# where thisN = jnp.sum(cat_inds == i). We drew photons.size() photons instead
# of thisN, above. so we scale their fluxes by a factor of
#
# _scale_fac = photons.size() / thisN
#
# next we want them to have a total flux of
#
# tot_flux_per_photon = (self.positive_flux + self.negative_flux) / photons.size()
# flux[i] / photons.size()
#
# so we scale by a factor of
#
# _scale_fac = tot_flux_per_photon / flux_per_photon_thisN
#
# so we get a total factor of
#
# _scale_fac = tot_flux / fluxes[i]
_scale_fac = tot_flux / fluxes[i]
pa.scaleFlux(_scale_fac)
Expand Down

0 comments on commit 1416586

Please sign in to comment.