Skip to content

Commit

Permalink
Updated consume() in CrDirReader [Polars -> Pandas] (#134)
Browse files Browse the repository at this point in the history
* Updated consume() in CrDirReader. Fixed polar issue for reading compressed file

* Comment Cleanup
  • Loading branch information
Gautam8387 authored Dec 17, 2024
1 parent fde515c commit fd40926
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions scarf/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,15 +424,15 @@ def process_batch(self, dfs: List[pd.DataFrame], filtering_cutoff: int) -> np.ar
"""
pl_dfs = [pl.DataFrame(df) for df in dfs]
pl_dfs = pl.concat(pl_dfs)
dfs_ = pl_dfs.group_by('barcode').agg(pl.sum('count'))
dfs_ = dfs_.filter(pl.col('count') > filtering_cutoff)
return np.sort(dfs_['barcode'])
dfs_ = pl_dfs.group_by("barcode").agg(pl.sum("count"))
dfs_ = dfs_.filter(pl.col("count") > filtering_cutoff)
return np.sort(dfs_["barcode"])

def _get_valid_barcodes(
self,
filtering_cutoff: int,
batch_size: int = int(10e3),
lines_in_mem: int = int(10e6),
lines_in_mem: int = int(10e5),
) -> np.ndarray:
"""Returns a list of valid barcodes after filtering out background barcodes.
Expand Down Expand Up @@ -524,7 +524,8 @@ def cell_names(self) -> List[str]:
vals = vals[(self.validBarcodeIdx + self.indexOffset)]
return list(vals)

def rename_batches(self, collect: List[pl.DataFrame], batch_size: int) -> List:
def rename_batches(self, collect: List[pd.DataFrame]) -> List:
collect = [pl.DataFrame(df) for df in collect]
df = pl.concat(collect)
barcodes = np.array(df["barcode"])
count_hash = {}
Expand All @@ -548,44 +549,38 @@ def consume(
lines_in_mem: The number of lines to read into memory.
dtype: The data type of the matrix.
"""
matrixIO = pl.read_csv_batched(
matrixIO = pd.read_csv(
self.matFn,
has_header=False,
separator=self.sep,
comment_prefix="%",
skip_rows_after_header=1,
new_columns=["gene", "barcode", "count"],
schema_overrides={"gene": pl.Int64, "barcode": pl.Int64, "count": pl.Int64},
batch_size=lines_in_mem,
comment="%",
sep=self.sep,
header=0,
chunksize=lines_in_mem,
names=["gene", "barcode", "count"],
)
unique_list = []
collect = []
while True:
chunk = matrixIO.next_batches(1)
if chunk is None:
break
chunk = chunk[0]
chunk = chunk.filter(pl.col("barcode").is_in(self.validBarcodeIdx))
in_uniques = np.unique(chunk["barcode"])
for chunk in matrixIO:
chunk = chunk[chunk["barcode"].isin(self.validBarcodeIdx)]
in_uniques = np.unique(chunk["barcode"].values)
unique_list.extend(in_uniques)
unique_list = list(set(unique_list))
if len(unique_list) > batch_size:
diff = batch_size - (len(unique_list) - len(in_uniques))
mask_pos = in_uniques[:diff]
mask_neg = in_uniques[diff:]
extra = chunk.filter(pl.col("barcode").is_in(mask_pos))
extra = chunk[chunk["barcode"].isin(mask_pos)]
collect.append(extra)
collect = self.rename_batches(collect, batch_size)
collect = self.rename_batches(collect)
mtx = self.to_sparse(np.array(collect), dtype=dtype)
yield mtx
left_out = chunk.filter(pl.col("barcode").is_in(mask_neg))
left_out = chunk[chunk["barcode"].isin(mask_neg)]
collect = []
unique_list = list(mask_neg)
collect.append(left_out)
else:
collect.append(chunk)
if len(collect) > 0:
collect = self.rename_batches(collect, batch_size)
collect = self.rename_batches(collect)
mtx = self.to_sparse(np.array(collect), dtype=dtype)
yield mtx

Expand Down

0 comments on commit fd40926

Please sign in to comment.