diff --git a/flamedisx/nest/lxe_blocks/quanta_splitting.py b/flamedisx/nest/lxe_blocks/quanta_splitting.py index 269b41490..3ff53651a 100644 --- a/flamedisx/nest/lxe_blocks/quanta_splitting.py +++ b/flamedisx/nest/lxe_blocks/quanta_splitting.py @@ -39,7 +39,7 @@ def setup(self): self.array_columns = (('ions_produced_min', max(len(self.source.energies), 2)),) -def _compute(self, + def _compute(self, data_tensor, ptensor, # Domain electrons_produced, photons_produced, @@ -47,6 +47,7 @@ def _compute(self, ions_produced, # Dependency domain and value energy, rate_vs_energy): + def compute_single_energy(args, approx=False): # Compute the block for a single energy. # Set approx to True for an approximate computation at higher energies @@ -61,12 +62,7 @@ def compute_single_energy(args, approx=False): # Calculate the ion domain tensor for this energy _ions_produced = ions_produced_add + ions_min - #every event in the batch shares E therefore ions domain - _ions_produced_1D=_ions_produced[0,0,0,:] - #create p_ni/p_nq domain for ER/NR - nq_2D=tf.repeat(unique_quanta[:,o],tf.shape(_ions_produced_1D)[0],axis=1) - ni_2D=tf.repeat(_ions_produced_1D[o,:],tf.shape(unique_quanta)[0],axis=0) - + if self.is_ER: nel_mean = self.gimme('mean_yield_electron', data_tensor=data_tensor, ptensor=ptensor, bonus_arg=energy) @@ -76,25 +72,19 @@ def compute_single_energy(args, approx=False): bonus_arg=nq_mean) if approx: - p_nq_1D = tfp.distributions.Normal(loc=nq_mean, - scale=tf.sqrt(nq_mean * fano) + 1e-10).prob(unique_quanta) + p_nq = tfp.distributions.Normal(loc=nq_mean, + scale=tf.sqrt(nq_mean * fano) + 1e-10).prob(nq) else: normal_dist_nq = tfp.distributions.Normal(loc=nq_mean, - scale=tf.sqrt(nq_mean * fano) + 1e-10) - p_nq_1D=normal_dist_nq.cdf(unique_quanta + 0.5) - normal_dist_nq.cdf(unique_quanta - 0.5) - #restore p_ni from unique_nq x n_ions -> unique_nel x n_electrons x n_photons (does not need n_ions) - p_nq=tf.gather_nd(params=p_nq_1D,indices=index_nq[:,o],batch_dims=0) - p_nq=tf.reshape(p_nq,[tf.shape(nq)[0],tf.shape(nq)[1],tf.shape(nq)[2]]) - + scale=tf.sqrt(nq_mean * fano) + 1e-10) + p_nq = normal_dist_nq.cdf(nq + 0.5) - normal_dist_nq.cdf(nq - 0.5) ex_ratio = self.gimme('exciton_ratio', data_tensor=data_tensor, ptensor=ptensor, bonus_arg=energy) alpha = 1. / (1. + ex_ratio) - p_ni_2D=tfp.distributions.Binomial(total_count=nq_2D, probs=alpha).prob(ni_2D) - #restore p_ni from unique_nq x n_ions -> unique_nel x n_electrons x n_photons x n_ions - p_ni=tf.gather_nd(params=p_ni_2D,indices=index_nq[:,o],batch_dims=0) - p_ni=tf.reshape(tf.reshape(p_ni,[-1]),[tf.shape(nq)[0],tf.shape(nq)[1],tf.shape(nq)[2],tf.shape(nq)[3]]) + p_ni = tfp.distributions.Binomial( + total_count=nq, probs=alpha).prob(_ions_produced) else: yields = self.gimme('mean_yields', data_tensor=data_tensor, ptensor=ptensor, @@ -108,47 +98,37 @@ def compute_single_energy(args, approx=False): bonus_arg=nq_mean) ni_fano = yield_fano[0] nex_fano = yield_fano[1] - - #p_ni does not need to have an altered dimensionality, see tensordot. + if approx: - p_ni_1D = tfp.distributions.Normal(loc=nq_mean*alpha, - scale=tf.sqrt(nq_mean*alpha*ni_fano) + 1e-10).prob(_ions_produced_1D) + p_ni = tfp.distributions.Normal(loc=nq_mean*alpha, + scale=tf.sqrt(nq_mean*alpha*ni_fano) + 1e-10).prob(_ions_produced) - p_nq_2D = tfp.distributions.Normal(loc=nq_mean*alpha*ex_ratio, + p_nq = tfp.distributions.Normal(loc=nq_mean*alpha*ex_ratio, scale=tf.sqrt(nq_mean*alpha*ex_ratio*nex_fano) + 1e-10).prob( - nq_2D - ni_2D) + nq - _ions_produced) else: normal_dist_ni = tfp.distributions.Normal(loc=nq_mean*alpha, scale=tf.sqrt(nq_mean*alpha*ni_fano) + 1e-10) - p_ni_1D = normal_dist_ni.cdf(_ions_produced_1D + 0.5) - \ - normal_dist_ni.cdf(_ions_produced_1D - 0.5) + p_ni = normal_dist_ni.cdf(_ions_produced + 0.5) - \ + normal_dist_ni.cdf(_ions_produced - 0.5) normal_dist_nq = tfp.distributions.Normal(loc=nq_mean*alpha*ex_ratio, scale=tf.sqrt(nq_mean*alpha*ex_ratio*nex_fano) + 1e-10) - p_nq_2D = normal_dist_nq.cdf(nq_2D - ni_2D + 0.5) \ - - normal_dist_nq.cdf(nq_2D - ni_2D - 0.5) - - # restore p_nq from unique_nq x n_ions -> unique_nel x n_electrons x n_photons x n_ions - p_nq=tf.gather_nd(params=p_nq_2D,indices=index_nq[:,o],batch_dims=0) - p_nq=tf.reshape(tf.reshape(p_nq,[-1]),[tf.shape(nq)[0],tf.shape(nq)[1],tf.shape(nq)[2],tf.shape(nq)[3]]) - - - - nel_2D=tf.repeat(unique_nel[:,o],tf.shape(_ions_produced_1D)[0],axis=1) - ni_nel_2D=tf.repeat(_ions_produced_1D[o,:],tf.shape(unique_nel)[0],axis=0) + p_nq = normal_dist_nq.cdf(nq - _ions_produced + 0.5) \ + - normal_dist_nq.cdf(nq - _ions_produced - 0.5) recomb_p = self.gimme('recomb_prob', data_tensor=data_tensor, ptensor=ptensor, bonus_arg=(nel_mean, nq_mean, ex_ratio)) skew = self.gimme('skewness', data_tensor=data_tensor, ptensor=ptensor, bonus_arg=nq_mean) var = self.gimme('variance', data_tensor=data_tensor, ptensor=ptensor, - bonus_arg=(nel_mean, nq_mean, recomb_p, ni_nel_2D)) + bonus_arg=(nel_mean, nq_mean, recomb_p, _ions_produced)) width_corr = self.gimme('width_correction', data_tensor=data_tensor, ptensor=ptensor, bonus_arg=skew) mu_corr = self.gimme('mu_correction', data_tensor=data_tensor, ptensor=ptensor, bonus_arg=(skew, var, width_corr)) - mean = (tf.ones_like(ni_nel_2D, dtype=fd.float_type()) - recomb_p) * ni_nel_2D - mu_corr + mean = (tf.ones_like(_ions_produced, dtype=fd.float_type()) - recomb_p) * _ions_produced - mu_corr std_dev = tf.sqrt(var) / width_corr if self.is_ER: @@ -157,27 +137,19 @@ def compute_single_energy(args, approx=False): owens_t_terms = 5 if approx: - p_nel_1D = fd.tfp_files.SkewGaussian(loc=mean, scale=std_dev, - skewness=skew, - owens_t_terms=owens_t_terms).prob(nel_2D) + p_nel = fd.tfp_files.SkewGaussian(loc=mean, scale=std_dev, + skewness=skew, + owens_t_terms=owens_t_terms).prob(electrons_produced) else: - p_nel_1D =fd.tfp_files.TruncatedSkewGaussianCC(loc=mean, scale=std_dev, - skewness=skew, - limit=ni_nel_2D, - owens_t_terms=owens_t_terms).prob(nel_2D) - - - #Restore p_nel unique_nel x nions-> unique_nel x n_electrons x n_photons x n_ions - p_nel=tf.gather_nd(params=p_nel_1D,indices=index_nel[:,o],batch_dims=0) - p_nel=tf.reshape(tf.reshape(p_nel,[-1]),[tf.shape(nq)[0],tf.shape(nq)[1],tf.shape(nq)[3]]) - p_nel=tf.repeat(p_nel[:,:,o,:],tf.shape(nq)[2],axis=2) - #modified contractions remove need for costly repeats in ions dimension. - if self.is_ER: - p_mult = p_ni * p_nel - p_final = tf.reduce_sum(p_mult, 3)*p_nq - else: - p_mult = p_nq*p_nel - p_final = tf.tensordot(p_mult,p_ni_1D,axes=[[3],[0]]) + p_nel = fd.tfp_files.TruncatedSkewGaussianCC(loc=mean, scale=std_dev, + skewness=skew, + limit=_ions_produced, + owens_t_terms=owens_t_terms).prob(electrons_produced) + + p_mult = p_nq * p_ni * p_nel + + # Contract over ions_produced + p_final = tf.reduce_sum(p_mult, 3) r_final = p_final * rate_vs_energy @@ -197,12 +169,6 @@ def compute_single_energy_approx(args): return compute_single_energy(args, approx=True) nq = electrons_produced + photons_produced - #remove degenerate dimensions - # unique_quanta,index_nq=unique(nq[:,:,:,0])#nevts x nph x nel->unique_nq - # unique_nel,index_nel=unique(electrons_produced[:,:,0,0])#nevts x nel->unique_nel - unique_quanta,index_nq=tf.unique(tf.reshape(nq[:,:,:,0],[-1])) - unique_nel,index_nel=tf.unique(tf.reshape(electrons_produced[:,:,0,0],[-1])) - ions_min_initial = self.source._fetch('ions_produced_min', data_tensor=data_tensor)[:, 0, o] ions_min_initial = tf.repeat(ions_min_initial, tf.shape(ions_produced)[1], axis=1) @@ -213,7 +179,6 @@ def compute_single_energy_approx(args): # for the lowest energy ions_produced_add = ions_produced - ions_min_initial - # Energy above which we use the approximate computation if self.is_ER: cutoff_energy = 5.