-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1069da5
commit e6d49bd
Showing
5 changed files
with
328 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch | ||
from scipy.sparse import isspmatrix | ||
|
||
|
||
class ChromatinDataset(torch.utils.data.Dataset): | ||
|
||
expression = None | ||
chromatin = None | ||
|
||
n = None | ||
|
||
@property | ||
def expression_sparse(self): | ||
if self.expression is None: | ||
return None | ||
else: | ||
return isspmatrix(self.expression) | ||
|
||
@property | ||
def chromatin_sparse(self): | ||
if self.chromatin is None: | ||
return None | ||
else: | ||
return isspmatrix(self.chromatin) | ||
|
||
def __init__( | ||
self, | ||
gene_expression_data, | ||
chromatin_state_data | ||
): | ||
self.n = self._n_samples( | ||
gene_expression_data, | ||
chromatin_state_data | ||
) | ||
|
||
self.chromatin = chromatin_state_data | ||
self.expression = gene_expression_data | ||
|
||
if not torch.is_tensor(self.expression) and not self.expression_sparse: | ||
self.expression = torch.Tensor(self.expression) | ||
|
||
if not torch.is_tensor(self.chromatin) and not self.chromatin_sparse: | ||
self.chromatin = torch.Tensor(self.chromatin) | ||
|
||
def __getitem__(self, i): | ||
|
||
e = self.expression[i, :] | ||
c = self.chromatin[i, :] | ||
|
||
if isspmatrix(e): | ||
e = torch.Tensor(e.A.ravel()) | ||
|
||
if isspmatrix(c): | ||
c = torch.Tensor(c.A.ravel()) | ||
|
||
return e, c | ||
|
||
def __len__(self): | ||
return self.n | ||
|
||
@staticmethod | ||
def _n_samples(expr, peaks): | ||
|
||
_n_expr = expr.shape[0] | ||
_n_peaks = peaks.shape[0] | ||
|
||
if _n_expr != _n_peaks: | ||
raise ValueError( | ||
f"Expression data {expr.shape} and peak data {peaks.shape} " | ||
"do not have the same number of observations" | ||
) | ||
|
||
return _n_expr | ||
|
||
|
||
class ChromatinDataLoader(torch.utils.data.DataLoader): | ||
|
||
def __init__(self, args, **kwargs): | ||
kwargs['collate_fn'] = chromatin_collate_fn | ||
super().__init__(args, **kwargs) | ||
|
||
|
||
def chromatin_collate_fn(data): | ||
|
||
return ( | ||
torch.stack([d[0] for d in data]), | ||
torch.stack([d[1] for d in data]) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch | ||
|
||
from ._base_trainer import ( | ||
_TrainingMixin | ||
) | ||
|
||
|
||
class ChromatinModule( | ||
torch.nn.Module, | ||
_TrainingMixin | ||
): | ||
|
||
type_name = 'chromatin' | ||
|
||
hidden_state = None | ||
|
||
g = None | ||
k = None | ||
p = None | ||
|
||
def __init__( | ||
self, | ||
g, | ||
p, | ||
k=50, | ||
input_dropout_rate=0.5, | ||
hidden_dropout_rate=0.0 | ||
): | ||
""" | ||
Initialize a chromatin state model | ||
:param g: Number of genes/transcripts | ||
:type g: int | ||
:param p: Number of peaks | ||
:type p: p | ||
:param k: Number of internal model nodes | ||
:type k: int | ||
:param input_dropout_rate: _description_, defaults to 0.5 | ||
:type input_dropout_rate: float, optional | ||
:param hidden_dropout_rate: _description_, defaults to 0.0 | ||
:type hidden_dropout_rate: float, optional | ||
""" | ||
super().__init__() | ||
|
||
self.g = g | ||
self.k = k | ||
self.p = p | ||
|
||
self.set_dropouts( | ||
input_dropout_rate, | ||
hidden_dropout_rate | ||
) | ||
|
||
self.model = torch.nn.Sequential( | ||
torch.nn.Dropout(input_dropout_rate), | ||
torch.nn.Linear( | ||
g, | ||
k, | ||
bias=False | ||
), | ||
torch.nn.Tanh(), | ||
torch.nn.Dropout(hidden_dropout_rate), | ||
torch.nn.Linear( | ||
k, | ||
k, | ||
bias=False | ||
), | ||
torch.nn.Softplus(threshold=5), | ||
torch.nn.Linear( | ||
k, | ||
p, | ||
bias=False | ||
), | ||
torch.nn.Sigmoid() | ||
) | ||
|
||
def forward( | ||
self, | ||
x | ||
): | ||
return self.model(x) | ||
|
||
def _slice_data_and_forward(self, train_x): | ||
return self.output_model( | ||
self( | ||
self.input_data(train_x) | ||
) | ||
) | ||
|
||
def input_data(self, x): | ||
return x[0] | ||
|
||
def output_data(self, x, **kwargs): | ||
return x[1] | ||
|
||
def output_model(self, x): | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import unittest | ||
|
||
from ._stubs import ( | ||
X, | ||
X_SP, | ||
PEAKS, | ||
PEAKS_SP | ||
) | ||
|
||
from supirfactor_dynamical._utils.chromatin_dataset import ( | ||
ChromatinDataset, | ||
ChromatinDataLoader | ||
) | ||
|
||
from supirfactor_dynamical.models.chromatin_model import ( | ||
ChromatinModule | ||
) | ||
|
||
|
||
class TestChromatinTraining(unittest.TestCase): | ||
|
||
def setUp(self) -> None: | ||
self.peaks = PEAKS.copy() | ||
self.X = X.copy() | ||
|
||
self.data = ChromatinDataset( | ||
self.X, | ||
self.peaks | ||
) | ||
|
||
self.dataloader = ChromatinDataLoader( | ||
self.data, | ||
batch_size=2 | ||
) | ||
|
||
def test_train(self): | ||
|
||
model = ChromatinModule( | ||
4, | ||
25, | ||
k=10 | ||
) | ||
|
||
model.train_model( | ||
self.dataloader, | ||
10 | ||
) | ||
|
||
self.assertEqual( | ||
len(model.training_loss), | ||
10 | ||
) | ||
|
||
|
||
class TestChromatinTrainingSparse(TestChromatinTraining): | ||
|
||
def setUp(self) -> None: | ||
self.peaks = PEAKS_SP.copy() | ||
self.X = X_SP.copy() | ||
|
||
self.data = ChromatinDataset( | ||
self.X, | ||
self.peaks | ||
) | ||
|
||
self.dataloader = ChromatinDataLoader( | ||
self.data, | ||
batch_size=2 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import unittest | ||
|
||
from ._stubs import ( | ||
X, | ||
X_SP, | ||
PEAKS, | ||
PEAKS_SP | ||
) | ||
|
||
from supirfactor_dynamical._utils.chromatin_dataset import ( | ||
ChromatinDataset, | ||
ChromatinDataLoader | ||
) | ||
|
||
|
||
class TestChromatinDataset(unittest.TestCase): | ||
|
||
def setUp(self) -> None: | ||
self.peaks = PEAKS.copy() | ||
self.X = X.copy() | ||
|
||
def test_init(self): | ||
|
||
data = ChromatinDataset( | ||
self.X, | ||
self.peaks | ||
) | ||
|
||
self.assertEqual( | ||
len(data), | ||
100 | ||
) | ||
|
||
dataloader = ChromatinDataLoader( | ||
data, | ||
batch_size=2 | ||
) | ||
|
||
ld = [d for d in dataloader] | ||
|
||
self.assertEqual( | ||
len(ld), | ||
50 | ||
) | ||
|
||
self.assertEqual( | ||
ld[0][0].shape, | ||
(2, 4) | ||
) | ||
|
||
self.assertEqual( | ||
ld[0][1].shape, | ||
(2, 25) | ||
) | ||
|
||
def test_misaligned(self): | ||
|
||
with self.assertRaises(ValueError): | ||
ChromatinDataset( | ||
self.X, | ||
self.peaks.T | ||
) | ||
|
||
|
||
class TestChromatinDatasetSparse(TestChromatinDataset): | ||
|
||
def setUp(self) -> None: | ||
self.peaks = PEAKS_SP.copy() | ||
self.X = X_SP.copy() |