From cd6f3474e84bd017683eb7b009c497c8b6a5ff78 Mon Sep 17 00:00:00 2001 From: gwirn <71886945+gwirn@users.noreply.github.com> Date: Wed, 26 Jun 2024 11:07:45 +0200 Subject: [PATCH] added flag to save json log file and improved trajectory preparation --- src/molearn/data/prepare.py | 35 ++++++++++++++++++++++++--------- src/molearn/trainers/trainer.py | 29 ++++++++++++++++----------- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/src/molearn/data/prepare.py b/src/molearn/data/prepare.py index 7b7252f..21f7ca5 100644 --- a/src/molearn/data/prepare.py +++ b/src/molearn/data/prepare.py @@ -145,7 +145,8 @@ def read_traj(self) -> None: top0 = None ucell0 = None for ct, (trp, top) in enumerate(zip(self.traj_path, self.topo_path)): - print(f"\tLoading {os.path.basename(trp)}") + if self.verbose: + print(f"\tLoading {os.path.basename(trp)}") loaded = None try: # do not enforce topology file on this formats @@ -171,24 +172,38 @@ def read_traj(self) -> None: loaded.unitcell_vectors = ucell0 multi_traj.append(loaded) self.traj = md.join(multi_traj) + # Recenter and apply periodic boundary + if self.image_mol: + try: + if self.verbose: + print("Imaging faild - retrying with supplying anchor molecules") + self.traj.image_molecules(inplace=True) + except ValueError: + try: + self.traj.image_molecules( + inplace=True, + anchor_molecules=[set(self.traj.topology.residue(0).atoms)], + ) + except ValueError as e: + print( + f"Unable to image molecule due to '{e}' - will just recenter it" + ) + self.traj.superpose(self.traj[0]) + # maybe not needed after image_molecules + self.traj.center_coordinates() # converts ELEMENT names from eg "Cd" -> "C" to avoid later complications topo_table, topo_bonds = self.traj.topology.to_dataframe() topo_table["element"] = topo_table["element"].apply( lambda x: x if len(x.strip()) <= 1 else x.strip()[0] ) + if self.verbose: + print("Saving new topology") self.traj.topology = md.Topology.from_dataframe(topo_table, topo_bonds) # save new topology self.traj[0].save_pdb( os.path.join(self.outpath, f"./{self.traj_name}_NEW_TOPO.pdb") ) - # Recenter and apply periodic boundary - if self.image_mol: - try: - self.traj.image_molecules(inplace=True) - except ValueError as e: - print(f"Unable to image molecule due to '{e}' - will just recenter it") - # maybe not needed after image_molecules - self.traj.center_coordinates() + n_frames = self.traj.n_frames # which index separated indices from training and test dataset self.test_border = int(n_frames * (1.0 - self.test_size)) @@ -205,6 +220,8 @@ def read_traj(self) -> None: atom_indices = [ a.index for a in train_traj.topology.atoms if a.element.symbol != "H" ] + if self.verbose: + print("Calculating disance matrix") # distance matrix between all frames self.traj_dists = np.empty((n_train_frames, n_train_frames)) for i in range(n_train_frames): diff --git a/src/molearn/trainers/trainer.py b/src/molearn/trainers/trainer.py index 905a3ef..e66c849 100644 --- a/src/molearn/trainers/trainer.py +++ b/src/molearn/trainers/trainer.py @@ -31,10 +31,11 @@ class Trainer: """ - def __init__(self, device=None, log_filename="log_file.dat"): + def __init__(self, device=None, log_filename="log_file.dat", json_log=False): """ :param torch.Device device: if not given will be determinined automatically based on torch.cuda.is_available() :param str log_filename: (default: 'default_log_filename.json') file used to log outputs to + :param bool json_log: True to use json.dump to save the log file """ if not device: self.device = ( @@ -52,6 +53,7 @@ def __init__(self, device=None, log_filename="log_file.dat"): self.verbose = True self.log_filename = "default_log_filename.csv" self.scheduler_key = None + self.json_log = json_log def get_network_summary(self): """ @@ -151,17 +153,22 @@ def log(self, log_dict, verbose=None): print(f"{k: <{max_key_len+1}}: {v:.6f}") print() - # create header if file doesn't exist => first epoch - if not os.path.isfile(self.log_filename): - with open(self.log_filename, "a") as f: - f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n") + if not self.json_log: + # create header if file doesn't exist => first epoch + if not os.path.isfile(self.log_filename): + with open(self.log_filename, "a") as f: + f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n") - with open(self.log_filename, "a") as f: - # just try to format if it is not a Failure - if "Failure" not in log_dict.values(): - f.write(f"{','.join([str(v) for v in log_dict.values()])}\n") - else: - dump = json.dumps(log_dict) + with open(self.log_filename, "a") as f: + # just try to format if it is not a Failure + if "Failure" not in log_dict.values(): + f.write(f"{','.join([str(v) for v in log_dict.values()])}\n") + else: + dump = json.dumps(log_dict) + f.write(dump + "\n") + else: + dump = json.dumps(log_dict) + with open(self.log_filename, "a") as f: f.write(dump + "\n") def scheduler_step(self, logs):