Skip to content

Commit

Permalink
Initial chromatin state model
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Oct 3, 2023
1 parent 1069da5 commit e6d49bd
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 0 deletions.
88 changes: 88 additions & 0 deletions supirfactor_dynamical/_utils/chromatin_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
from scipy.sparse import isspmatrix


class ChromatinDataset(torch.utils.data.Dataset):

expression = None
chromatin = None

n = None

@property
def expression_sparse(self):
if self.expression is None:
return None
else:
return isspmatrix(self.expression)

@property
def chromatin_sparse(self):
if self.chromatin is None:
return None
else:
return isspmatrix(self.chromatin)

def __init__(
self,
gene_expression_data,
chromatin_state_data
):
self.n = self._n_samples(
gene_expression_data,
chromatin_state_data
)

self.chromatin = chromatin_state_data
self.expression = gene_expression_data

if not torch.is_tensor(self.expression) and not self.expression_sparse:
self.expression = torch.Tensor(self.expression)

if not torch.is_tensor(self.chromatin) and not self.chromatin_sparse:
self.chromatin = torch.Tensor(self.chromatin)

def __getitem__(self, i):

e = self.expression[i, :]
c = self.chromatin[i, :]

if isspmatrix(e):
e = torch.Tensor(e.A.ravel())

if isspmatrix(c):
c = torch.Tensor(c.A.ravel())

return e, c

def __len__(self):
return self.n

@staticmethod
def _n_samples(expr, peaks):

_n_expr = expr.shape[0]
_n_peaks = peaks.shape[0]

if _n_expr != _n_peaks:
raise ValueError(
f"Expression data {expr.shape} and peak data {peaks.shape} "
"do not have the same number of observations"
)

return _n_expr


class ChromatinDataLoader(torch.utils.data.DataLoader):

def __init__(self, args, **kwargs):
kwargs['collate_fn'] = chromatin_collate_fn
super().__init__(args, **kwargs)


def chromatin_collate_fn(data):

return (
torch.stack([d[0] for d in data]),
torch.stack([d[1] for d in data])
)
97 changes: 97 additions & 0 deletions supirfactor_dynamical/models/chromatin_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch

from ._base_trainer import (
_TrainingMixin
)


class ChromatinModule(
torch.nn.Module,
_TrainingMixin
):

type_name = 'chromatin'

hidden_state = None

g = None
k = None
p = None

def __init__(
self,
g,
p,
k=50,
input_dropout_rate=0.5,
hidden_dropout_rate=0.0
):
"""
Initialize a chromatin state model
:param g: Number of genes/transcripts
:type g: int
:param p: Number of peaks
:type p: p
:param k: Number of internal model nodes
:type k: int
:param input_dropout_rate: _description_, defaults to 0.5
:type input_dropout_rate: float, optional
:param hidden_dropout_rate: _description_, defaults to 0.0
:type hidden_dropout_rate: float, optional
"""
super().__init__()

self.g = g
self.k = k
self.p = p

self.set_dropouts(
input_dropout_rate,
hidden_dropout_rate
)

self.model = torch.nn.Sequential(
torch.nn.Dropout(input_dropout_rate),
torch.nn.Linear(
g,
k,
bias=False
),
torch.nn.Tanh(),
torch.nn.Dropout(hidden_dropout_rate),
torch.nn.Linear(
k,
k,
bias=False
),
torch.nn.Softplus(threshold=5),
torch.nn.Linear(
k,
p,
bias=False
),
torch.nn.Sigmoid()
)

def forward(
self,
x
):
return self.model(x)

def _slice_data_and_forward(self, train_x):
return self.output_model(
self(
self.input_data(train_x)
)
)

def input_data(self, x):
return x[0]

def output_data(self, x, **kwargs):
return x[1]

def output_model(self, x):
return x
5 changes: 5 additions & 0 deletions supirfactor_dynamical/tests/_stubs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import numpy as np
from scipy.sparse import csr_matrix

from supirfactor_dynamical._utils._trunc_robust_scaler import TruncRobustScaler

Expand All @@ -18,6 +19,10 @@
X = X[np.argsort(X[:, 0]), :]
Y = X @ A
T = np.repeat(np.arange(4), 25)
X_SP = csr_matrix(X)

PEAKS = _rng.choice([0, 1], size=(100, 25), p=[0.95, 0.05])
PEAKS_SP = csr_matrix(PEAKS)

X_tensor = torch.Tensor(X)

Expand Down
69 changes: 69 additions & 0 deletions supirfactor_dynamical/tests/test_chromatin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest

from ._stubs import (
X,
X_SP,
PEAKS,
PEAKS_SP
)

from supirfactor_dynamical._utils.chromatin_dataset import (
ChromatinDataset,
ChromatinDataLoader
)

from supirfactor_dynamical.models.chromatin_model import (
ChromatinModule
)


class TestChromatinTraining(unittest.TestCase):

def setUp(self) -> None:
self.peaks = PEAKS.copy()
self.X = X.copy()

self.data = ChromatinDataset(
self.X,
self.peaks
)

self.dataloader = ChromatinDataLoader(
self.data,
batch_size=2
)

def test_train(self):

model = ChromatinModule(
4,
25,
k=10
)

model.train_model(
self.dataloader,
10
)

self.assertEqual(
len(model.training_loss),
10
)


class TestChromatinTrainingSparse(TestChromatinTraining):

def setUp(self) -> None:
self.peaks = PEAKS_SP.copy()
self.X = X_SP.copy()

self.data = ChromatinDataset(
self.X,
self.peaks
)

self.dataloader = ChromatinDataLoader(
self.data,
batch_size=2
)
69 changes: 69 additions & 0 deletions supirfactor_dynamical/tests/test_chromatin_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest

from ._stubs import (
X,
X_SP,
PEAKS,
PEAKS_SP
)

from supirfactor_dynamical._utils.chromatin_dataset import (
ChromatinDataset,
ChromatinDataLoader
)


class TestChromatinDataset(unittest.TestCase):

def setUp(self) -> None:
self.peaks = PEAKS.copy()
self.X = X.copy()

def test_init(self):

data = ChromatinDataset(
self.X,
self.peaks
)

self.assertEqual(
len(data),
100
)

dataloader = ChromatinDataLoader(
data,
batch_size=2
)

ld = [d for d in dataloader]

self.assertEqual(
len(ld),
50
)

self.assertEqual(
ld[0][0].shape,
(2, 4)
)

self.assertEqual(
ld[0][1].shape,
(2, 25)
)

def test_misaligned(self):

with self.assertRaises(ValueError):
ChromatinDataset(
self.X,
self.peaks.T
)


class TestChromatinDatasetSparse(TestChromatinDataset):

def setUp(self) -> None:
self.peaks = PEAKS_SP.copy()
self.X = X_SP.copy()

0 comments on commit e6d49bd

Please sign in to comment.