Skip to content

Commit

Permalink
Merge pull request #21 from Degiacomi-Lab/prepare_fix
Browse files Browse the repository at this point in the history
added flag to save json log file and improved trajectory preparation
  • Loading branch information
degiacom authored Jun 27, 2024
2 parents 92e0f31 + 28aff03 commit 6754592
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 22 deletions.
3 changes: 2 additions & 1 deletion examples/prepare_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def main():
"./data/preparation/topo_MurDclosed1F.pdb",
],
test_size=0.0,
n_cluster=1500,
n_cluster=15,
outpath=storage_path,
verbose=True,
image_mol=True,
)
# reading in the trajectories and removing of all atoms apart from protein atoms
tm.read_traj()
Expand Down
43 changes: 33 additions & 10 deletions src/molearn/data/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,13 @@ def _loading_fallback(self, traj_path, topo_path):
"""
return load_func(traj_path, topo_path)

def read_traj(self) -> None:
def read_traj(self, atom_indices=None, ref_atom_indices=None) -> None:
"""
Read in one or multiple trajectories, remove everything but protein atoms and
image the molecule to center it in the water box, and create a training/validation
and test split.
:param array_like | None atom_indices: The indices of the atoms to superpose. If not supplied, all atoms will be used.
:param array_like | None ref_atom_indices: Use these atoms on the reference structure. If not supplied, the same atom indices will be used for this trajectory and the reference one.
"""
if self.verbose:
print("Reading trajectory")
Expand Down Expand Up @@ -145,7 +147,8 @@ def read_traj(self) -> None:
top0 = None
ucell0 = None
for ct, (trp, top) in enumerate(zip(self.traj_path, self.topo_path)):
print(f"\tLoading {os.path.basename(trp)}")
if self.verbose:
print(f"\tLoading {os.path.basename(trp)}")
loaded = None
try:
# do not enforce topology file on this formats
Expand All @@ -171,24 +174,42 @@ def read_traj(self) -> None:
loaded.unitcell_vectors = ucell0
multi_traj.append(loaded)
self.traj = md.join(multi_traj)
# Recenter and apply periodic boundary
if self.image_mol:
try:
if self.verbose:
print("Imaging faild - retrying with supplying anchor molecules")
self.traj.image_molecules(inplace=True)
except ValueError:
try:
self.traj.image_molecules(
inplace=True,
anchor_molecules=[set(self.traj.topology.residue(0).atoms)],
)
except ValueError as e:
print(
f"Unable to image molecule due to '{e}' - will just recenter it"
)
self.traj.superpose(
self.traj[0],
atom_indices=atom_indices,
ref_atom_indices=ref_atom_indices,
)
# maybe not needed after image_molecules
self.traj.center_coordinates()
# converts ELEMENT names from eg "Cd" -> "C" to avoid later complications
topo_table, topo_bonds = self.traj.topology.to_dataframe()
topo_table["element"] = topo_table["element"].apply(
lambda x: x if len(x.strip()) <= 1 else x.strip()[0]
)
if self.verbose:
print("Saving new topology")
self.traj.topology = md.Topology.from_dataframe(topo_table, topo_bonds)
# save new topology
self.traj[0].save_pdb(
os.path.join(self.outpath, f"./{self.traj_name}_NEW_TOPO.pdb")
)
# Recenter and apply periodic boundary
if self.image_mol:
try:
self.traj.image_molecules(inplace=True)
except ValueError as e:
print(f"Unable to image molecule due to '{e}' - will just recenter it")
# maybe not needed after image_molecules
self.traj.center_coordinates()

n_frames = self.traj.n_frames
# which index separated indices from training and test dataset
self.test_border = int(n_frames * (1.0 - self.test_size))
Expand All @@ -205,6 +226,8 @@ def read_traj(self) -> None:
atom_indices = [
a.index for a in train_traj.topology.atoms if a.element.symbol != "H"
]
if self.verbose:
print("Calculating disance matrix")
# distance matrix between all frames
self.traj_dists = np.empty((n_train_frames, n_train_frames))
for i in range(n_train_frames):
Expand Down
29 changes: 18 additions & 11 deletions src/molearn/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ class Trainer:
"""

def __init__(self, device=None, log_filename="log_file.dat"):
def __init__(self, device=None, log_filename="log_file.dat", json_log=False):
"""
:param torch.Device device: if not given will be determinined automatically based on torch.cuda.is_available()
:param str log_filename: (default: 'default_log_filename.json') file used to log outputs to
:param bool json_log: True to use json.dump to save the log file
"""
if not device:
self.device = (
Expand All @@ -52,6 +53,7 @@ def __init__(self, device=None, log_filename="log_file.dat"):
self.verbose = True
self.log_filename = "default_log_filename.csv"
self.scheduler_key = None
self.json_log = json_log

def get_network_summary(self):
"""
Expand Down Expand Up @@ -151,17 +153,22 @@ def log(self, log_dict, verbose=None):
print(f"{k: <{max_key_len+1}}: {v:.6f}")
print()

# create header if file doesn't exist => first epoch
if not os.path.isfile(self.log_filename):
with open(self.log_filename, "a") as f:
f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n")
if not self.json_log:
# create header if file doesn't exist => first epoch
if not os.path.isfile(self.log_filename):
with open(self.log_filename, "a") as f:
f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n")

with open(self.log_filename, "a") as f:
# just try to format if it is not a Failure
if "Failure" not in log_dict.values():
f.write(f"{','.join([str(v) for v in log_dict.values()])}\n")
else:
dump = json.dumps(log_dict)
with open(self.log_filename, "a") as f:
# just try to format if it is not a Failure
if "Failure" not in log_dict.values():
f.write(f"{','.join([str(v) for v in log_dict.values()])}\n")
else:
dump = json.dumps(log_dict)
f.write(dump + "\n")
else:
dump = json.dumps(log_dict)
with open(self.log_filename, "a") as f:
f.write(dump + "\n")

def scheduler_step(self, logs):
Expand Down

0 comments on commit 6754592

Please sign in to comment.