Skip to content

Commit

Permalink
fixing file extension check
Browse files Browse the repository at this point in the history
  • Loading branch information
gwirn committed Jun 6, 2024
1 parent c6517e3 commit 9eecc8f
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/molearn/data/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def read_traj(self) -> None:
try:
# do not enforce topology file on this formats
fext = os.path.splitext(self.traj_path)[-1]
if any([fext == ".pdb", fext == "h5", fext == "lh5"]):
if any([fext == ".pdb", fext == ".h5", fext == ".lh5"]):
self.traj = md.load(self.traj_path)
else:
self.traj = md.load(self.traj_path, self.topo_path)
Expand All @@ -132,7 +132,8 @@ def read_traj(self) -> None:
elif isinstance(self.traj_path, list):
if isinstance(self.traj_path, list) and self.topo_path is None:
fext = os.path.splitext(self.traj_path[0])[-1]
if any([fext == ".pdb", fext == "h5", fext == "lh5"]):
# file type doesn't need a topo but zip needs equally long list
if any([fext == ".pdb", fext == ".h5", fext == ".lh5"]):
self.topo_path = [None] * len(self.traj_path)
assert isinstance(
self.topo_path, list
Expand All @@ -144,12 +145,12 @@ def read_traj(self) -> None:
top0 = None
ucell0 = None
for ct, (trp, top) in enumerate(zip(self.traj_path, self.topo_path)):
print(f"\tLoading Trajectory {ct}")
print(f"\tLoading {os.path.basename(trp)}")
loaded = None
try:
# do not enforce topology file on this formats
trp_ext = os.path.splitext(trp)
if any([trp_ext == ".pdb", trp_ext == "h5", trp_ext == "lh5"]):
trp_ext = os.path.splitext(trp)[-1]
if any([trp_ext == ".pdb", trp_ext == ".h5", trp_ext == ".lh5"]):
loaded = md.load(trp)
else:
loaded = md.load(trp, top)
Expand Down

0 comments on commit 9eecc8f

Please sign in to comment.