Skip to content

Commit

Permalink
style: reorganize code
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshLoecker committed Dec 17, 2024
1 parent d620960 commit 9541824
Showing 1 changed file with 45 additions and 27 deletions.
72 changes: 45 additions & 27 deletions main/como/rnaseq_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ async def _build_matrix_results(
:param taxon: The NCBI Taxon ID
:return: A dataclass `ReadMatrixResults`
"""
gene_info = gene_info_migrations(gene_info)
conversion = await ensembl_to_gene_id_and_symbol(ids=matrix["ensembl_gene_id"].tolist(), taxon=taxon)
conversion["ensembl_gene_id"] = conversion["ensembl_gene_id"].str.split(",")
conversion = conversion.explode("ensembl_gene_id")
conversion.reset_index(inplace=True, drop=True)
matrix = matrix.merge(conversion, on="ensembl_gene_id", how="left")

# Only include Entrez and Ensembl Gene IDs that are present in `gene_info`
Expand All @@ -181,6 +183,7 @@ async def _build_matrix_results(
matrix = matrix.replace(to_replace="-", value=pd.NA).dropna()
matrix["entrez_gene_id"] = matrix["entrez_gene_id"].astype(int)

gene_info = gene_info_migrations(gene_info)
gene_info = gene_info.replace(to_replace="-", value=pd.NA).dropna()
gene_info["entrez_gene_id"] = gene_info["entrez_gene_id"].astype(int)

Expand All @@ -189,6 +192,7 @@ async def _build_matrix_results(
on=["entrez_gene_id", "ensembl_gene_id"],
how="inner",
)

gene_info = gene_info.merge(
counts_matrix[["entrez_gene_id", "ensembl_gene_id"]],
on=["entrez_gene_id", "ensembl_gene_id"],
Expand Down Expand Up @@ -552,7 +556,7 @@ def tpm_quantile_filter(*, metrics: NamedMetrics, filtering_options: _FilteringO

# Only keep `entrez_gene_ids` that pass `min_genes`
metric.entrez_gene_ids = [gene for gene, keep in zip(entrez_ids, min_genes) if keep]
metric.gene_sizes = [gene for gene, keep in zip(gene_size, min_genes) if keep]
metric.gene_sizes = np.array(gene for gene, keep in zip(gene_size, min_genes) if keep)
metric.count_matrix = metric.count_matrix.iloc[min_genes, :]
metric.normalization_matrix = metrics[sample].normalization_matrix.iloc[min_genes, :]

Expand All @@ -569,15 +573,13 @@ def zfpkm_filter(*, metrics: NamedMetrics, filtering_options: _FilteringOptions,
min_sample_expression = filtering_options.replicate_ratio
high_confidence_sample_expression = filtering_options.high_replicate_ratio
cut_off = filtering_options.cut_off

if calcualte_fpkm:
metrics = calculate_fpkm(metrics)
metrics = calculate_fpkm(metrics) if calcualte_fpkm else metrics

metric: _StudyMetrics
for metric in metrics.values():
# if fpkm was not calculated, the normalization matrix will be empty; collect the count matrix instead
matrix = metric.count_matrix if metric.normalization_matrix.empty else metric.normalization_matrix
matrix = matrix[matrix.sum(axis=1) > 0]
matrix = matrix[matrix.sum(axis=1) > 0] # remove rows (genes) that have no counts

minimums = matrix == 0
results, zfpkm_df = zfpkm_transform(matrix)
Expand Down Expand Up @@ -628,7 +630,6 @@ async def _save_rnaseq_tests(
rnaseq_matrix: pd.DataFrame,
metadata_df: pd.DataFrame,
gene_info_df: pd.DataFrame,
output_filepath: Path,
prep: RNAPrepMethod,
taxon: int,
replicate_ratio: float,
Expand All @@ -637,6 +638,8 @@ async def _save_rnaseq_tests(
high_batch_ratio: float,
technique: FilteringTechnique,
cut_off: int | float,
output_boolean_activity_filepath: Path,
output_zscore_normalization_filepath: Path,
):
"""Save the results of the RNA-Seq tests to a CSV file."""
filtering_options = _FilteringOptions(
Expand All @@ -656,19 +659,32 @@ async def _save_rnaseq_tests(
metrics = read_counts_results.metrics
entrez_gene_ids = read_counts_results.entrez_gene_ids

metrics = filter_counts(
metrics: NamedMetrics = filter_counts(
context_name=context_name,
metrics=metrics,
technique=technique,
filtering_options=filtering_options,
prep=prep,
)

merged_zscore_df = pd.DataFrame()
expressed_genes: list[str] = []
top_genes: list[str] = []
for metric in metrics.values():
expressed_genes.extend(metric.entrez_gene_ids)
top_genes.extend(metric.high_confidence_entrez_gene_ids)
if metric.normalization_matrix is not None:
merged_zscore_df = (
metric.normalization_matrix
if merged_zscore_df.empty
else pd.concat(
[merged_zscore_df, metric.normalization_matrix],
axis=1,
)
)
merged_zscore_df.index = pd.Series(entrez_gene_ids, name="entrez_gene_id")
merged_zscore_df.to_csv(output_zscore_normalization_filepath, index=True)
logger.success(f"Wrote z-score normalization matrix to {output_zscore_normalization_filepath}")

expression_frequency = pd.Series(expressed_genes).value_counts()
expression_df = pd.DataFrame(
Expand All @@ -692,11 +708,11 @@ async def _save_rnaseq_tests(
expressed_count = len(boolean_matrix[boolean_matrix["expressed"] == 1])
high_confidence_count = len(boolean_matrix[boolean_matrix["high"] == 1])

boolean_matrix.to_csv(output_filepath, index=False)
boolean_matrix.to_csv(output_boolean_activity_filepath, index=False)
logger.info(
f"{context_name} - Found {expressed_count} expressed and {high_confidence_count} confidently expressed genes"
)
logger.success(f"Wrote boolean matrix to {output_filepath}")
logger.success(f"Wrote boolean matrix to {output_boolean_activity_filepath}")


async def _create_metadata_df(path: Path) -> pd.DataFrame:
Expand All @@ -708,15 +724,15 @@ async def _create_metadata_df(path: Path) -> pd.DataFrame:
return pd.read_excel(path)


async def rnaseq_gen( # noqa: C901, allow complex function
async def rnaseq_gen(
context_name: str,
input_rnaseq_filepath: Path,
input_gene_info_filepath: Path,
output_rnaseq_filepath: Path,
prep: RNAPrepMethod,
taxon: int,
input_metadata_filepath: Path | None = None,
input_metadata_df: pd.DataFrame | None = None,
taxon_id: int,
output_boolean_activity_filepath: Path,
output_zscore_normalization_filepath: Path,
input_metadata_filepath_or_df: Path | pd.DataFrame,
replicate_ratio: float = 0.5,
high_replicate_ratio: float = 1.0,
batch_ratio: float = 0.5,
Expand All @@ -733,11 +749,11 @@ async def rnaseq_gen( # noqa: C901, allow complex function
:param context_name: The name of the context being processed
:param input_rnaseq_filepath: The filepath to the gene count matrix
:param input_gene_info_filepath: The filepath to the gene info file
:param output_rnaseq_filepath: The filepath to write the output gene count matrix
:param output_boolean_activity_filepath: The filepath to write the output gene count matrix
:param output_zscore_normalization_filepath: The filepath to write the output z-score normalization matrix
:param prep: The preparation method
:param taxon: The NCBI Taxon ID
:param input_metadata_filepath: The filepath to the metadata file
:param input_metadata_df: The metadata dataframe
:param taxon_id: The NCBI Taxon ID
:param input_metadata_filepath_or_df: The filepath or dataframe containing metadata information
:param replicate_ratio: The percentage of replicates that a gene must
appear in for a gene to be marked as "active" in a batch/study
:param batch_ratio: The percentage of batches that a gene must appear in for a gene to be marked as 'active"
Expand All @@ -749,9 +765,6 @@ async def rnaseq_gen( # noqa: C901, allow complex function
:param cutoff: The cutoff value to use for the provided filtering technique
:return: None
"""
if not input_metadata_df and not input_metadata_filepath:
raise ValueError("At least one of input_metadata_filepath or input_metadata_df must be provided")

technique = (
FilteringTechnique.from_string(str(technique.lower())) if isinstance(technique, (str, int)) else technique
)
Expand Down Expand Up @@ -787,20 +800,25 @@ async def rnaseq_gen( # noqa: C901, allow complex function
)

logger.debug(f"Starting '{context_name}'")
output_rnaseq_filepath.parent.mkdir(parents=True, exist_ok=True)

output_boolean_activity_filepath.parent.mkdir(parents=True, exist_ok=True)
metadata_df = (
input_metadata_filepath_or_df
if isinstance(input_metadata_filepath_or_df, pd.DataFrame)
else await _create_metadata_df(input_metadata_filepath_or_df)
)
await _save_rnaseq_tests(
context_name=context_name,
rnaseq_matrix=await _read_counts(input_rnaseq_filepath),
metadata_df=input_metadata_df or await _create_metadata_df(input_metadata_filepath),
metadata_df=metadata_df,
gene_info_df=pd.read_csv(input_gene_info_filepath),
output_filepath=output_rnaseq_filepath,
prep=prep,
taxon=taxon,
taxon=taxon_id,
replicate_ratio=replicate_ratio,
batch_ratio=batch_ratio,
high_replicate_ratio=high_replicate_ratio,
high_batch_ratio=high_batch_ratio,
technique=technique,
cut_off=cutoff,
output_boolean_activity_filepath=output_boolean_activity_filepath,
output_zscore_normalization_filepath=output_zscore_normalization_filepath,
)

0 comments on commit 9541824

Please sign in to comment.