Skip to content

Commit

Permalink
Update grn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject authored Oct 23, 2024
1 parent f9ee0b8 commit 2991337
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions scprint/tasks/grn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
num_genes: int = 3000,
precision: str = "16-mixed",
cell_type_col: str = "cell_type",
max_len: int = 3000,
how: str = "random expr", # random expr, most var within, most var across, given
preprocess: str = "softmax", # sinkhorn, softmax, none
head_agg: str = "mean", # mean, sum, none
Expand All @@ -53,7 +54,6 @@ def __init__(
known_grn: Optional[any] = None,
symmetrize: bool = False,
doplot: bool = True,
max_cells: int = 0,
forward_mode: str = "none",
genes: List[str] = [],
loc: str = "./",
Expand Down Expand Up @@ -190,15 +190,15 @@ def predict(self, model, adata, layer, cell_type=None):
adataset = SimpleAnnDataset(
subadata, obs_to_output=["organism_ontology_term_id"]
)
self.col = Collator(
col = Collator(
organisms=model.organisms,
valid_genes=model.genes,
how="some" if self.how != "random expr" else "random expr",
genelist=self.curr_genes if self.how != "random expr" else [],
)
dataloader = DataLoader(
adataset,
collate_fn=self.col,
collate_fn=col,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
Expand Down

0 comments on commit 2991337

Please sign in to comment.