Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support concatenation of more than two AtomArray objects #712

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/apidoc.json
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
"Atom",
"AtomArray",
"AtomArrayStack",
"concatenate",
"array",
"stack",
"repeat",
Expand Down
150 changes: 113 additions & 37 deletions src/biotite/structure/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"Atom",
"AtomArray",
"AtomArrayStack",
"concatenate",
"array",
"stack",
"repeat",
Expand All @@ -22,6 +23,7 @@

import abc
import numbers
from collections.abc import Sequence
import numpy as np
from biotite.copyable import Copyable
from biotite.structure.bonds import BondList
Expand Down Expand Up @@ -420,42 +422,7 @@ def __len__(self):
return self._array_length

def __add__(self, array):
if not isinstance(self, type(array)):
raise TypeError("Can only concatenate two arrays or two stacks")
# Create either new array or stack, depending of the own type
if isinstance(self, AtomArray):
concat = AtomArray(length=self._array_length + array._array_length)
if isinstance(self, AtomArrayStack):
concat = AtomArrayStack(
self.stack_depth(), self._array_length + array._array_length
)

concat._coord = np.concatenate((self._coord, array.coord), axis=-2)

# Transfer only annotations,
# which are existent in both operands
arr_categories = list(array._annot.keys())
for category in self._annot.keys():
if category in arr_categories:
annot = self._annot[category]
arr_annot = array._annot[category]
concat._annot[category] = np.concatenate((annot, arr_annot))

# Concatenate bonds lists,
# if at least one of them contains bond information
if self._bonds is not None or array._bonds is not None:
bonds1 = self._bonds
bonds2 = array._bonds
if bonds1 is None:
bonds1 = BondList(self._array_length)
if bonds2 is None:
bonds2 = BondList(array._array_length)
concat._bonds = bonds1 + bonds2

# Copy box
if self._box is not None:
concat._box = np.copy(self._box)
return concat
return concatenate([self, array])

def __copy_fill__(self, clone):
super().__copy_fill__(clone)
Expand Down Expand Up @@ -619,6 +586,7 @@ class AtomArray(_AtomArrayBase):
:class:`AtomArray` is done with the '+' operator.
Only the annotation categories, which are existing in both arrays,
are transferred to the new array.
For a list of :class:`AtomArray` objects, use :func:`concatenate()`.

Optionally, an :class:`AtomArray` can store chemical bond
information via a :class:`BondList` object.
Expand Down Expand Up @@ -891,7 +859,9 @@ class AtomArrayStack(_AtomArrayBase):
:class:`AtomArray` instance.

Concatenation of atoms for each array in the stack is done using the
'+' operator. For addition of atom arrays onto the stack use the
'+' operator.
For a list of :class:`AtomArray` objects, use :func:`concatenate()`.
For addition of atom arrays onto the stack use the
:func:`stack()` method.

The :attr:`box` attribute has the shape *m x 3 x 3*, as the cell
Expand Down Expand Up @@ -1305,6 +1275,112 @@ def stack(arrays):
return array_stack


