diff --git a/supirfactor_dynamical/_utils/chromatin_dataset.py b/supirfactor_dynamical/_utils/chromatin_dataset.py new file mode 100644 index 0000000..44d0d58 --- /dev/null +++ b/supirfactor_dynamical/_utils/chromatin_dataset.py @@ -0,0 +1,88 @@ +import torch +from scipy.sparse import isspmatrix + + +class ChromatinDataset(torch.utils.data.Dataset): + + expression = None + chromatin = None + + n = None + + @property + def expression_sparse(self): + if self.expression is None: + return None + else: + return isspmatrix(self.expression) + + @property + def chromatin_sparse(self): + if self.chromatin is None: + return None + else: + return isspmatrix(self.chromatin) + + def __init__( + self, + gene_expression_data, + chromatin_state_data + ): + self.n = self._n_samples( + gene_expression_data, + chromatin_state_data + ) + + self.chromatin = chromatin_state_data + self.expression = gene_expression_data + + if not torch.is_tensor(self.expression) and not self.expression_sparse: + self.expression = torch.Tensor(self.expression) + + if not torch.is_tensor(self.chromatin) and not self.chromatin_sparse: + self.chromatin = torch.Tensor(self.chromatin) + + def __getitem__(self, i): + + e = self.expression[i, :] + c = self.chromatin[i, :] + + if isspmatrix(e): + e = torch.Tensor(e.A.ravel()) + + if isspmatrix(c): + c = torch.Tensor(c.A.ravel()) + + return e, c + + def __len__(self): + return self.n + + @staticmethod + def _n_samples(expr, peaks): + + _n_expr = expr.shape[0] + _n_peaks = peaks.shape[0] + + if _n_expr != _n_peaks: + raise ValueError( + f"Expression data {expr.shape} and peak data {peaks.shape} " + "do not have the same number of observations" + ) + + return _n_expr + + +class ChromatinDataLoader(torch.utils.data.DataLoader): + + def __init__(self, args, **kwargs): + kwargs['collate_fn'] = chromatin_collate_fn + super().__init__(args, **kwargs) + + +def chromatin_collate_fn(data): + + return ( + torch.stack([d[0] for d in data]), + torch.stack([d[1] for d in data]) + ) diff --git a/supirfactor_dynamical/models/chromatin_model.py b/supirfactor_dynamical/models/chromatin_model.py new file mode 100644 index 0000000..95b2f04 --- /dev/null +++ b/supirfactor_dynamical/models/chromatin_model.py @@ -0,0 +1,97 @@ +import torch + +from ._base_trainer import ( + _TrainingMixin +) + + +class ChromatinModule( + torch.nn.Module, + _TrainingMixin +): + + type_name = 'chromatin' + + hidden_state = None + + g = None + k = None + p = None + + def __init__( + self, + g, + p, + k=50, + input_dropout_rate=0.5, + hidden_dropout_rate=0.0 + ): + """ + Initialize a chromatin state model + + :param g: Number of genes/transcripts + :type g: int + :param p: Number of peaks + :type p: p + :param k: Number of internal model nodes + :type k: int + :param input_dropout_rate: _description_, defaults to 0.5 + :type input_dropout_rate: float, optional + :param hidden_dropout_rate: _description_, defaults to 0.0 + :type hidden_dropout_rate: float, optional + """ + super().__init__() + + self.g = g + self.k = k + self.p = p + + self.set_dropouts( + input_dropout_rate, + hidden_dropout_rate + ) + + self.model = torch.nn.Sequential( + torch.nn.Dropout(input_dropout_rate), + torch.nn.Linear( + g, + k, + bias=False + ), + torch.nn.Tanh(), + torch.nn.Dropout(hidden_dropout_rate), + torch.nn.Linear( + k, + k, + bias=False + ), + torch.nn.Softplus(threshold=5), + torch.nn.Linear( + k, + p, + bias=False + ), + torch.nn.Sigmoid() + ) + + def forward( + self, + x + ): + return self.model(x) + + def _slice_data_and_forward(self, train_x): + return self.output_model( + self( + self.input_data(train_x) + ) + ) + + def input_data(self, x): + return x[0] + + def output_data(self, x, **kwargs): + return x[1] + + def output_model(self, x): + return x diff --git a/supirfactor_dynamical/tests/_stubs.py b/supirfactor_dynamical/tests/_stubs.py index ff58876..9db6a89 100644 --- a/supirfactor_dynamical/tests/_stubs.py +++ b/supirfactor_dynamical/tests/_stubs.py @@ -1,5 +1,6 @@ import torch import numpy as np +from scipy.sparse import csr_matrix from supirfactor_dynamical._utils._trunc_robust_scaler import TruncRobustScaler @@ -18,6 +19,10 @@ X = X[np.argsort(X[:, 0]), :] Y = X @ A T = np.repeat(np.arange(4), 25) +X_SP = csr_matrix(X) + +PEAKS = _rng.choice([0, 1], size=(100, 25), p=[0.95, 0.05]) +PEAKS_SP = csr_matrix(PEAKS) X_tensor = torch.Tensor(X) diff --git a/supirfactor_dynamical/tests/test_chromatin.py b/supirfactor_dynamical/tests/test_chromatin.py new file mode 100644 index 0000000..b0a3388 --- /dev/null +++ b/supirfactor_dynamical/tests/test_chromatin.py @@ -0,0 +1,69 @@ +import unittest + +from ._stubs import ( + X, + X_SP, + PEAKS, + PEAKS_SP +) + +from supirfactor_dynamical._utils.chromatin_dataset import ( + ChromatinDataset, + ChromatinDataLoader +) + +from supirfactor_dynamical.models.chromatin_model import ( + ChromatinModule +) + + +class TestChromatinTraining(unittest.TestCase): + + def setUp(self) -> None: + self.peaks = PEAKS.copy() + self.X = X.copy() + + self.data = ChromatinDataset( + self.X, + self.peaks + ) + + self.dataloader = ChromatinDataLoader( + self.data, + batch_size=2 + ) + + def test_train(self): + + model = ChromatinModule( + 4, + 25, + k=10 + ) + + model.train_model( + self.dataloader, + 10 + ) + + self.assertEqual( + len(model.training_loss), + 10 + ) + + +class TestChromatinTrainingSparse(TestChromatinTraining): + + def setUp(self) -> None: + self.peaks = PEAKS_SP.copy() + self.X = X_SP.copy() + + self.data = ChromatinDataset( + self.X, + self.peaks + ) + + self.dataloader = ChromatinDataLoader( + self.data, + batch_size=2 + ) diff --git a/supirfactor_dynamical/tests/test_chromatin_loader.py b/supirfactor_dynamical/tests/test_chromatin_loader.py new file mode 100644 index 0000000..410a26d --- /dev/null +++ b/supirfactor_dynamical/tests/test_chromatin_loader.py @@ -0,0 +1,69 @@ +import unittest + +from ._stubs import ( + X, + X_SP, + PEAKS, + PEAKS_SP +) + +from supirfactor_dynamical._utils.chromatin_dataset import ( + ChromatinDataset, + ChromatinDataLoader +) + + +class TestChromatinDataset(unittest.TestCase): + + def setUp(self) -> None: + self.peaks = PEAKS.copy() + self.X = X.copy() + + def test_init(self): + + data = ChromatinDataset( + self.X, + self.peaks + ) + + self.assertEqual( + len(data), + 100 + ) + + dataloader = ChromatinDataLoader( + data, + batch_size=2 + ) + + ld = [d for d in dataloader] + + self.assertEqual( + len(ld), + 50 + ) + + self.assertEqual( + ld[0][0].shape, + (2, 4) + ) + + self.assertEqual( + ld[0][1].shape, + (2, 25) + ) + + def test_misaligned(self): + + with self.assertRaises(ValueError): + ChromatinDataset( + self.X, + self.peaks.T + ) + + +class TestChromatinDatasetSparse(TestChromatinDataset): + + def setUp(self) -> None: + self.peaks = PEAKS_SP.copy() + self.X = X_SP.copy()