Skip to content

Commit

Permalink
Allow more custom contact filtering in lddt()
Browse files Browse the repository at this point in the history
  • Loading branch information
padix-key committed Jan 25, 2025
1 parent 6aae94c commit 4d05f3c
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 10 deletions.
64 changes: 54 additions & 10 deletions src/biotite/structure/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def lddt(
inclusion_radius=15,
distance_bins=(0.5, 1.0, 2.0, 4.0),
exclude_same_residue=True,
exclude_same_chain=False,
filter_function=None,
symmetric=False,
):
"""
Expand Down Expand Up @@ -299,9 +301,16 @@ def lddt(
The distance bins for the score calculation, i.e if a distance deviation is
within the first bin, the score is 1, if it is outside all bins, the score is 0.
exclude_same_residue : bool, optional
If set to False, distances between atoms of the same residue are also
considered.
By default, only atom distances between different residues are considered.
If true, only atom distances between different residues are considered.
Otherwise, also atom distances within the same residue are included.
exclude_same_chain : bool, optional
If true, only atom distances between different chains are considered.
Otherwise, also atom distances within the same chain are included.
filter_function : Callable(ndarray, shape=(n,2), dtype=int -> ndarray, shape=(n,), dtype=bool), optional
Used for custom contact filtering, if the other parameters are not sufficient.
A function that takes an array of contact atom indices and returns a mask that
is ``True`` for all contacts that should be retained.
All other contacts are not considered for lDDT computation.
symmetric : bool, optional
If set to true, the *lDDT* score is computed symmetrically.
This means both contacts found in the `reference` and `subject` structure are
Expand Down Expand Up @@ -394,7 +403,13 @@ def lddt(
)

contacts = _find_contacts(
reference, atom_mask, partner_mask, inclusion_radius, exclude_same_residue
reference,
atom_mask,
partner_mask,
inclusion_radius,
exclude_same_residue,
exclude_same_chain,
filter_function,
)
if symmetric:
if not isinstance(subject, AtomArray):
Expand All @@ -403,7 +418,13 @@ def lddt(
f"but got '{type(subject).__name__}'"
)
subject_contacts = _find_contacts(
subject, atom_mask, partner_mask, inclusion_radius, exclude_same_residue
subject,
atom_mask,
partner_mask,
inclusion_radius,
exclude_same_residue,
exclude_same_chain,
filter_function,
)
contacts = np.concatenate((contacts, subject_contacts), axis=0)
# Adding additional contacts may introduce duplicates between the existing and
Expand Down Expand Up @@ -532,7 +553,9 @@ def _find_contacts(
atom_mask=None,
partner_mask=None,
inclusion_radius=15,
exclude_same_residue=True,
exclude_same_residue=False,
exclude_same_chain=True,
filter_function=None,
):
"""
Find contacts between the atoms in the given structure.
Expand All @@ -555,9 +578,16 @@ def _find_contacts(
inclusion_radius : float, optional
Pairwise atom distances are considered within this radius.
exclude_same_residue : bool, optional
If set to False, distances between atoms of the same residue are also
considered.
By default, only atom distances between different residues are considered.
If true, only atom distances between different residues are considered.
Otherwise, also atom distances within the same residue are included.
exclude_same_chain : bool, optional
If true, only atom distances between different chains are considered.
Otherwise, also atom distances within the same chain are included.
filter_function : Callable(ndarray, shape=(n,2), dtype=int -> ndarray, shape=(n,), dtype=bool), optional
Used for custom contact filtering, if the other parameters are not sufficient.
A function that takes an array of contact atom indices and returns a mask that
is ``True`` for all contacts that should be retained.
All other contacts are not considered for lDDT computation.
Returns
-------
Expand Down Expand Up @@ -588,7 +618,13 @@ def _find_contacts(
# Convert into pairs of indices
contacts = _to_sparse_indices(all_contacts)

if exclude_same_residue:
if exclude_same_chain:
# Do the same for the chain level
chain_indices = get_chain_positions(atoms, contacts.flatten()).reshape(
contacts.shape
)
contacts = contacts[chain_indices[:, 0] != chain_indices[:, 1]]
elif exclude_same_residue:
# Find the index of the residue for each atom
residue_indices = get_residue_positions(atoms, contacts.flatten()).reshape(
contacts.shape
Expand All @@ -598,6 +634,14 @@ def _find_contacts(
else:
# In any case self-contacts should not be considered
contacts = contacts[contacts[:, 0] != contacts[:, 1]]
if filter_function is not None:
mask = filter_function(contacts)
if mask.shape != (contacts.shape[0],):
raise IndexError(
f"Mask returned from filter function has shape {mask.shape}, "
f"but expected ({contacts.shape[0]},)"
)
contacts = contacts[filter_function(contacts), :]
return contacts


Expand Down
61 changes: 61 additions & 0 deletions tests/structure/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,67 @@ def test_lddt_mask(models, seed):
assert test_lddt.tolist() == pytest.approx(ref_lddt.tolist())


@pytest.mark.parametrize("exclude_same_chain", [False, True])
@pytest.mark.parametrize("exclude_same_residue", [False, True])
def test_lddt_filter_function(models, exclude_same_residue, exclude_same_chain):
"""
In :func:`lddt()`, mimic the `exclude_same_residue` or `exclude_same_chain`
parameter using a custom `filter_function` and expect to get the same results
compared to using these parameters directly.
"""
# Cut the model into two chains to test 'exclude_same_chain'
models = models.copy()
models.chain_id[models.res_id >= 11] = "B"
reference = models[0]
subject = models[1]

if exclude_same_residue and exclude_same_chain:

def filter_function(contacts):
return (
reference.res_id[contacts[:, 0]] != reference.res_id[contacts[:, 1]]
) & (
reference.chain_id[contacts[:, 0]] != reference.chain_id[contacts[:, 1]]
)

elif exclude_same_residue:

def filter_function(contacts):
return reference.res_id[contacts[:, 0]] != reference.res_id[contacts[:, 1]]

elif exclude_same_chain:

def filter_function(contacts):
return (
reference.chain_id[contacts[:, 0]] != reference.chain_id[contacts[:, 1]]
)

else:

def filter_function(contacts):
return np.full(contacts.shape[0], True)

# Do not aggregate to make the test more strict
ref_lddt = struc.lddt(
reference,
subject,
exclude_same_residue=exclude_same_residue,
exclude_same_chain=exclude_same_chain,
aggregation="atom",
)

test_lddt = struc.lddt(
reference,
subject,
exclude_same_residue=False,
exclude_same_chain=False,
filter_function=filter_function,
aggregation="atom",
)

assert test_lddt.tolist() == pytest.approx(ref_lddt.tolist())


def test_custom_lddt_symmetric(models):
"""
Check that in :func:`lddt()` with ``symmetric=True`` the *lDDT* score is independent
Expand Down

0 comments on commit 4d05f3c

Please sign in to comment.