Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chain_iptm #9

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions alphafold/common/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,80 @@ def predicted_tm_score(logits, breaks, residue_weights = None,

return (per_alignment * residue_weights).max()

def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=False):

def predicted_tm_score_chain(logits, breaks, residue_weights = None,
asym_id = None, use_jnp=False, chain_num=None):
"""Computes predicted the chain matrix of pTM scores.

Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins] the error bins.
residue_weights: [num_res] the per residue weights to use for the
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation.
use_jnp: Whether to use JAX.
chain_num: The number of chains in the multimer model.


Returns:
iptm_matrix: The predicted TM alignment or the predicted iTM score.
predicted_tm_term: The predicted TM term for each bin.
"""
if use_jnp:
_np, _softmax = jnp, jax.nn.softmax
else:
_np, _softmax = np, scipy.special.softmax

if chain_num is None:
chain_num = 1

# residue_weights has to be in [0, 1], but can be floating-point, i.e. the
# exp. resolved head's probability.
if residue_weights is None:
residue_weights = _np.ones(logits.shape[0])

bin_centers = _calculate_bin_centers(breaks, use_jnp=use_jnp)
num_res = residue_weights.shape[0]

# Clip num_res to avoid negative/undefined d0.
clipped_num_res = _np.maximum(residue_weights.sum(), 19)

# Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
# "Scoring function for automated assessment of protein structure template
# quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8

# Convert logits to probs.
probs = _softmax(logits, axis=-1)

# TM-Score term for every bin.
tm_per_bin = 1. / (1 + _np.square(bin_centers) / _np.square(d0))
# E_distances tm(distance).
predicted_tm_term = (probs * tm_per_bin).sum(-1)

# jax.debug.print('residue weights = {x}',x=residue_weights)

def get_cross_iptm(i, j):
pair_mask = jnp.logical_and(i * jnp.ones((num_res))[:, None] == asym_id[None, :] , j*jnp.ones((num_res))[None, :] == asym_id[:, None])
chain_chain_predicted_tm_term = predicted_tm_term * pair_mask
pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True))
per_alignment = (chain_chain_predicted_tm_term * normed_residue_mask).sum(-1)
return (per_alignment * residue_weights).max()

iptm_matrix_list = []

for i in jnp.arange(chain_num):
local_list = []
for j in jnp.arange(chain_num):
local_list.append(get_cross_iptm(i, j))
iptm_matrix_list.append(local_list)

return(iptm_matrix_list, predicted_tm_term)

def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=False, chain_num=None):
"""Post processes prediction_result to get confidence metrics."""
confidence_metrics = {}
plddt = compute_plddt(prediction_result['predicted_lddt']['logits'], use_jnp=use_jnp)
Expand All @@ -195,7 +268,14 @@ def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=F
residue_weights=mask,
asym_id=prediction_result['predicted_aligned_error']['asym_id'],
use_jnp=use_jnp)

confidence_metrics['chain_iptm'], confidence_metrics['ptm_matrix'] = predicted_tm_score_chain(
logits=prediction_result['predicted_aligned_error']['logits'],
breaks=prediction_result['predicted_aligned_error']['breaks'],
residue_weights=mask,
asym_id=prediction_result['predicted_aligned_error']['asym_id'],
use_jnp=use_jnp,
chain_num=chain_num,
)
# compute mean_score
if rank_by == "multimer":
mean_score = 80 * confidence_metrics["iptm"] + 20 * confidence_metrics["ptm"]
Expand Down
6 changes: 5 additions & 1 deletion alphafold/model/modules_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,16 @@ def apply_network(prev, safe_key):
if not return_representations:
del ret['representations']

# Extract chain NUM
chain_num = c.embeddings_and_evoformer.max_relative_chain + 1

# add confidence metrics
ret.update(confidence.get_confidence_metrics(
prediction_result=ret,
mask=batch["seq_mask"],
rank_by=self.config.rank_by,
use_jnp=True))
use_jnp=True,
chain_num=chain_num))

ret["tol"] = confidence.compute_tol(
prev["prev_pos"],
Expand Down