diff --git a/src/molearn/data/prepare.py b/src/molearn/data/prepare.py index 8a43903..7b7252f 100644 --- a/src/molearn/data/prepare.py +++ b/src/molearn/data/prepare.py @@ -120,7 +120,7 @@ def read_traj(self) -> None: try: # do not enforce topology file on this formats fext = os.path.splitext(self.traj_path)[-1] - if any([fext == ".pdb", fext == "h5", fext == "lh5"]): + if any([fext == ".pdb", fext == ".h5", fext == ".lh5"]): self.traj = md.load(self.traj_path) else: self.traj = md.load(self.traj_path, self.topo_path) @@ -132,7 +132,8 @@ def read_traj(self) -> None: elif isinstance(self.traj_path, list): if isinstance(self.traj_path, list) and self.topo_path is None: fext = os.path.splitext(self.traj_path[0])[-1] - if any([fext == ".pdb", fext == "h5", fext == "lh5"]): + # file type doesn't need a topo but zip needs equally long list + if any([fext == ".pdb", fext == ".h5", fext == ".lh5"]): self.topo_path = [None] * len(self.traj_path) assert isinstance( self.topo_path, list @@ -144,12 +145,12 @@ def read_traj(self) -> None: top0 = None ucell0 = None for ct, (trp, top) in enumerate(zip(self.traj_path, self.topo_path)): - print(f"\tLoading Trajectory {ct}") + print(f"\tLoading {os.path.basename(trp)}") loaded = None try: # do not enforce topology file on this formats - trp_ext = os.path.splitext(trp) - if any([trp_ext == ".pdb", trp_ext == "h5", trp_ext == "lh5"]): + trp_ext = os.path.splitext(trp)[-1] + if any([trp_ext == ".pdb", trp_ext == ".h5", trp_ext == ".lh5"]): loaded = md.load(trp) else: loaded = md.load(trp, top)