diff --git a/src/molearn/data/pdb_data.py b/src/molearn/data/pdb_data.py index 7c3806d..1d25074 100644 --- a/src/molearn/data/pdb_data.py +++ b/src/molearn/data/pdb_data.py @@ -59,14 +59,21 @@ def __init__(self, filename=None, topology=None, fix_terminal=False, atoms=None) if atoms is not None: self.atomselect(atoms=atoms) - def import_pdb(self, filename: str | list[str], topology: str): + def import_pdb(self, filename: str | list[str], topology: str | None): """ Load one or multiple trajectory files - :parameter - * filename: the path the the trajectory as a str or a list of filepaths to multiple trajectories + :param str | list[str] filename: the path the trajectory as a str or a list of filepaths to multiple trajectories + :param str | None topology: the path the topology file for the trajector(y)ies """ - self._mol = mda.Universe(topology, filename) + + if isinstance(filename, list) and topology is None: + first_universe = mda.Universe(filename[0]) + self._mol = mda.Universe(first_universe._topology, filename) + elif topology is None: + self._mol = mda.Universe(filename) + else: + self._mol = mda.Universe(topology, filename) def fix_terminal(self): """ @@ -118,7 +125,7 @@ def prepare_dataset(self): self.dataset /= self.std self.dataset = torch.from_numpy(self.dataset).float() self.dataset = self.dataset.permute(0, 2, 1) - print(f"Dataset shape: {', '.join([str(i) for i in self.dataset.shape])}") + print(f"Dataset shape: {self.dataset.shape}") print(f"mean: {str(self.mean)}\n std: {str(self.std)}") def get_atominfo(self):