diff --git a/gears/data_utils.py b/gears/data_utils.py index f97972f..463b917 100644 --- a/gears/data_utils.py +++ b/gears/data_utils.py @@ -81,7 +81,12 @@ def get_dropout_non_zero_genes(adata): for i, j in conditions2index.items(): condition2mean_expression[i] = np.mean(adata.X[j], axis = 0) pert_list = np.array(list(condition2mean_expression.keys())) - mean_expression = np.array(list(condition2mean_expression.values())).reshape(len(adata.obs.condition.unique()), adata.X.toarray().shape[1]) + # to handle the non-sparse data input + try : + adata.X = adata.X.toarray() + mean_expression = np.array(list(condition2mean_expression.values())).reshape(len(adata.obs.condition.unique()), adata.X.shape[1]) + except: + mean_expression = np.array(list(condition2mean_expression.values())).reshape(len(adata.obs.condition.unique()), adata.X.shape[1]) ctrl = mean_expression[np.where(pert_list == 'ctrl')[0]] ## in silico modeling and upperbounding @@ -408,4 +413,4 @@ def get_genes_from_perts(self, perts): gene_list = [p.split('+') for p in np.unique(perts)] gene_list = [item for sublist in gene_list for item in sublist] gene_list = [g for g in gene_list if g != 'ctrl'] - return np.unique(gene_list) \ No newline at end of file + return np.unique(gene_list) diff --git a/gears/pertdata.py b/gears/pertdata.py index 801f289..232387c 100644 --- a/gears/pertdata.py +++ b/gears/pertdata.py @@ -594,8 +594,13 @@ def create_cell_graph_dataset(self, split_adata, pert_category, # Create cell graphs cell_graphs = [] for X, y in zip(Xs, ys): - cell_graphs.append(self.create_cell_graph(X.toarray(), - y.toarray(), de_idx, pert_category, pert_idx)) + try: + cell_graphs.append(self.create_cell_graph(X.toarray(), + y, de_idx, pert_category, pert_idx)) + except: + y = y.toarray() + cell_graphs.append(self.create_cell_graph(X.toarray(), + y, de_idx, pert_category, pert_idx)) return cell_graphs