Skip to content

Commit

Permalink
Chromatin model first shot
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Oct 17, 2023
1 parent e6d49bd commit fea1c25
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 4 deletions.
1 change: 1 addition & 0 deletions supirfactor_dynamical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TFLSTMDecoder,
TFGRUDecoder,
SupirFactorBiophysical,
ChromatinAwareModel,
get_model
)

Expand Down
16 changes: 16 additions & 0 deletions supirfactor_dynamical/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np
import pandas as pd
import anndata as ad
from scipy import sparse


def _process_weights_to_tensor(
Expand All @@ -23,6 +25,20 @@ def _process_weights_to_tensor(
dtype=torch.float32
)

elif isinstance(prior_network, sparse.csr_matrix):
labels = (None, None)
data = torch.sparse_csr_tensor(
prior_network.indptr,
prior_network.indices,
prior_network.data.astype(np.float32)
).to_dense()

elif isinstance(prior_network, ad.AnnData):
labels = (prior_network.obs_names, prior_network.var_names)
data = _process_weights_to_tensor(
prior_network.X
)[0]

else:
labels = (None, None)
data = prior_network
Expand Down
4 changes: 3 additions & 1 deletion supirfactor_dynamical/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .biophysical_model import SupirFactorBiophysical
from .decay_model import DecayModule
from .chromatin_model import ChromatinAwareModel

# Standard mixins
from ._base_velocity_model import (
Expand All @@ -27,7 +28,8 @@
TFGRUDecoder.type_name: TFGRUDecoder,
TFLSTMDecoder.type_name: TFLSTMDecoder,
SupirFactorBiophysical.type_name: SupirFactorBiophysical,
DecayModule.type_name: DecayModule
DecayModule.type_name: DecayModule,
ChromatinAwareModel.type_name: ChromatinAwareModel
}

_not_velocity = [
Expand Down
28 changes: 26 additions & 2 deletions supirfactor_dynamical/models/_model_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,36 @@ def set_decoder(
def mask_input_weights(
self,
mask,
module=None,
use_mask_weights=False,
layer_name='weight_ih_l0',
layer_name='weight',
weight_vstack=None
):
"""
Apply a mask to layer weights
:param mask: Mask tensor. Non-zero values will be retained,
and zero values will be masked to zero in the layer weights
:type mask: torch.Tensor
:param encoder: Module to mask, use self.encoder if this is None,
defaults to None
:type encoder: torch.nn.Module, optional
:param use_mask_weights: Set the weights equal to values in mask,
defaults to False
:type use_mask_weights: bool, optional
:param layer_name: Module weight name,
defaults to 'weight'
:type layer_name: str, optional
:param weight_vstack: Number of times to stack the mask, for cases
where the layer weights are also stacked, defaults to None
:type weight_vstack: _type_, optional
:raises ValueError: Raise error if the mask and module weights are
different sizes
"""

if isinstance(self.encoder, torch.nn.Sequential):
if module is not None:
encoder = module
elif isinstance(self.encoder, torch.nn.Sequential):
encoder = self.encoder[0]
else:
encoder = self.encoder
Expand Down
114 changes: 114 additions & 0 deletions supirfactor_dynamical/models/chromatin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,120 @@
_TrainingMixin
)

from ._base_model import (
_TFMixin
)

from .._utils import _process_weights_to_tensor


class ChromatinAwareModel(
_TFMixin,
_TrainingMixin
):

type_name = 'chromatin_aware'

g = None
k = None
p = None

peak_encoder = None
tf_encoder = None
chromatin_encoder = None

def __init__(
self,
gene_peak_mask,
peak_tf_mask,
chromatin_model=None,
input_dropout_rate=0.5,
hidden_dropout_rate=0.0
):

self.g, self.p = gene_peak_mask.shape
self.k, _ = peak_tf_mask.shape

self.set_dropouts(
input_dropout_rate,
hidden_dropout_rate
)

self.peak_encoder = torch.nn.Sequential(
torch.nn.Linear(self.g, self.p, bias=False),
torch.nn.Softplus(threshold=5)
)

gene_peak_mask, _ = _process_weights_to_tensor(
gene_peak_mask
)

self.mask_input_weights(
gene_peak_mask,
module=self.peak_encoder[1],
layer_name='weight'
)

peak_tf_mask, _ = _process_weights_to_tensor(
peak_tf_mask
)

self.tf_encoder = torch.nn.Sequential(
self.input_dropout,
torch.nn.Linear(self.p, self.k, bias=False),
torch.nn.Softplus(threshold=5)
)

self.mask_input_weights(
peak_tf_mask,
module=self.tf_encoder[0],
layer_name='weight'
)

self.decoder = torch.nn.Sequential(
self.hidden_dropout,
torch.nn.Linear(self.k, self.k, bias=False),
torch.nn.Softplus(threshold=5),
torch.nn.Linear(self.k, self.g, bias=False),
torch.nn.Softplus(threshold=5)
)

if chromatin_model is not None:

if isinstance(chromatin_model, str):
from .._utils._loader import read
chromatin_model = read(chromatin_model)

self.chromatin_model = chromatin_model

else:

self.chromatin_model = ChromatinModule(
self.g,
self.p,
self.k,
input_dropout_rate=input_dropout_rate,
hidden_dropout_rate=hidden_dropout_rate,
)

def forward(self, x, return_tfa=False):

peak_state = self.chromatin_model(x)

peak_activity = self.peak_encoder(x)
peak_activity = torch.mul(
peak_activity,
peak_state
)

tfa = self.tf_encoder(peak_activity)
x_out = self.decoder(tfa)

if return_tfa:
return x_out, tfa
else:
return x_out


class ChromatinModule(
torch.nn.Module,
Expand Down
116 changes: 115 additions & 1 deletion supirfactor_dynamical/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,138 @@
import torch
import numpy.testing as npt
import numpy as np
import pandas as pd
import anndata as ad

from supirfactor_dynamical._utils import (
_calculate_erv,
_calculate_rss,
_calculate_tss,
_calculate_r2,
_aggregate_r2
_aggregate_r2,
_process_weights_to_tensor
)

from scipy.linalg import pinv

from ._stubs import (
X,
X_SP,
A,
X_tensor
)


class TestTensorUtils(unittest.TestCase):

def test_array_to_tensor(self):

x_t, (a, b) = _process_weights_to_tensor(
X
)

torch.testing.assert_close(x_t, torch.transpose(X_tensor, 0, 1))
self.assertIsNone(a)
self.assertIsNone(b)

x_t, (a, b) = _process_weights_to_tensor(
X,
transpose=False
)

torch.testing.assert_close(x_t, X_tensor)
self.assertIsNone(a)
self.assertIsNone(b)

def test_dataframe_to_tensor(self):

x = pd.DataFrame(X)

x_t, _ = _process_weights_to_tensor(
x
)

torch.testing.assert_close(x_t, torch.transpose(X_tensor, 0, 1))

x_t, _ = _process_weights_to_tensor(
x,
transpose=False
)

torch.testing.assert_close(x_t, X_tensor)

def test_sparse_to_tensor(self):

x_t, (a, b) = _process_weights_to_tensor(
X_SP
)

torch.testing.assert_close(
x_t.to_dense(),
torch.transpose(X_tensor, 0, 1)
)
self.assertIsNone(a)
self.assertIsNone(b)

x_t, (a, b) = _process_weights_to_tensor(
X_SP,
transpose=False
)

torch.testing.assert_close(
x_t.to_dense(),
X_tensor
)
self.assertIsNone(a)
self.assertIsNone(b)

def adata_to_tensor(self):

adata = ad.AnnData(X)

x_t, _ = _process_weights_to_tensor(
adata
)

torch.testing.assert_close(
x_t.to_dense(),
torch.transpose(X_tensor, 0, 1)
)

x_t, _ = _process_weights_to_tensor(
adata,
transpose=False
)

torch.testing.assert_close(
x_t.to_dense(),
X_tensor
)

def adata_sparse_to_tensor(self):

adata = ad.AnnData(X_SP)

x_t, _ = _process_weights_to_tensor(
adata
)

torch.testing.assert_close(
x_t.to_dense(),
torch.transpose(X_tensor, 0, 1)
)

x_t, _ = _process_weights_to_tensor(
adata,
transpose=False
)

torch.testing.assert_close(
x_t.to_dense(),
X_tensor
)


class TestMathUtils(unittest.TestCase):

def test_erv(self):
Expand Down

0 comments on commit fea1c25

Please sign in to comment.