From 4d05f3cffd95644271ec154bf220a79100624c36 Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Fri, 24 Jan 2025 12:03:17 +0100 Subject: [PATCH 1/2] Allow more custom contact filtering in `lddt()` --- src/biotite/structure/compare.py | 64 +++++++++++++++++++++++++++----- tests/structure/test_compare.py | 61 ++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 10 deletions(-) diff --git a/src/biotite/structure/compare.py b/src/biotite/structure/compare.py index 34d8cb5ca..ed111ed55 100644 --- a/src/biotite/structure/compare.py +++ b/src/biotite/structure/compare.py @@ -256,6 +256,8 @@ def lddt( inclusion_radius=15, distance_bins=(0.5, 1.0, 2.0, 4.0), exclude_same_residue=True, + exclude_same_chain=False, + filter_function=None, symmetric=False, ): """ @@ -299,9 +301,16 @@ def lddt( The distance bins for the score calculation, i.e if a distance deviation is within the first bin, the score is 1, if it is outside all bins, the score is 0. exclude_same_residue : bool, optional - If set to False, distances between atoms of the same residue are also - considered. - By default, only atom distances between different residues are considered. + If true, only atom distances between different residues are considered. + Otherwise, also atom distances within the same residue are included. + exclude_same_chain : bool, optional + If true, only atom distances between different chains are considered. + Otherwise, also atom distances within the same chain are included. + filter_function : Callable(ndarray, shape=(n,2), dtype=int -> ndarray, shape=(n,), dtype=bool), optional + Used for custom contact filtering, if the other parameters are not sufficient. + A function that takes an array of contact atom indices and returns a mask that + is ``True`` for all contacts that should be retained. + All other contacts are not considered for lDDT computation. symmetric : bool, optional If set to true, the *lDDT* score is computed symmetrically. This means both contacts found in the `reference` and `subject` structure are @@ -394,7 +403,13 @@ def lddt( ) contacts = _find_contacts( - reference, atom_mask, partner_mask, inclusion_radius, exclude_same_residue + reference, + atom_mask, + partner_mask, + inclusion_radius, + exclude_same_residue, + exclude_same_chain, + filter_function, ) if symmetric: if not isinstance(subject, AtomArray): @@ -403,7 +418,13 @@ def lddt( f"but got '{type(subject).__name__}'" ) subject_contacts = _find_contacts( - subject, atom_mask, partner_mask, inclusion_radius, exclude_same_residue + subject, + atom_mask, + partner_mask, + inclusion_radius, + exclude_same_residue, + exclude_same_chain, + filter_function, ) contacts = np.concatenate((contacts, subject_contacts), axis=0) # Adding additional contacts may introduce duplicates between the existing and @@ -532,7 +553,9 @@ def _find_contacts( atom_mask=None, partner_mask=None, inclusion_radius=15, - exclude_same_residue=True, + exclude_same_residue=False, + exclude_same_chain=True, + filter_function=None, ): """ Find contacts between the atoms in the given structure. @@ -555,9 +578,16 @@ def _find_contacts( inclusion_radius : float, optional Pairwise atom distances are considered within this radius. exclude_same_residue : bool, optional - If set to False, distances between atoms of the same residue are also - considered. - By default, only atom distances between different residues are considered. + If true, only atom distances between different residues are considered. + Otherwise, also atom distances within the same residue are included. + exclude_same_chain : bool, optional + If true, only atom distances between different chains are considered. + Otherwise, also atom distances within the same chain are included. + filter_function : Callable(ndarray, shape=(n,2), dtype=int -> ndarray, shape=(n,), dtype=bool), optional + Used for custom contact filtering, if the other parameters are not sufficient. + A function that takes an array of contact atom indices and returns a mask that + is ``True`` for all contacts that should be retained. + All other contacts are not considered for lDDT computation. Returns ------- @@ -588,7 +618,13 @@ def _find_contacts( # Convert into pairs of indices contacts = _to_sparse_indices(all_contacts) - if exclude_same_residue: + if exclude_same_chain: + # Do the same for the chain level + chain_indices = get_chain_positions(atoms, contacts.flatten()).reshape( + contacts.shape + ) + contacts = contacts[chain_indices[:, 0] != chain_indices[:, 1]] + elif exclude_same_residue: # Find the index of the residue for each atom residue_indices = get_residue_positions(atoms, contacts.flatten()).reshape( contacts.shape @@ -598,6 +634,14 @@ def _find_contacts( else: # In any case self-contacts should not be considered contacts = contacts[contacts[:, 0] != contacts[:, 1]] + if filter_function is not None: + mask = filter_function(contacts) + if mask.shape != (contacts.shape[0],): + raise IndexError( + f"Mask returned from filter function has shape {mask.shape}, " + f"but expected ({contacts.shape[0]},)" + ) + contacts = contacts[filter_function(contacts), :] return contacts diff --git a/tests/structure/test_compare.py b/tests/structure/test_compare.py index 70ea33513..0d9f98170 100644 --- a/tests/structure/test_compare.py +++ b/tests/structure/test_compare.py @@ -359,6 +359,67 @@ def test_lddt_mask(models, seed): assert test_lddt.tolist() == pytest.approx(ref_lddt.tolist()) +@pytest.mark.parametrize("exclude_same_chain", [False, True]) +@pytest.mark.parametrize("exclude_same_residue", [False, True]) +def test_lddt_filter_function(models, exclude_same_residue, exclude_same_chain): + """ + In :func:`lddt()`, mimic the `exclude_same_residue` or `exclude_same_chain` + parameter using a custom `filter_function` and expect to get the same results + compared to using these parameters directly. + """ + # Cut the model into two chains to test 'exclude_same_chain' + models = models.copy() + models.chain_id[models.res_id >= 11] = "B" + reference = models[0] + subject = models[1] + + if exclude_same_residue and exclude_same_chain: + + def filter_function(contacts): + return ( + reference.res_id[contacts[:, 0]] != reference.res_id[contacts[:, 1]] + ) & ( + reference.chain_id[contacts[:, 0]] != reference.chain_id[contacts[:, 1]] + ) + + elif exclude_same_residue: + + def filter_function(contacts): + return reference.res_id[contacts[:, 0]] != reference.res_id[contacts[:, 1]] + + elif exclude_same_chain: + + def filter_function(contacts): + return ( + reference.chain_id[contacts[:, 0]] != reference.chain_id[contacts[:, 1]] + ) + + else: + + def filter_function(contacts): + return np.full(contacts.shape[0], True) + + # Do not aggregate to make the test more strict + ref_lddt = struc.lddt( + reference, + subject, + exclude_same_residue=exclude_same_residue, + exclude_same_chain=exclude_same_chain, + aggregation="atom", + ) + + test_lddt = struc.lddt( + reference, + subject, + exclude_same_residue=False, + exclude_same_chain=False, + filter_function=filter_function, + aggregation="atom", + ) + + assert test_lddt.tolist() == pytest.approx(ref_lddt.tolist()) + + def test_custom_lddt_symmetric(models): """ Check that in :func:`lddt()` with ``symmetric=True`` the *lDDT* score is independent From 8ee77792f43eeacfcba8bc46e7574a9de3e7df05 Mon Sep 17 00:00:00 2001 From: Patrick Kunzmann Date: Sat, 25 Jan 2025 11:10:21 +0100 Subject: [PATCH 2/2] Handle case of empty structure --- src/biotite/structure/chains.py | 3 +++ src/biotite/structure/residues.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/biotite/structure/chains.py b/src/biotite/structure/chains.py index c4bbd4996..c4c9d59a3 100644 --- a/src/biotite/structure/chains.py +++ b/src/biotite/structure/chains.py @@ -64,6 +64,9 @@ def get_chain_starts(array, add_exclusive_stop=False): -------- get_residue_starts """ + if array.array_length() == 0: + return np.array([], dtype=int) + diff = np.diff(array.res_id) res_id_decrement = diff < 0 # This mask is 'true' at indices where the value changes diff --git a/src/biotite/structure/residues.py b/src/biotite/structure/residues.py index 61ae1712a..4ac6039cb 100644 --- a/src/biotite/structure/residues.py +++ b/src/biotite/structure/residues.py @@ -69,6 +69,9 @@ def get_residue_starts(array, add_exclusive_stop=False): [ 0 16 35 56 75 92 116 135 157 169 176 183 197 208 219 226 250 264 278 292 304] """ + if array.array_length() == 0: + return np.array([], dtype=int) + # These mask are 'true' at indices where the value changes chain_id_changes = array.chain_id[1:] != array.chain_id[:-1] res_id_changes = array.res_id[1:] != array.res_id[:-1]