diff --git a/pyproject.toml b/pyproject.toml index 17532fea..5ef5795b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ ] [project.optional-dependencies] -cuda = [ +mamba_cuda = [ "mamba_ssm==1.2.0.post1", "causal-conv1d==1.2.0.post2", ] diff --git a/requirements.txt b/requirements.txt index 1757bc9e..10c94e31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ # this references dependencies in pyproject.toml (including optional cuda dependencies) . -.[cuda] \ No newline at end of file +.[mamba_cuda] \ No newline at end of file