def concatenate(atoms):
"""
Concatenate multiple :class:`AtomArray` or :class:`AtomArrayStack` objects into
a single :class:`AtomArray` or :class:`AtomArrayStack`, respectively.

Parameters
----------
atoms : iterable object of AtomArray or AtomArrayStack
The atoms to be concatenated.
:class:`AtomArray` cannot be mixed with :class:`AtomArrayStack`.

Returns
-------
concatenated_atoms : AtomArray or AtomArrayStack
The concatenated atoms, i.e. its ``array_length()`` is the sum of the
``array_length()`` of the input ``atoms``.

Notes
-----
The following rules apply:

- Only the annotation categories that exist in all elements are transferred.
- The box of the first element that has a box is transferred, if any.
- The bonds of all elements are concatenated, if any element has associated bonds.
For elements without a :class:`BondList` an empty :class:`BondList` is assumed.

Examples
--------

>>> atoms1 = array([
... Atom([1,2,3], res_id=1, atom_name="N"),
... Atom([4,5,6], res_id=1, atom_name="CA"),
... Atom([7,8,9], res_id=1, atom_name="C")
... ])
>>> atoms2 = array([
... Atom([1,2,3], res_id=2, atom_name="N"),
... Atom([4,5,6], res_id=2, atom_name="CA"),
... Atom([7,8,9], res_id=2, atom_name="C")
... ])
>>> print(concatenate([atoms1, atoms2]))
1 N 1.000 2.000 3.000
1 CA 4.000 5.000 6.000
1 C 7.000 8.000 9.000
2 N 1.000 2.000 3.000
2 CA 4.000 5.000 6.000
2 C 7.000 8.000 9.000
"""
# Ensure that the atoms can be iterated over multiple times
if not isinstance(atoms, Sequence):
atoms = list(atoms)

length = 0
depth = None
element_type = None
common_categories = set(atoms[0].get_annotation_categories())
box = None
has_bonds = False
for element in atoms:
if element_type is None:
element_type = type(element)
else:
if not isinstance(element, element_type):
raise TypeError(
f"Cannot concatenate '{type(element).__name__}' "
f"with '{element_type.__name__}'"
)
length += element.array_length()
if isinstance(element, AtomArrayStack):
if depth is None:
depth = element.stack_depth()
else:
if element.stack_depth() != depth:
raise IndexError("The stack depths are not equal")
common_categories &= set(element.get_annotation_categories())
if element.box is not None and box is None:
box = element.box
if element.bonds is not None:
has_bonds = True

if element_type == AtomArray:
concat_atoms = AtomArray(length)
elif element_type == AtomArrayStack:
concat_atoms = AtomArrayStack(depth, length)
concat_atoms.coord = np.concatenate([element.coord for element in atoms], axis=-2)
for category in common_categories:
concat_atoms.set_annotation(
category,
np.concatenate(
[element.get_annotation(category) for element in atoms], axis=0
),
)
concat_atoms.box = box
if has_bonds:
# Concatenate bonds of all elements
concat_atoms.bonds = BondList.concatenate(
[
element.bonds
if element.bonds is not None
else BondList(element.array_length())
for element in atoms
]
)

return concat_atoms


def repeat(atoms, coord):
"""
Repeat atoms (:class:`AtomArray` or :class:`AtomArrayStack`)
Expand Down
71 changes: 57 additions & 14 deletions src/biotite/structure/bonds.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cimport cython
cimport numpy as np
from libc.stdlib cimport free, realloc

from collections.abc import Sequence
import itertools
import numbers
from enum import IntEnum
Expand Down Expand Up @@ -309,6 +310,61 @@ class BondList(Copyable):
self._bonds = np.zeros((0, 3), dtype=np.uint32)
self._max_bonds_per_atom = 0

@staticmethod
def concatenate(bonds_lists):
"""
Concatenate multiple :class:`BondList` objects into a single
:class:`BondList`, respectively.

Parameters
----------
bonds_lists : iterable object of BondList
The bond lists to be concatenated.

Returns
-------
concatenated_bonds : BondList
The concatenated bond lists.

Examples
--------

