Skip to content

Commit

Permalink
Revert "Experimental Changes to CB"
Browse files Browse the repository at this point in the history
This reverts commit 7a0ddb0.

Accidental Push, sorry
  • Loading branch information
josh0-jrg committed Oct 23, 2023
1 parent 7a0ddb0 commit 4ce1649
Showing 1 changed file with 32 additions and 67 deletions.
99 changes: 32 additions & 67 deletions flamedisx/nest/lxe_blocks/quanta_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ def setup(self):
self.array_columns = (('ions_produced_min',
max(len(self.source.energies), 2)),)

def _compute(self,
def _compute(self,
data_tensor, ptensor,
# Domain
electrons_produced, photons_produced,
# Bonus dimension
ions_produced,
# Dependency domain and value
energy, rate_vs_energy):

def compute_single_energy(args, approx=False):
# Compute the block for a single energy.
# Set approx to True for an approximate computation at higher energies
Expand All @@ -61,12 +62,7 @@ def compute_single_energy(args, approx=False):

# Calculate the ion domain tensor for this energy
_ions_produced = ions_produced_add + ions_min
#every event in the batch shares E therefore ions domain
_ions_produced_1D=_ions_produced[0,0,0,:]
#create p_ni/p_nq domain for ER/NR
nq_2D=tf.repeat(unique_quanta[:,o],tf.shape(_ions_produced_1D)[0],axis=1)
ni_2D=tf.repeat(_ions_produced_1D[o,:],tf.shape(unique_quanta)[0],axis=0)


if self.is_ER:
nel_mean = self.gimme('mean_yield_electron', data_tensor=data_tensor, ptensor=ptensor,
bonus_arg=energy)
Expand All @@ -76,25 +72,19 @@ def compute_single_energy(args, approx=False):
bonus_arg=nq_mean)

if approx:
p_nq_1D = tfp.distributions.Normal(loc=nq_mean,
scale=tf.sqrt(nq_mean * fano) + 1e-10).prob(unique_quanta)
p_nq = tfp.distributions.Normal(loc=nq_mean,
scale=tf.sqrt(nq_mean * fano) + 1e-10).prob(nq)
else:
normal_dist_nq = tfp.distributions.Normal(loc=nq_mean,
scale=tf.sqrt(nq_mean * fano) + 1e-10)
p_nq_1D=normal_dist_nq.cdf(unique_quanta + 0.5) - normal_dist_nq.cdf(unique_quanta - 0.5)
#restore p_ni from unique_nq x n_ions -> unique_nel x n_electrons x n_photons (does not need n_ions)
p_nq=tf.gather_nd(params=p_nq_1D,indices=index_nq[:,o],batch_dims=0)
p_nq=tf.reshape(p_nq,[tf.shape(nq)[0],tf.shape(nq)[1],tf.shape(nq)[2]])

scale=tf.sqrt(nq_mean * fano) + 1e-10)
p_nq = normal_dist_nq.cdf(nq + 0.5) - normal_dist_nq.cdf(nq - 0.5)

ex_ratio = self.gimme('exciton_ratio', data_tensor=data_tensor, ptensor=ptensor,
bonus_arg=energy)
alpha = 1. / (1. + ex_ratio)

p_ni_2D=tfp.distributions.Binomial(total_count=nq_2D, probs=alpha).prob(ni_2D)
#restore p_ni from unique_nq x n_ions -> unique_nel x n_electrons x n_photons x n_ions
p_ni=tf.gather_nd(params=p_ni_2D,indices=index_nq[:,o],batch_dims=0)
p_ni=tf.reshape(tf.reshape(p_ni,[-1]),[tf.shape(nq)[0],tf.shape(nq)[1],tf.shape(nq)[2],tf.shape(nq)[3]])
p_ni = tfp.distributions.Binomial(
total_count=nq, probs=alpha).prob(_ions_produced)

else:
yields = self.gimme('mean_yields', data_tensor=data_tensor, ptensor=ptensor,
Expand All @@ -108,47 +98,37 @@ def compute_single_energy(args, approx=False):
bonus_arg=nq_mean)
ni_fano = yield_fano[0]
nex_fano = yield_fano[1]

#p_ni does not need to have an altered dimensionality, see tensordot.

if approx:
p_ni_1D = tfp.distributions.Normal(loc=nq_mean*alpha,
scale=tf.sqrt(nq_mean*alpha*ni_fano) + 1e-10).prob(_ions_produced_1D)
p_ni = tfp.distributions.Normal(loc=nq_mean*alpha,
scale=tf.sqrt(nq_mean*alpha*ni_fano) + 1e-10).prob(_ions_produced)

p_nq_2D = tfp.distributions.Normal(loc=nq_mean*alpha*ex_ratio,
p_nq = tfp.distributions.Normal(loc=nq_mean*alpha*ex_ratio,
scale=tf.sqrt(nq_mean*alpha*ex_ratio*nex_fano) + 1e-10).prob(
nq_2D - ni_2D)
nq - _ions_produced)
else:
normal_dist_ni = tfp.distributions.Normal(loc=nq_mean*alpha,
scale=tf.sqrt(nq_mean*alpha*ni_fano) + 1e-10)
p_ni_1D = normal_dist_ni.cdf(_ions_produced_1D + 0.5) - \
normal_dist_ni.cdf(_ions_produced_1D - 0.5)
p_ni = normal_dist_ni.cdf(_ions_produced + 0.5) - \
normal_dist_ni.cdf(_ions_produced - 0.5)

