diff --git a/netam/common.py b/netam/common.py index d4b9a511..06274d16 100644 --- a/netam/common.py +++ b/netam/common.py @@ -333,32 +333,6 @@ def chunked(iterable, n): yield chunk -def assume_single_sequence_is_heavy_chain(seq_arg_idx=0): - """Wraps a function that takes a heavy/light sequence pair as its first argument and - returns a tuple of results. - - The wrapped function will assume that if the first argument is a string, it is a - heavy chain sequence, and in that case will return only the heavy chain result. - """ - - def decorator(function): - @wraps(function) - def wrapper(*args, **kwargs): - seq = args[seq_arg_idx] - if isinstance(seq, str): - seq = (seq, "") - args = list(args) - args[seq_arg_idx] = seq - res = function(*args, **kwargs) - return res[0] - else: - return function(*args, **kwargs) - - return wrapper - - return decorator - - def heavy_chain_shim(paired_evaluator): """Returns a function that evaluates only heavy chains given a paired evaluator.""" diff --git a/netam/dasm.py b/netam/dasm.py index 33c1c812..bebb87eb 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -1,7 +1,5 @@ """Defining the deep natural selection model (DNSM).""" -import copy - import torch import torch.nn.functional as F @@ -72,13 +70,6 @@ def update_neutral_probs(self): self.nt_cspss, self._branch_lengths, ): - mask = mask.to("cpu") - nt_rates = nt_rates.to("cpu") - nt_csps = nt_csps.to("cpu") - if self.multihit_model is not None: - multihit_model = copy.deepcopy(self.multihit_model).to("cpu") - else: - multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. parent_idxs = nt_idx_tensor_of_str(nt_parent.replace("N", "A")) @@ -93,11 +84,11 @@ def update_neutral_probs(self): parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), nt_csps.reshape(-1, 3, 4), - multihit_model=multihit_model, + multihit_model=self.multihit_model, ) if not torch.isfinite(neutral_codon_probs).all(): - print(f"Found a non-finite neutral_codon_prob") + print("Found a non-finite neutral_codon_prob") print(f"nt_parent: {nt_parent}") print(f"mask: {mask}") print(f"nt_rates: {nt_rates}") @@ -137,7 +128,7 @@ def __getitem__(self, idx): "nt_csps": self.nt_cspss[idx], } - def to(self, device): + def move_data_to_device(self, device): self.codon_parents_idxss = self.codon_parents_idxss.to(device) self.codon_children_idxss = self.codon_children_idxss.to(device) self.aa_parents_idxss = self.aa_parents_idxss.to(device) diff --git a/netam/ddsm.py b/netam/ddsm.py index fce19470..2d6de130 100644 --- a/netam/ddsm.py +++ b/netam/ddsm.py @@ -8,7 +8,6 @@ import netam.framework as framework import netam.molevol as molevol import netam.sequences as sequences -import copy from typing import Tuple @@ -24,13 +23,6 @@ def update_neutral_probs(self): self.nt_cspss, self._branch_lengths, ): - mask = mask.to("cpu") - nt_rates = nt_rates.to("cpu") - nt_csps = nt_csps.to("cpu") - if self.multihit_model is not None: - multihit_model = copy.deepcopy(self.multihit_model).to("cpu") - else: - multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) @@ -45,11 +37,11 @@ def update_neutral_probs(self): parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), nt_csps.reshape(-1, 3, 4), - multihit_model=multihit_model, + multihit_model=self.multihit_model, ) if not torch.isfinite(neutral_aa_probs).all(): - print(f"Found a non-finite neutral_aa_probs") + print("Found a non-finite neutral_aa_probs") print(f"nt_parent: {nt_parent}") print(f"mask: {mask}") print(f"nt_rates: {nt_rates}") @@ -85,7 +77,7 @@ def __getitem__(self, idx): "nt_csps": self.nt_cspss[idx], } - def to(self, device): + def move_data_to_device(self, device): self.aa_parents_idxss = self.aa_parents_idxss.to(device) self.aa_children_idxss = self.aa_children_idxss.to(device) self.aa_subs_indicators = self.aa_subs_indicators.to(device) diff --git a/netam/dnsm.py b/netam/dnsm.py index 568b3798..744da75c 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -1,7 +1,5 @@ """Defining the deep natural selection model (DNSM).""" -import copy - import torch import torch.nn.functional as F @@ -37,13 +35,6 @@ def update_neutral_probs(self): self.nt_cspss, self._branch_lengths, ): - mask = mask.to("cpu") - nt_rates = nt_rates.to("cpu") - nt_csps = nt_csps.to("cpu") - if self.multihit_model is not None: - multihit_model = copy.deepcopy(self.multihit_model).to("cpu") - else: - multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) @@ -60,11 +51,11 @@ def update_neutral_probs(self): parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), nt_csps.reshape(-1, 3, 4), - multihit_model=multihit_model, + multihit_model=self.multihit_model, ) if not torch.isfinite(neutral_aa_mut_probs).all(): - print(f"Found a non-finite neutral_aa_mut_prob") + print("Found a non-finite neutral_aa_mut_prob") print(f"nt_parent: {nt_parent}") print(f"mask: {mask}") print(f"nt_rates: {nt_rates}") @@ -101,7 +92,7 @@ def __getitem__(self, idx): "nt_csps": self.nt_cspss[idx], } - def to(self, device): + def move_data_to_device(self, device): self.aa_parents_idxss = self.aa_parents_idxss.to(device) self.aa_subs_indicators = self.aa_subs_indicators.to(device) self.masks = self.masks.to(device) diff --git a/netam/dxsm.py b/netam/dxsm.py index 507318ed..4ac7a7a4 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -67,9 +67,10 @@ def __init__( self.masks = masks self.aa_subs_indicators = aa_subs_indicators self.multihit_model = copy.deepcopy(multihit_model) - if multihit_model is not None: + if self.multihit_model is not None: # We want these parameters to act like fixed data. This is essential # for multithreaded branch length optimization to work. + self.multihit_model = self.multihit_model.to("cpu") self.multihit_model.values.requires_grad_(False) assert len(self.nt_parents) == len(self.nt_children) @@ -84,6 +85,14 @@ def __init__( self._branch_lengths = branch_lengths self.update_neutral_probs() + def __post_init__(self): + self.move_data_to_device("cpu") + + @abstractmethod + def move_data_to_device(self, device): + """Move all tensors stored by the dataset to the given device.""" + pass + @classmethod def of_seriess( cls, @@ -284,6 +293,9 @@ def branch_lengths(self, new_branch_lengths): self._branch_lengths = new_branch_lengths self.update_neutral_probs() + def to(self, device): + self.device = device + @abstractmethod def update_neutral_probs(self): pass diff --git a/netam/framework.py b/netam/framework.py index 5a7d8064..5ea199ec 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -768,8 +768,8 @@ def standardize_and_optimize_branch_lengths(self, **optimization_kwargs): dataset.branch_lengths = self.find_optimal_branch_lengths( dataset, **optimization_kwargs ) - dataset.to(device) self.model.to(device) + dataset.to(device) def standardize_and_use_yun_approx_branch_lengths(self): """Yun Song's approximation to the branch lengths.