diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index e5ea99af0..da1cb08c5 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -168,7 +168,80 @@ def predicted_tm_score(logits, breaks, residue_weights = None, return (per_alignment * residue_weights).max() -def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=False): + +def predicted_tm_score_chain(logits, breaks, residue_weights = None, + asym_id = None, use_jnp=False, chain_num=None): + """Computes predicted the chain matrix of pTM scores. + + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for + ipTM calculation. + use_jnp: Whether to use JAX. + chain_num: The number of chains in the multimer model. + + + Returns: + iptm_matrix: The predicted TM alignment or the predicted iTM score. + predicted_tm_term: The predicted TM term for each bin. + """ + if use_jnp: + _np, _softmax = jnp, jax.nn.softmax + else: + _np, _softmax = np, scipy.special.softmax + + if chain_num is None: + chain_num = 1 + + # residue_weights has to be in [0, 1], but can be floating-point, i.e. the + # exp. resolved head's probability. + if residue_weights is None: + residue_weights = _np.ones(logits.shape[0]) + + bin_centers = _calculate_bin_centers(breaks, use_jnp=use_jnp) + num_res = residue_weights.shape[0] + + # Clip num_res to avoid negative/undefined d0. + clipped_num_res = _np.maximum(residue_weights.sum(), 19) + + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick + # "Scoring function for automated assessment of protein structure template + # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf + d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 + + # Convert logits to probs. + probs = _softmax(logits, axis=-1) + + # TM-Score term for every bin. + tm_per_bin = 1. / (1 + _np.square(bin_centers) / _np.square(d0)) + # E_distances tm(distance). + predicted_tm_term = (probs * tm_per_bin).sum(-1) + + # jax.debug.print('residue weights = {x}',x=residue_weights) + + def get_cross_iptm(i, j): + pair_mask = jnp.logical_and(i * jnp.ones((num_res))[:, None] == asym_id[None, :] , j*jnp.ones((num_res))[None, :] == asym_id[:, None]) + chain_chain_predicted_tm_term = predicted_tm_term * pair_mask + pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None]) + normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True)) + per_alignment = (chain_chain_predicted_tm_term * normed_residue_mask).sum(-1) + return (per_alignment * residue_weights).max() + + iptm_matrix_list = [] + + for i in jnp.arange(chain_num): + local_list = [] + for j in jnp.arange(chain_num): + local_list.append(get_cross_iptm(i, j)) + iptm_matrix_list.append(local_list) + + return(iptm_matrix_list, predicted_tm_term) + +def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=False, chain_num=None): """Post processes prediction_result to get confidence metrics.""" confidence_metrics = {} plddt = compute_plddt(prediction_result['predicted_lddt']['logits'], use_jnp=use_jnp) @@ -195,7 +268,14 @@ def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=F residue_weights=mask, asym_id=prediction_result['predicted_aligned_error']['asym_id'], use_jnp=use_jnp) - + confidence_metrics['chain_iptm'], confidence_metrics['ptm_matrix'] = predicted_tm_score_chain( + logits=prediction_result['predicted_aligned_error']['logits'], + breaks=prediction_result['predicted_aligned_error']['breaks'], + residue_weights=mask, + asym_id=prediction_result['predicted_aligned_error']['asym_id'], + use_jnp=use_jnp, + chain_num=chain_num, + ) # compute mean_score if rank_by == "multimer": mean_score = 80 * confidence_metrics["iptm"] + 20 * confidence_metrics["ptm"] diff --git a/alphafold/model/modules_multimer.py b/alphafold/model/modules_multimer.py index 7cd8a6fd5..1e909c944 100644 --- a/alphafold/model/modules_multimer.py +++ b/alphafold/model/modules_multimer.py @@ -461,12 +461,16 @@ def apply_network(prev, safe_key): if not return_representations: del ret['representations'] + # Extract chain NUM + chain_num = c.embeddings_and_evoformer.max_relative_chain + 1 + # add confidence metrics ret.update(confidence.get_confidence_metrics( prediction_result=ret, mask=batch["seq_mask"], rank_by=self.config.rank_by, - use_jnp=True)) + use_jnp=True, + chain_num=chain_num)) ret["tol"] = confidence.compute_tol( prev["prev_pos"],