>>> bonds1 = BondList(2, np.array([(0, 1)]))
>>> bonds2 = BondList(3, np.array([(0, 1), (0, 2)]))
>>> merged_bonds = BondList.concatenate([bonds1, bonds2])
>>> print(merged_bonds.get_atom_count())
5
>>> print(merged_bonds.as_array()[:, :2])
[[0 1]
[2 3]
[2 4]]
"""
# Ensure that the bonds_lists can be iterated over multiple times
if not isinstance(bonds_lists, Sequence):
bonds_lists = list(bonds_lists)

cdef np.ndarray merged_bonds = np.concatenate(
[bond_list._bonds for bond_list in bonds_lists]
)
# Offset the indices of appended bonds list
# (consistent with addition of AtomArray)
cdef int start = 0, stop = 0
cdef int cum_atom_count = 0
for bond_list in bonds_lists:
stop = start + bond_list._bonds.shape[0]
merged_bonds[start : stop, :2] += cum_atom_count
cum_atom_count += bond_list._atom_count
start = stop

cdef merged_bond_list = BondList(cum_atom_count)
# Array is not used in constructor to prevent unnecessary
# maximum and redundant bond calculation
merged_bond_list._bonds = merged_bonds
merged_bond_list._max_bonds_per_atom = max(
[bond_list._max_bonds_per_atom for bond_list in bonds_lists]
)
return merged_bond_list

def __copy_create__(self):
# Create empty bond list to prevent
# unnecessary removal of redundant atoms
Expand Down Expand Up @@ -1002,20 +1058,7 @@ class BondList(Copyable):
)

def __add__(self, bond_list):
cdef np.ndarray merged_bonds \
= np.concatenate([self._bonds, bond_list._bonds])
# Offset the indices of appended bonds list
# (consistent with addition of AtomArray)
merged_bonds[len(self._bonds):, :2] += self._atom_count
cdef uint32 merged_count = self._atom_count + bond_list._atom_count
cdef merged_bond_list = BondList(merged_count)
# Array is not used in constructor to prevent unnecessary
# maximum and redundant bond calculation
merged_bond_list._bonds = merged_bonds
merged_bond_list._max_bonds_per_atom = max(
self._max_bonds_per_atom, bond_list._max_bonds_per_atom
)
return merged_bond_list
return BondList.concatenate([self, bond_list])

def __getitem__(self, index):
## Variables for both, integer and boolean index arrays
Expand Down
8 changes: 8 additions & 0 deletions tests/structure/test_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ def test_stack_indexing(stack):
assert filtered_stack.array_length() == 1


def test_concatenate_single(array, stack):
"""
Concatenation of a single array or stack should return the same object.
"""
assert array == struc.concatenate([array])
assert stack == struc.concatenate([stack])


def test_concatenation(array, stack):
concat_array = array[2:] + array[:2]
assert concat_array.chain_id.tolist() == ["B", "B", "B", "A", "A"]
Expand Down
36 changes: 36 additions & 0 deletions tests/structure/test_bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further
# information.

import itertools
from os.path import join
import numpy as np
import pytest
Expand Down Expand Up @@ -132,6 +133,41 @@ def test_modification(bond_list):
assert bond_list.as_array().tolist() == [[1, 3, 1], [3, 4, 0], [4, 6, 0], [1, 4, 0]]


@pytest.mark.parametrize("seed", range(10))
def test_concatenation_and_splitting(seed):
"""
Repeatedly concatenating and splitting a `BondList` with the same indices
should recover the same object.
"""
N_BOND_LISTS = 5
MAX_ATOMS = 10
MAX_BONDS = 10

rng = np.random.default_rng(seed)
split_bond_lists = []
starts = [0]
for _ in range(N_BOND_LISTS):
n_atoms = rng.integers(1, MAX_ATOMS)
bonds = rng.integers(0, n_atoms, size=(MAX_BONDS, 2))
bond_types = rng.integers(0, len(struc.BondType), size=MAX_BONDS)
split_bond_lists.append(
struc.BondList(
n_atoms, np.concatenate([bonds, bond_types[:, np.newaxis]], axis=1)
)
)
starts.append(starts[-1] + n_atoms)

concatenated_bond_list = struc.BondList.concatenate(split_bond_lists)
resplit_bond_lists = [
concatenated_bond_list[start:stop] for start, stop in itertools.pairwise(starts)
]

for ref_bond_list, test_bond_list in zip(
split_bond_lists, resplit_bond_lists, strict=True
):
assert ref_bond_list == test_bond_list


def test_add_two_bond_list():
"""
Test adding two `BondList` objects.
Expand Down
Loading