normal_dist_nq = tfp.distributions.Normal(loc=nq_mean*alpha*ex_ratio,
scale=tf.sqrt(nq_mean*alpha*ex_ratio*nex_fano) + 1e-10)
p_nq_2D = normal_dist_nq.cdf(nq_2D - ni_2D + 0.5) \
- normal_dist_nq.cdf(nq_2D - ni_2D - 0.5)

# restore p_nq from unique_nq x n_ions -> unique_nel x n_electrons x n_photons x n_ions
p_nq=tf.gather_nd(params=p_nq_2D,indices=index_nq[:,o],batch_dims=0)
p_nq=tf.reshape(tf.reshape(p_nq,[-1]),[tf.shape(nq)[0],tf.shape(nq)[1],tf.shape(nq)[2],tf.shape(nq)[3]])



nel_2D=tf.repeat(unique_nel[:,o],tf.shape(_ions_produced_1D)[0],axis=1)
ni_nel_2D=tf.repeat(_ions_produced_1D[o,:],tf.shape(unique_nel)[0],axis=0)
p_nq = normal_dist_nq.cdf(nq - _ions_produced + 0.5) \
- normal_dist_nq.cdf(nq - _ions_produced - 0.5)

recomb_p = self.gimme('recomb_prob', data_tensor=data_tensor, ptensor=ptensor,
bonus_arg=(nel_mean, nq_mean, ex_ratio))
skew = self.gimme('skewness', data_tensor=data_tensor, ptensor=ptensor,
bonus_arg=nq_mean)
var = self.gimme('variance', data_tensor=data_tensor, ptensor=ptensor,
bonus_arg=(nel_mean, nq_mean, recomb_p, ni_nel_2D))
bonus_arg=(nel_mean, nq_mean, recomb_p, _ions_produced))
width_corr = self.gimme('width_correction', data_tensor=data_tensor, ptensor=ptensor,
bonus_arg=skew)
mu_corr = self.gimme('mu_correction', data_tensor=data_tensor, ptensor=ptensor,
bonus_arg=(skew, var, width_corr))

mean = (tf.ones_like(ni_nel_2D, dtype=fd.float_type()) - recomb_p) * ni_nel_2D - mu_corr
mean = (tf.ones_like(_ions_produced, dtype=fd.float_type()) - recomb_p) * _ions_produced - mu_corr
std_dev = tf.sqrt(var) / width_corr

if self.is_ER:
Expand All @@ -157,27 +137,19 @@ def compute_single_energy(args, approx=False):
owens_t_terms = 5

if approx:
p_nel_1D = fd.tfp_files.SkewGaussian(loc=mean, scale=std_dev,
skewness=skew,
owens_t_terms=owens_t_terms).prob(nel_2D)
p_nel = fd.tfp_files.SkewGaussian(loc=mean, scale=std_dev,
skewness=skew,
owens_t_terms=owens_t_terms).prob(electrons_produced)
else:
p_nel_1D =fd.tfp_files.TruncatedSkewGaussianCC(loc=mean, scale=std_dev,
skewness=skew,
limit=ni_nel_2D,
owens_t_terms=owens_t_terms).prob(nel_2D)


#Restore p_nel unique_nel x nions-> unique_nel x n_electrons x n_photons x n_ions
p_nel=tf.gather_nd(params=p_nel_1D,indices=index_nel[:,o],batch_dims=0)
p_nel=tf.reshape(tf.reshape(p_nel,[-1]),[tf.shape(nq)[0],tf.shape(nq)[1],tf.shape(nq)[3]])
p_nel=tf.repeat(p_nel[:,:,o,:],tf.shape(nq)[2],axis=2)
#modified contractions remove need for costly repeats in ions dimension.
if self.is_ER:
p_mult = p_ni * p_nel
p_final = tf.reduce_sum(p_mult, 3)*p_nq
else:
p_mult = p_nq*p_nel
p_final = tf.tensordot(p_mult,p_ni_1D,axes=[[3],[0]])
p_nel = fd.tfp_files.TruncatedSkewGaussianCC(loc=mean, scale=std_dev,
skewness=skew,
limit=_ions_produced,
owens_t_terms=owens_t_terms).prob(electrons_produced)

p_mult = p_nq * p_ni * p_nel

# Contract over ions_produced
p_final = tf.reduce_sum(p_mult, 3)

r_final = p_final * rate_vs_energy

Expand All @@ -197,12 +169,6 @@ def compute_single_energy_approx(args):
return compute_single_energy(args, approx=True)

nq = electrons_produced + photons_produced
#remove degenerate dimensions
# unique_quanta,index_nq=unique(nq[:,:,:,0])#nevts x nph x nel->unique_nq
# unique_nel,index_nel=unique(electrons_produced[:,:,0,0])#nevts x nel->unique_nel
unique_quanta,index_nq=tf.unique(tf.reshape(nq[:,:,:,0],[-1]))
unique_nel,index_nel=tf.unique(tf.reshape(electrons_produced[:,:,0,0],[-1]))


ions_min_initial = self.source._fetch('ions_produced_min', data_tensor=data_tensor)[:, 0, o]
ions_min_initial = tf.repeat(ions_min_initial, tf.shape(ions_produced)[1], axis=1)
Expand All @@ -213,7 +179,6 @@ def compute_single_energy_approx(args):
# for the lowest energy
ions_produced_add = ions_produced - ions_min_initial


# Energy above which we use the approximate computation
if self.is_ER:
cutoff_energy = 5.
Expand Down

0 comments on commit 4ce1649

Please sign in to comment.