Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep datasets on CPU #120

Merged
merged 7 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,32 +333,6 @@ def chunked(iterable, n):
yield chunk


def assume_single_sequence_is_heavy_chain(seq_arg_idx=0):
"""Wraps a function that takes a heavy/light sequence pair as its first argument and
returns a tuple of results.

The wrapped function will assume that if the first argument is a string, it is a
heavy chain sequence, and in that case will return only the heavy chain result.
"""

def decorator(function):
@wraps(function)
def wrapper(*args, **kwargs):
seq = args[seq_arg_idx]
if isinstance(seq, str):
seq = (seq, "")
args = list(args)
args[seq_arg_idx] = seq
res = function(*args, **kwargs)
return res[0]
else:
return function(*args, **kwargs)

return wrapper

return decorator


def heavy_chain_shim(paired_evaluator):
"""Returns a function that evaluates only heavy chains given a paired evaluator."""

Expand Down
15 changes: 3 additions & 12 deletions netam/dasm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Defining the deep natural selection model (DNSM)."""

import copy

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -72,13 +70,6 @@ def update_neutral_probs(self):
self.nt_cspss,
self._branch_lengths,
):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
Expand All @@ -93,11 +84,11 @@ def update_neutral_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
multihit_model=self.multihit_model,
)

if not torch.isfinite(neutral_codon_probs).all():
print(f"Found a non-finite neutral_codon_prob")
print("Found a non-finite neutral_codon_prob")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"nt_rates: {nt_rates}")
Expand Down Expand Up @@ -137,7 +128,7 @@ def __getitem__(self, idx):
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
def move_data_to_device(self, device):
self.codon_parents_idxss = self.codon_parents_idxss.to(device)
self.codon_children_idxss = self.codon_children_idxss.to(device)
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
Expand Down
14 changes: 3 additions & 11 deletions netam/ddsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
import copy
from typing import Tuple


Expand All @@ -24,13 +23,6 @@ def update_neutral_probs(self):
self.nt_cspss,
self._branch_lengths,
):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
Expand All @@ -45,11 +37,11 @@ def update_neutral_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
multihit_model=self.multihit_model,
)

if not torch.isfinite(neutral_aa_probs).all():
print(f"Found a non-finite neutral_aa_probs")
print("Found a non-finite neutral_aa_probs")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"nt_rates: {nt_rates}")
Expand Down Expand Up @@ -85,7 +77,7 @@ def __getitem__(self, idx):
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
def move_data_to_device(self, device):
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
self.aa_children_idxss = self.aa_children_idxss.to(device)
self.aa_subs_indicators = self.aa_subs_indicators.to(device)
Expand Down
15 changes: 3 additions & 12 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Defining the deep natural selection model (DNSM)."""

import copy

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -37,13 +35,6 @@ def update_neutral_probs(self):
self.nt_cspss,
self._branch_lengths,
):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
multihit_model = None
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
Expand All @@ -60,11 +51,11 @@ def update_neutral_probs(self):
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
multihit_model=self.multihit_model,
)

if not torch.isfinite(neutral_aa_mut_probs).all():
print(f"Found a non-finite neutral_aa_mut_prob")
print("Found a non-finite neutral_aa_mut_prob")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"nt_rates: {nt_rates}")
Expand Down Expand Up @@ -101,7 +92,7 @@ def __getitem__(self, idx):
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
def move_data_to_device(self, device):
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
self.aa_subs_indicators = self.aa_subs_indicators.to(device)
self.masks = self.masks.to(device)
Expand Down
14 changes: 13 additions & 1 deletion netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ def __init__(
self.masks = masks
self.aa_subs_indicators = aa_subs_indicators
self.multihit_model = copy.deepcopy(multihit_model)
if multihit_model is not None:
if self.multihit_model is not None:
# We want these parameters to act like fixed data. This is essential
# for multithreaded branch length optimization to work.
self.multihit_model = self.multihit_model.to("cpu")
self.multihit_model.values.requires_grad_(False)

assert len(self.nt_parents) == len(self.nt_children)
Expand All @@ -84,6 +85,14 @@ def __init__(
self._branch_lengths = branch_lengths
self.update_neutral_probs()

def __post_init__(self):
self.move_data_to_device("cpu")

@abstractmethod
def move_data_to_device(self, device):
"""Move all tensors stored by the dataset to the given device."""
pass

@classmethod
def of_seriess(
cls,
Expand Down Expand Up @@ -284,6 +293,9 @@ def branch_lengths(self, new_branch_lengths):
self._branch_lengths = new_branch_lengths
self.update_neutral_probs()

def to(self, device):
self.device = device

@abstractmethod
def update_neutral_probs(self):
pass
Expand Down
2 changes: 1 addition & 1 deletion netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,8 +768,8 @@ def standardize_and_optimize_branch_lengths(self, **optimization_kwargs):
dataset.branch_lengths = self.find_optimal_branch_lengths(
dataset, **optimization_kwargs
)
dataset.to(device)
self.model.to(device)
dataset.to(device)

def standardize_and_use_yun_approx_branch_lengths(self):
"""Yun Song's approximation to the branch lengths.
Expand Down