From 8700a7efc0533bbf0a0eb95675209140491d1b93 Mon Sep 17 00:00:00 2001 From: gwirn <71886945+gwirn@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:04:45 +0200 Subject: [PATCH] added trainer, model and examples --- examples/transformer_example.py | 49 ++++++ examples/transformer_sample_example.py | 112 +++++++++++++ src/molearn/models/transformer.py | 120 ++++++++++++++ src/molearn/trainers/__init__.py | 23 ++- src/molearn/trainers/transformer_trainer.py | 168 ++++++++++++++++++++ 5 files changed, 465 insertions(+), 7 deletions(-) create mode 100644 examples/transformer_example.py create mode 100644 examples/transformer_sample_example.py create mode 100644 src/molearn/models/transformer.py create mode 100644 src/molearn/trainers/transformer_trainer.py diff --git a/examples/transformer_example.py b/examples/transformer_example.py new file mode 100644 index 0000000..feb6ddf --- /dev/null +++ b/examples/transformer_example.py @@ -0,0 +1,49 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.abspath(os.pardir), "src")) +from molearn.data import PDBData +from molearn.trainers import Transformer_Trainer +from molearn.models.transformer import TransformerCoordGen +import torch + + +def main(): + ##### Load Data ##### + data = PDBData() + data.import_pdb( + "./clustered/MurDopen_CLUSTER_aggl_train.dcd", + "./clustered/MurDopen_NEW_TOPO.pdb", + ) + data.fix_terminal() + data.atomselect(atoms=["CA", "C", "N", "CB", "O"]) + + ##### Prepare Trainer ##### + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + trainer = Transformer_Trainer(device=device) + + trainer.set_data(data, batch_size=8, validation_split=0.1, manual_seed=25) + + trainer.set_autoencoder(TransformerCoordGen) + trainer.prepare_optimiser() + + ##### Training Loop ##### + # Keep training until loss does not improve for 32 consecutive epochs + + runkwargs = dict( + log_filename="log_file.dat", + log_folder="transformer_checkpoints", + checkpoint_folder="transformer_checkpoints", + ) + + best = 1e24 + while True: + trainer.run(max_epochs=32 + trainer.epoch, **runkwargs) + if not best > trainer.best: + break + best = trainer.best + print(f"best {trainer.best}, best_filename {trainer.best_name}") + + +if __name__ == "__main__": + main() diff --git a/examples/transformer_sample_example.py b/examples/transformer_sample_example.py new file mode 100644 index 0000000..39b982d --- /dev/null +++ b/examples/transformer_sample_example.py @@ -0,0 +1,112 @@ +import sys +import os +import argparse + +import torch + +from torch.optim import Adam +import numpy as np + +sys.path.insert(0, os.path.join(os.path.abspath(os.pardir), "src")) +from molearn.data import PDBData +from molearn.models.transformer import ( + TransformerCoordGen, + generate_coordinates, + gen_xyz, +) + +torch.manual_seed(42) +np.random.seed(42) +# Check if GPU is available +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + + +def qr_full(num_samples=1): + z = np.random.randn(num_samples, 3, 3) + q, r = np.linalg.qr(z) + sign = 2 * (np.diagonal(r, axis1=-2, axis2=-1) >= 0) - 1 + rot = q + rot *= sign[..., None, :] + rot[:, 0, :] *= np.linalg.det(rot)[..., None] + return rot + + +# Initialize and train the model +model = TransformerCoordGen().to(device) +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument( + "-tr", "--trajectory", type=str, required=True, help="path to trajectory" +) +parser.add_argument( + "-to", "--topology", type=str, required=True, help="path to topology" +) +parser.add_argument( + "-o", + "--outpath", + type=str, + required=True, + help="path to directory where things should get stored", +) + +args = parser.parse_args() + +if not os.path.isdir(args.outpath): + os.mkdir(args.outpath) +##### Load Data ##### + +data = PDBData() +data.import_pdb(args.trajectory, args.topology) +data.fix_terminal() +data.atomselect(atoms=["CA", "C", "N", "CB", "O"]) +data.prepare_dataset() +data.dataset = data.dataset.permute(0, 2, 1) +train_loader, test_loader = data.get_dataloader( + batch_size=32, validation_split=0.1, manual_seed=25 +) + +checkpoint = torch.load( + "./transformer_checkpoints/checkpoint_epoch1110_loss0.0009468746138736606.ckpt" +) +model.load_state_dict(checkpoint["model_state_dict"]) +optimizer = Adam(model.parameters()) +optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) +epoch = checkpoint["epoch"] +loss = checkpoint["loss"] +model.eval() + + +raw_coord = [] +with open("./transformer/coords/init.txt", "r") as cfile: + for i in cfile: + raw_coord.append(i.strip().split(" ")) +raw_coord = np.asarray(raw_coord, dtype=float) + + +seq_len = list(data.dataset.shape)[1] +n_residues = [5, 20, 50, 100, 150, 200, 320] +given_seq = 120 * 5 +for g in n_residues: + given_seq = g * 5 + for bs, batch in enumerate(test_loader): + for cs, sample in enumerate(batch[0]): + print(f"Batch {bs} protein {cs}") + sample = sample.permute(0, 1).cpu().numpy() + start_sequence = sample[:given_seq, :] + n_rand = start_sequence.shape[0] * start_sequence.shape[1] + noise = (np.random.ranf(n_rand) * 0.05).reshape(start_sequence.shape) + # add noise and randomly rotate coodrdiantes + new_coords = start_sequence + noise # @ qr_full() + generated_coords = ( + # use this when rotation is used + # generate_coordinates(model, new_coords[0], device, seq_len - given_seq) + generate_coordinates(model, new_coords, device, seq_len - given_seq) + * data.std + + data.mean + ) + sample = sample * data.std + data.mean + gen_xyz(generated_coords, f"{args.outpath}/{g}lss_{cs}_pred.xyz") + gen_xyz(sample, f"{args.outpath}/{g}lss_{cs}_gt.xyz") + if cs > 5: + break + break diff --git a/src/molearn/models/transformer.py b/src/molearn/models/transformer.py new file mode 100644 index 0000000..9f5966f --- /dev/null +++ b/src/molearn/models/transformer.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn + +import numpy as np + + +class PositionalEncoding(nn.Module): + """ + Positional encoding for transformer + """ + + def __init__(self, d_model, max_len=5000): + super(PositionalEncoding, self).__init__() + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer("pe", pe) + + def forward(self, x): + return x + self.pe[: x.size(0), :] + + +class TransformerCoordGen(nn.Module): + """ + Transformer for coordinate generation + """ + + def __init__( + self, + input_dim=3, + d_model=64, + nhead=4, + num_encoder_layers=3, + num_decoder_layers=3, + ): + super(TransformerCoordGen, self).__init__() + self.input_embedding = nn.Linear(input_dim, d_model) + self.positional_encoding = PositionalEncoding(d_model) + self.transformer = nn.Transformer( + d_model=d_model, + nhead=nhead, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + ) + self.fc_out = nn.Linear(d_model, input_dim) + + def forward(self, src, tgt, tgt_mask=None): + src = self.input_embedding(src) + tgt = self.input_embedding(tgt) + src = self.positional_encoding(src) + tgt = self.positional_encoding(tgt) + output = self.transformer(src, tgt, tgt_mask=tgt_mask) + return self.fc_out(output) + + +def generate_square_subsequent_mask(sz: int): + """ + the additive mask for the tgt sequence + :param int sz: sequence size + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +# Generation function +def generate_coordinates( + model, + start_sequence: np.ndarray[tuple[int, int], np.dtype[np.float64]], + device: torch.device, + num_steps: int = 10, +): + """ + use model to autoregessively generate coordinaes + :param molearn.models.transformer.TransformerCoordGen model: the transformer to + genereate the coordinates + :param np.ndarray[tuple[int, int], np.dtype[np.float64]] start_sequence: the + given coordinates to initialize the coordinate generation + :param torch.device device + :param int num_steps: how many coordinates should be generated + """ + model.eval() + generated = start_sequence.tolist() + with torch.no_grad(): + src = torch.tensor(start_sequence, dtype=torch.float32).unsqueeze(1).to(device) + + for _ in range(num_steps): + tgt = torch.tensor(generated, dtype=torch.float32).unsqueeze(1).to(device) + next_point = model(src, tgt) + generated.append(next_point[-1, 0].cpu().numpy().tolist()) + + return np.array(generated) + + +def gen_xyz( + generated_coords: np.ndarray[tuple[int, int], np.dtype[np.float64]], path: str +): + """ + generates an xyz file based on an array of coordinates + + :param np.ndarray[tuple[int, int], np.dtype[np.float64]] generated_coords: coordinates + :param str path: path where the xyz file shoudl be stored + """ + with open(path, "w+") as cfile: + cfile.write(f"{len(generated_coords)}\n") + for j in generated_coords: + cfile.write(f"C\t{j[0]}\t{j[1]}\t{j[2]}\n") + + +if __name__ == "__main__": + pass diff --git a/src/molearn/trainers/__init__.py b/src/molearn/trainers/__init__.py index c7da066..eaa94dd 100644 --- a/src/molearn/trainers/__init__.py +++ b/src/molearn/trainers/__init__.py @@ -14,20 +14,29 @@ from .trainer import * from .torch_physics_trainer import * - +from .transformer_trainer import Transformer_Trainer class RaiseErrorOnInit: - module = 'unknown module is creating an ImportError' - def __init__(self,*args, **kwargs): - raise ImportError(f'{self.module}. Therefore {self.__class__.__name__} can not be used') + module = "unknown module is creating an ImportError" + + def __init__(self, *args, **kwargs): + raise ImportError( + f"{self.module}. Therefore {self.__class__.__name__} can not be used" + ) + + try: from .openmm_physics_trainer import * except ImportError as e: import warnings - warnings.warn(f"{e}. OpenMM or openmmtorchplugin are not installed. If this is needed please install with `mamba install -c conda-forge openmmtorchplugin=1.1.3 openmm`") + + warnings.warn( + f"{e}. OpenMM or openmmtorchplugin are not installed. If this is needed please install with `mamba install -c conda-forge openmmtorchplugin=1.1.3 openmm`" + ) try: from .sinkhorn_trainer import * except ImportError as e: - warnings.warn(f"{e}. sinkhorn is not installed. If this is needed please install with `pip install geomloss`") - + warnings.warn( + f"{e}. sinkhorn is not installed. If this is needed please install with `pip install geomloss`" + ) diff --git a/src/molearn/trainers/transformer_trainer.py b/src/molearn/trainers/transformer_trainer.py new file mode 100644 index 0000000..47f2fda --- /dev/null +++ b/src/molearn/trainers/transformer_trainer.py @@ -0,0 +1,168 @@ +import os +from .trainer import Trainer + +from molearn.models.transformer import ( + generate_coordinates, + generate_square_subsequent_mask, +) +import torch.nn as nn +from torch.nn import functional as F + + +class Transformer_Trainer(Trainer): + """ + Torch_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step. + An extra 'physics_loss' (bonds, angles, and torsions) is calculated using pytorch. + To use this trainer requires the additional step of calling :func: `prepare_physics `. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_network_summary(self): + """ + returns a dictionary containing information about the size of the autoencoder. + """ + + def get_parameters(trainable_only, model): + return sum( + p.numel() + for p in model.parameters() + if (p.requires_grad and trainable_only) + ) + + return get_parameters(True, self.autoencoder) + + def train_epoch(self, epoch): + """ + Train one epoch. Called once an epoch from :func:`trainer.run ` + This method performs the following functions: + - Sets network to train mode via ``self.autoencoder.train()`` + - for each batch in self.train_dataloader implements typical pytorch training protocol: + + * zero gradients with call ``self.optimiser.zero_grad()`` + * Use training implemented in trainer.train_step ``result = self.train_step(batch)`` + * Determine gradients using keyword ``'loss'`` e.g. ``result['loss'].backward()`` + * Update network gradients. ``self.optimiser.step`` + + - All results are aggregated via averaging and returned with ``'train_'`` prepended on the dictionary key + + :param int epoch: The epoch is passed as an argument however epoch number can also be accessed from self.epoch. + :returns: Return all results from train_step averaged. These results will be printed and/or logged in :func:`trainer.run() ` via a call to :func:`self.log(results) ` + :rtype: dict + """ + self.autoencoder.train() + # should be create once a training but is here for easier usage with normal trainer + self.tgt_mask = generate_square_subsequent_mask( + self._data.dataset.shape[-1] - 2 + ) + N = 0 + results = {} + for i, batch in enumerate(self.train_dataloader): + batch = batch[0].to(self.device).permute(0, 2, 1) + self.optimiser.zero_grad() + + train_result = self.train_step(batch) + + train_result["loss"].backward() + self.optimiser.step() + if i == 0: + results = { + key: value.item() * len(batch) + for key, value in train_result.items() + } + else: + for key in train_result.keys(): + results[key] += train_result[key].item() * len(batch) + N += len(batch) + return {f"train_{key}": results[key] / N for key in results.keys()} + + def train_step(self, batch): + """ + Called from :func:`Trainer.train_epoch `. + + :param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``. + :returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that :func:`self.train_epoch ` will call ``result['loss'].backwards()`` to obtain gradients. + :rtype: dict + """ + results = self.common_step(batch) + results["loss"] = results["mse_loss"] + results["similarity_loss"] + return results + + def common_step(self, batch): + """ + Called from both train_step and valid_step. + Calculates the mean squared error loss for self.autoencoder. + Encoded and decoded frames are saved in self._internal under keys ``encoded`` and ``decoded`` respectively should you wish to use them elsewhere. + + :param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms] A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``. + :returns: Return calculated mse_loss + :rtype: dict + """ + + # Shape: (n_atoms, batch_size, feature_size) + src = batch[:, :-1].transpose(0, 1) + tgt = batch[:, 1:].transpose(0, 1) + # Teacher forcing + output = self.autoencoder(src, tgt[:-1], tgt_mask=self.tgt_mask) + mse_loss = nn.MSELoss()(output, tgt[1:]) + + # additional loss to force it to generate more diverse structures + cs_tgt = F.cosine_similarity(tgt[..., None, :, :], tgt[..., :, None, :], dim=0) + cs_output = F.cosine_similarity( + output[..., None, :, :], output[..., :, None, :], dim=0 + ) + # adding the loss and scaling by 1500 to be about the range same as the mse loss + similarity_loss = (cs_tgt - cs_output).abs().sum() / 1500 + + return {"mse_loss": mse_loss, "similarity_loss": similarity_loss} + + def valid_epoch(self, epoch): + """ + Called once an epoch from :func:`trainer.run ` within a no_grad context. + This method performs the following functions: + - Sets network to eval mode via ``self.autoencoder.eval()`` + - for each batch in ``self.valid_dataloader`` calls :func:`trainer.valid_step ` to retrieve validation loss + - All results are aggregated via averaging and returned with ``'valid_'`` prepended on the dictionary key + + * The loss with key ``'loss'`` is returned as ``'valid_loss'`` this will be the loss value by which the best checkpoint is determined. + + :param int epoch: The epoch is passed as an argument however epoch number can also be accessed from self.epoch. + :returns: Return all results from valid_step averaged. These results will be printed and/or logged in :func:`Trainer.run() ` via a call to :func:`self.log(results) ` + :rtype: dict + """ + self.autoencoder.eval() + N = 0 + results = {} + for i, batch in enumerate(self.valid_dataloader): + batch = batch[0].to(self.device).permute(0, 2, 1) + valid_result = self.valid_step(batch) + if i == 0: + results = { + key: value.item() * len(batch) + for key, value in valid_result.items() + } + else: + for key in valid_result.keys(): + results[key] += valid_result[key].item() * len(batch) + N += len(batch) + # generate one structure every epoch + start_sequence = batch[0][:100, :].permute(0, 1).cpu().numpy() + generated_coords = ( + generate_coordinates( + self.autoencoder, + start_sequence, + self.device, + self._data.dataset.shape[-1], + ) + * self.std + + self.mean + ) + # save new generated structure (based on the first 100 atoms of a test example) as xyz file + with open( + f"{os.path.dirname(self.log_filename)}/epoch{epoch}.xyz", "w+" + ) as cfile: + cfile.write(f"{len(generated_coords)}\n") + for j in generated_coords: + cfile.write(f"C\t{j[0]}\t{j[1]}\t{j[2]}\n") + return {f"valid_{key}": results[key] / N for key in results.keys()}