diff --git a/scarf/readers.py b/scarf/readers.py index 69ef4f9..148059f 100644 --- a/scarf/readers.py +++ b/scarf/readers.py @@ -424,15 +424,15 @@ def process_batch(self, dfs: List[pd.DataFrame], filtering_cutoff: int) -> np.ar """ pl_dfs = [pl.DataFrame(df) for df in dfs] pl_dfs = pl.concat(pl_dfs) - dfs_ = pl_dfs.group_by('barcode').agg(pl.sum('count')) - dfs_ = dfs_.filter(pl.col('count') > filtering_cutoff) - return np.sort(dfs_['barcode']) + dfs_ = pl_dfs.group_by("barcode").agg(pl.sum("count")) + dfs_ = dfs_.filter(pl.col("count") > filtering_cutoff) + return np.sort(dfs_["barcode"]) def _get_valid_barcodes( self, filtering_cutoff: int, batch_size: int = int(10e3), - lines_in_mem: int = int(10e6), + lines_in_mem: int = int(10e5), ) -> np.ndarray: """Returns a list of valid barcodes after filtering out background barcodes. @@ -524,7 +524,8 @@ def cell_names(self) -> List[str]: vals = vals[(self.validBarcodeIdx + self.indexOffset)] return list(vals) - def rename_batches(self, collect: List[pl.DataFrame], batch_size: int) -> List: + def rename_batches(self, collect: List[pd.DataFrame]) -> List: + collect = [pl.DataFrame(df) for df in collect] df = pl.concat(collect) barcodes = np.array(df["barcode"]) count_hash = {} @@ -548,44 +549,38 @@ def consume( lines_in_mem: The number of lines to read into memory. dtype: The data type of the matrix. """ - matrixIO = pl.read_csv_batched( + matrixIO = pd.read_csv( self.matFn, - has_header=False, - separator=self.sep, - comment_prefix="%", - skip_rows_after_header=1, - new_columns=["gene", "barcode", "count"], - schema_overrides={"gene": pl.Int64, "barcode": pl.Int64, "count": pl.Int64}, - batch_size=lines_in_mem, + comment="%", + sep=self.sep, + header=0, + chunksize=lines_in_mem, + names=["gene", "barcode", "count"], ) unique_list = [] collect = [] - while True: - chunk = matrixIO.next_batches(1) - if chunk is None: - break - chunk = chunk[0] - chunk = chunk.filter(pl.col("barcode").is_in(self.validBarcodeIdx)) - in_uniques = np.unique(chunk["barcode"]) + for chunk in matrixIO: + chunk = chunk[chunk["barcode"].isin(self.validBarcodeIdx)] + in_uniques = np.unique(chunk["barcode"].values) unique_list.extend(in_uniques) unique_list = list(set(unique_list)) if len(unique_list) > batch_size: diff = batch_size - (len(unique_list) - len(in_uniques)) mask_pos = in_uniques[:diff] mask_neg = in_uniques[diff:] - extra = chunk.filter(pl.col("barcode").is_in(mask_pos)) + extra = chunk[chunk["barcode"].isin(mask_pos)] collect.append(extra) - collect = self.rename_batches(collect, batch_size) + collect = self.rename_batches(collect) mtx = self.to_sparse(np.array(collect), dtype=dtype) yield mtx - left_out = chunk.filter(pl.col("barcode").is_in(mask_neg)) + left_out = chunk[chunk["barcode"].isin(mask_neg)] collect = [] unique_list = list(mask_neg) collect.append(left_out) else: collect.append(chunk) if len(collect) > 0: - collect = self.rename_batches(collect, batch_size) + collect = self.rename_batches(collect) mtx = self.to_sparse(np.array(collect), dtype=dtype) yield mtx