Skip to content

Commit

Permalink
switched from R-deseq2 to python-deseq2
Browse files Browse the repository at this point in the history
  • Loading branch information
vicpaton committed May 7, 2024
1 parent 109188c commit c3ea543
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 65 deletions.
86 changes: 28 additions & 58 deletions networkcommons/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from rpy2.robjects.packages import importr
from rpy2.robjects.conversion import localconverter
import decoupler as dc
from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats

def get_available_datasets():
public_link="https://oc.embl.de/index.php/s/6KsHfeoqJOKLF6B"
Expand Down Expand Up @@ -54,71 +57,38 @@ def download_url(url, save_path, chunk_size=128):
fd.write(chunk)


def deseq2_analysis(counts,
metadata,
covariates="",
deseq2_test='Wald',
deseq2_fitType='parametric',
deseq2_betaprior=False,
deseq2_quiet=False,
deseq2_minReplicatesForReplace=7,
):
"""
Perform DESeq2 analysis using rpy2.
Parameters:
counts (DataFrame): A pandas DataFrame containing raw count data.
metadata (DataFrame): A pandas DataFrame containing metadata.
additional_args (dict): Additional arguments for DESeq2 analysis.
Returns:
DESeq2 results as a DataFrame.
"""
# Importing required R packages
DESeq2 = importr("DESeq2")
base = importr("base")

# Set genesymbol as rownames
counts.set_index('gene_symbol', inplace=True)
metadata.set_index('sample_ID', inplace=True)

# Convert pandas DataFrames to R DataFrames
pandas2ri.activate()
gene_counts = pandas2ri.py2rpy(counts)
metadata_r = pandas2ri.py2rpy(metadata)

if covariates != "" and len(covariates)>=1:
covariates = ["" + covariates]

# Create design formula
design_formula = robjects.Formula("~ 0 + group" + " + ".join(covariates))
def run_deseq2_analysis(counts,
metadata,
test_group,
ref_group,
covariates=[]):

counts.set_index('gene_symbol', inplace=True)
metadata.set_index('sample_ID', inplace=True)

# Create DESeqDataSet object
formatted_data = DESeq2.DESeqDataSetFromMatrix(countData=gene_counts,
colData=metadata_r,
design=design_formula)

# Get study groups
studygroups = list(set(metadata['group']))
design_factors = ['group']

if len(covariates) > 0:
if isinstance(covariates, str):
covariates = [covariates]
design_factors += covariates

inference = DefaultInference(n_cpus=8)
dds = DeseqDataSet(
counts=counts.T,
metadata=metadata,
design_factors=design_factors,
refit_cooks=True,
inference=inference
)
dds.deseq2()

# Run DESeq2 analysis
results = DESeq2.DESeq(formatted_data,
test=deseq2_test,
fitType=deseq2_fitType,
betaPrior=deseq2_betaprior,
quiet=deseq2_quiet,
minReplicatesForReplace=deseq2_minReplicatesForReplace)
results = DESeq2.results(results, contrast=robjects.StrVector(['group', studygroups[0], studygroups[1]]))
results = base.as_data_frame(results)
results = DeseqStats(dds, contrast=["group", test_group, ref_group], inference=inference)
results.summary()
return results.results_df.astype('float64')

# Convert DESeq2 results to pandas DataFrame
with localconverter(robjects.default_converter + pandas2ri.converter):
results_df = robjects.conversion.rpy2py(results)


return results_df



Expand Down
15 changes: 8 additions & 7 deletions networkcommons/test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_deseq2_analysis():
})

# Call the deseq2_analysis function
result = deseq2_analysis(counts, metadata)
result = run_deseq2_analysis(counts, metadata, ref_group='Control', test_group='Treatment')

# Assert that the returned value is a pandas DataFrame
assert isinstance(result, pd.DataFrame)
Expand All @@ -55,13 +55,14 @@ def test_deseq2_analysis():

# Assert that the DataFrame has the expected content
data = {
'baseMean': [93.233027, 101.285704, 11.793541],
'log2FoldChange': [-0.218172, 0.682183, 0.052954],
'lfcSE': [0.328036, 0.352393, 0.521659],
'stat': [-0.665087, 1.935862, 0.101510],
'pvalue': [0.505995, 0.052885, 0.919146],
'padj': [0.758992, 0.158654, 0.919146]
'baseMean': [93.233032, 101.285698, 11.793541],
'log2FoldChange': [0.222414, -0.682183, -0.052951],
'lfcSE': [0.150059, 0.352411, 0.521689],
'stat': [1.482173, -1.935763, -0.101499],
'pvalue': [0.138294, 0.052897, 0.919154],
'padj': [0.207441, 0.158690, 0.919154]
}

expected_result = pd.DataFrame(data, index=['Gene1', 'Gene2', 'Gene3'])
expected_result.index.name = 'gene_symbol'
pd.testing.assert_frame_equal(result, expected_result, check_exact=False)

0 comments on commit c3ea543

Please sign in to comment.