Skip to content

Commit

Permalink
Merge pull request #740 from padix-key/lddt
Browse files Browse the repository at this point in the history
Allow more custom contact filtering in `lddt()`
  • Loading branch information
padix-key authored Jan 28, 2025
2 parents 6aae94c + 8ee7779 commit d979def
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/biotite/structure/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def get_chain_starts(array, add_exclusive_stop=False):
--------
get_residue_starts
"""
if array.array_length() == 0:
return np.array([], dtype=int)

diff = np.diff(array.res_id)
res_id_decrement = diff < 0
# This mask is 'true' at indices where the value changes
Expand Down
64 changes: 54 additions & 10 deletions src/biotite/structure/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def lddt(
inclusion_radius=15,
distance_bins=(0.5, 1.0, 2.0, 4.0),
exclude_same_residue=True,
exclude_same_chain=False,
filter_function=None,
symmetric=False,
):
"""
Expand Down Expand Up @@ -299,9 +301,16 @@ def lddt(
The distance bins for the score calculation, i.e if a distance deviation is
within the first bin, the score is 1, if it is outside all bins, the score is 0.
exclude_same_residue : bool, optional
If set to False, distances between atoms of the same residue are also
considered.
By default, only atom distances between different residues are considered.
If true, only atom distances between different residues are considered.
Otherwise, also atom distances within the same residue are included.
exclude_same_chain : bool, optional
If true, only atom distances between different chains are considered.
Otherwise, also atom distances within the same chain are included.
filter_function : Callable(ndarray, shape=(n,2), dtype=int -> ndarray, shape=(n,), dtype=bool), optional
Used for custom contact filtering, if the other parameters are not sufficient.
A function that takes an array of contact atom indices and returns a mask that
is ``True`` for all contacts that should be retained.
All other contacts are not considered for lDDT computation.
symmetric : bool, optional
If set to true, the *lDDT* score is computed symmetrically.
This means both contacts found in the `reference` and `subject` structure are
Expand Down Expand Up @@ -394,7 +403,13 @@ def lddt(
)

contacts = _find_contacts(
reference, atom_mask, partner_mask, inclusion_radius, exclude_same_residue
reference,
atom_mask,
partner_mask,
inclusion_radius,
exclude_same_residue,
exclude_same_chain,
filter_function,
)
if symmetric:
if not isinstance(subject, AtomArray):
Expand All @@ -403,7 +418,13 @@ def lddt(
f"but got '{type(subject).__name__}'"
)
subject_contacts = _find_contacts(
subject, atom_mask, partner_mask, inclusion_radius, exclude_same_residue
subject,
atom_mask,
partner_mask,
inclusion_radius,
exclude_same_residue,
exclude_same_chain,
filter_function,
)
contacts = np.concatenate((contacts, subject_contacts), axis=0)
# Adding additional contacts may introduce duplicates between the existing and
Expand Down Expand Up @@ -532,7 +553,9 @@ def _find_contacts(
atom_mask=None,
partner_mask=None,
inclusion_radius=15,
exclude_same_residue=True,
exclude_same_residue=False,
exclude_same_chain=True,
filter_function=None,
):
"""
Find contacts between the atoms in the given structure.
Expand All @@ -555,9 +578,16 @@ def _find_contacts(
inclusion_radius : float, optional
Pairwise atom distances are considered within this radius.
exclude_same_residue : bool, optional
If set to False, distances between atoms of the same residue are also
considered.
By default, only atom distances between different residues are considered.
If true, only atom distances between different residues are considered.
Otherwise, also atom distances within the same residue are included.
exclude_same_chain : bool, optional
If true, only atom distances between different chains are considered.
Otherwise, also atom distances within the same chain are included.
filter_function : Callable(ndarray, shape=(n,2), dtype=int -> ndarray, shape=(n,), dtype=bool), optional
Used for custom contact filtering, if the other parameters are not sufficient.
A function that takes an array of contact atom indices and returns a mask that
is ``True`` for all contacts that should be retained.
All other contacts are not considered for lDDT computation.
Returns
-------
Expand Down Expand Up @@ -588,7 +618,13 @@ def _find_contacts(
# Convert into pairs of indices
contacts = _to_sparse_indices(all_contacts)

if exclude_same_residue:
if exclude_same_chain:
# Do the same for the chain level
chain_indices = get_chain_positions(atoms, contacts.flatten()).reshape(
contacts.shape
)
contacts = contacts[chain_indices[:, 0] != chain_indices[:, 1]]
elif exclude_same_residue:
# Find the index of the residue for each atom
residue_indices = get_residue_positions(atoms, contacts.flatten()).reshape(
contacts.shape
Expand All @@ -598,6 +634,14 @@ def _find_contacts(
else:
# In any case self-contacts should not be considered
contacts = contacts[contacts[:, 0] != contacts[:, 1]]
if filter_function is not None:
mask = filter_function(contacts)
if mask.shape != (contacts.shape[0],):
raise IndexError(
f"Mask returned from filter function has shape {mask.shape}, "
f"but expected ({contacts.shape[0]},)"
)
contacts = contacts[filter_function(contacts), :]
return contacts


Expand Down
3 changes: 3 additions & 0 deletions src/biotite/structure/residues.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def get_residue_starts(array, add_exclusive_stop=False):
[ 0 16 35 56 75 92 116 135 157 169 176 183 197 208 219 226 250 264
278 292 304]
"""
if array.array_length() == 0:
return np.array([], dtype=int)

# These mask are 'true' at indices where the value changes
chain_id_changes = array.chain_id[1:] != array.chain_id[:-1]
res_id_changes = array.res_id[1:] != array.res_id[:-1]
Expand Down
61 changes: 61 additions & 0 deletions tests/structure/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,67 @@ def test_lddt_mask(models, seed):
assert test_lddt.tolist() == pytest.approx(ref_lddt.tolist())


@pytest.mark.parametrize("exclude_same_chain", [False, True])
@pytest.mark.parametrize("exclude_same_residue", [False, True])
def test_lddt_filter_function(models, exclude_same_residue, exclude_same_chain):
"""
In :func:`lddt()`, mimic the `exclude_same_residue` or `exclude_same_chain`
parameter using a custom `filter_function` and expect to get the same results
compared to using these parameters directly.
"""
# Cut the model into two chains to test 'exclude_same_chain'
models = models.copy()
models.chain_id[models.res_id >= 11] = "B"
reference = models[0]
subject = models[1]

if exclude_same_residue and exclude_same_chain:

def filter_function(contacts):
return (
reference.res_id[contacts[:, 0]] != reference.res_id[contacts[:, 1]]
) & (
reference.chain_id[contacts[:, 0]] != reference.chain_id[contacts[:, 1]]
)

elif exclude_same_residue:

def filter_function(contacts):
return reference.res_id[contacts[:, 0]] != reference.res_id[contacts[:, 1]]

elif exclude_same_chain:

def filter_function(contacts):
return (
reference.chain_id[contacts[:, 0]] != reference.chain_id[contacts[:, 1]]
)

else:

def filter_function(contacts):
return np.full(contacts.shape[0], True)

# Do not aggregate to make the test more strict
ref_lddt = struc.lddt(
reference,
subject,
exclude_same_residue=exclude_same_residue,
exclude_same_chain=exclude_same_chain,
aggregation="atom",
)

test_lddt = struc.lddt(
reference,
subject,
exclude_same_residue=False,
exclude_same_chain=False,
filter_function=filter_function,
aggregation="atom",
)

assert test_lddt.tolist() == pytest.approx(ref_lddt.tolist())


def test_custom_lddt_symmetric(models):
"""
Check that in :func:`lddt()` with ``symmetric=True`` the *lDDT* score is independent
Expand Down

0 comments on commit d979def

Please sign in to comment.