diff --git a/examples/prepare_example.py b/examples/prepare_example.py index 6e4dfca..8ee33d7 100644 --- a/examples/prepare_example.py +++ b/examples/prepare_example.py @@ -21,9 +21,10 @@ def main(): "./data/preparation/topo_MurDclosed1F.pdb", ], test_size=0.0, - n_cluster=1500, + n_cluster=15, outpath=storage_path, verbose=True, + image_mol=True, ) # reading in the trajectories and removing of all atoms apart from protein atoms tm.read_traj() diff --git a/src/molearn/data/prepare.py b/src/molearn/data/prepare.py index 7b7252f..e0663ff 100644 --- a/src/molearn/data/prepare.py +++ b/src/molearn/data/prepare.py @@ -102,11 +102,13 @@ def _loading_fallback(self, traj_path, topo_path): """ return load_func(traj_path, topo_path) - def read_traj(self) -> None: + def read_traj(self, atom_indices=None, ref_atom_indices=None) -> None: """ Read in one or multiple trajectories, remove everything but protein atoms and image the molecule to center it in the water box, and create a training/validation and test split. + :param array_like | None atom_indices: The indices of the atoms to superpose. If not supplied, all atoms will be used. + :param array_like | None ref_atom_indices: Use these atoms on the reference structure. If not supplied, the same atom indices will be used for this trajectory and the reference one. """ if self.verbose: print("Reading trajectory") @@ -145,7 +147,8 @@ def read_traj(self) -> None: top0 = None ucell0 = None for ct, (trp, top) in enumerate(zip(self.traj_path, self.topo_path)): - print(f"\tLoading {os.path.basename(trp)}") + if self.verbose: + print(f"\tLoading {os.path.basename(trp)}") loaded = None try: # do not enforce topology file on this formats @@ -171,24 +174,42 @@ def read_traj(self) -> None: loaded.unitcell_vectors = ucell0 multi_traj.append(loaded) self.traj = md.join(multi_traj) + # Recenter and apply periodic boundary + if self.image_mol: + try: + if self.verbose: + print("Imaging faild - retrying with supplying anchor molecules") + self.traj.image_molecules(inplace=True) + except ValueError: + try: + self.traj.image_molecules( + inplace=True, + anchor_molecules=[set(self.traj.topology.residue(0).atoms)], + ) + except ValueError as e: + print( + f"Unable to image molecule due to '{e}' - will just recenter it" + ) + self.traj.superpose( + self.traj[0], + atom_indices=atom_indices, + ref_atom_indices=ref_atom_indices, + ) + # maybe not needed after image_molecules + self.traj.center_coordinates() # converts ELEMENT names from eg "Cd" -> "C" to avoid later complications topo_table, topo_bonds = self.traj.topology.to_dataframe() topo_table["element"] = topo_table["element"].apply( lambda x: x if len(x.strip()) <= 1 else x.strip()[0] ) + if self.verbose: + print("Saving new topology") self.traj.topology = md.Topology.from_dataframe(topo_table, topo_bonds) # save new topology self.traj[0].save_pdb( os.path.join(self.outpath, f"./{self.traj_name}_NEW_TOPO.pdb") ) - # Recenter and apply periodic boundary - if self.image_mol: - try: - self.traj.image_molecules(inplace=True) - except ValueError as e: - print(f"Unable to image molecule due to '{e}' - will just recenter it") - # maybe not needed after image_molecules - self.traj.center_coordinates() + n_frames = self.traj.n_frames # which index separated indices from training and test dataset self.test_border = int(n_frames * (1.0 - self.test_size)) @@ -205,6 +226,8 @@ def read_traj(self) -> None: atom_indices = [ a.index for a in train_traj.topology.atoms if a.element.symbol != "H" ] + if self.verbose: + print("Calculating disance matrix") # distance matrix between all frames self.traj_dists = np.empty((n_train_frames, n_train_frames)) for i in range(n_train_frames): diff --git a/src/molearn/trainers/trainer.py b/src/molearn/trainers/trainer.py index 905a3ef..e66c849 100644 --- a/src/molearn/trainers/trainer.py +++ b/src/molearn/trainers/trainer.py @@ -31,10 +31,11 @@ class Trainer: """ - def __init__(self, device=None, log_filename="log_file.dat"): + def __init__(self, device=None, log_filename="log_file.dat", json_log=False): """ :param torch.Device device: if not given will be determinined automatically based on torch.cuda.is_available() :param str log_filename: (default: 'default_log_filename.json') file used to log outputs to + :param bool json_log: True to use json.dump to save the log file """ if not device: self.device = ( @@ -52,6 +53,7 @@ def __init__(self, device=None, log_filename="log_file.dat"): self.verbose = True self.log_filename = "default_log_filename.csv" self.scheduler_key = None + self.json_log = json_log def get_network_summary(self): """ @@ -151,17 +153,22 @@ def log(self, log_dict, verbose=None): print(f"{k: <{max_key_len+1}}: {v:.6f}") print() - # create header if file doesn't exist => first epoch - if not os.path.isfile(self.log_filename): - with open(self.log_filename, "a") as f: - f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n") + if not self.json_log: + # create header if file doesn't exist => first epoch + if not os.path.isfile(self.log_filename): + with open(self.log_filename, "a") as f: + f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n") - with open(self.log_filename, "a") as f: - # just try to format if it is not a Failure - if "Failure" not in log_dict.values(): - f.write(f"{','.join([str(v) for v in log_dict.values()])}\n") - else: - dump = json.dumps(log_dict) + with open(self.log_filename, "a") as f: + # just try to format if it is not a Failure + if "Failure" not in log_dict.values(): + f.write(f"{','.join([str(v) for v in log_dict.values()])}\n") + else: + dump = json.dumps(log_dict) + f.write(dump + "\n") + else: + dump = json.dumps(log_dict) + with open(self.log_filename, "a") as f: f.write(dump + "\n") def scheduler_step(self, logs):