Skip to content

Commit

Permalink
bump torch and remove extra
Browse files Browse the repository at this point in the history
  • Loading branch information
eegli committed Feb 3, 2025
1 parent 9b96d48 commit 8498712
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 449 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
TORCH_VERSION = 2.4.1
TORCH_VERSION = 2.6.0
CUDA_INDEX_URL = https://download.pytorch.org/whl/cu124

MAMBA_VERSION = 2.2.2
Expand All @@ -18,7 +18,7 @@ all: format lint check_types test
.PHONY: .install_common
.install_common:
@echo "Installing common Python dependencies"
uv sync --all-extras
uv sync

.PHONY: install_common_ci
install_common_ci:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ uv add mblm

### Using Torch and Mamba

You will need to **install a recent PyTorch version manually**. We use `>=2.4.1`. It is best to do this after installing the package since some sub-dependencies might install their own (CPU) PyTorch version.
You will need to **install a recent PyTorch version manually**. We use `>=2.6.0`. It is best to do this after installing the package since some sub-dependencies might install their own (CPU) PyTorch version.

```
pip install 'torch>=2.4.1' --index-url https://download.pytorch.org/whl/cu124
pip install 'torch>=2.6.0' --index-url https://download.pytorch.org/whl/cu124
```

Finally, in order to use the efficient [Mamba-SSM](https://github.com/state-spaces/mamba), follow their instructions on the homepage. You'll need Linux and a GPU available during installation.
Expand Down
12 changes: 1 addition & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,6 @@ dev = [
"polars>=1.18.0 ; sys_platform != 'linux'",
]

[project.optional-dependencies]
analysis = [
"rouge-score>=0.1.2",
"tabulate>=0.9.0",
"types-tabulate>=0.9.0.20240106",
"vegafusion[embed]<=2.0.0",
"polars-lts-cpu>=1.18.0 ; sys_platform == 'linux'",
"polars>=1.18.0 ; sys_platform != 'linux'",
]


[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down Expand Up @@ -108,6 +97,7 @@ addopts = [
"--cov-report=term-missing",
"--cov=mblm",
]
filterwarnings = ["ignore::UserWarning:mblm.model.mamba_shim"]

[tool.ruff]
line-length = 100
Expand Down
11 changes: 4 additions & 7 deletions src/mblm/model/mamba_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@
"""

import os
import warnings
from functools import partial
from typing import cast

import torch

from mblm.utils.misc import once

Mamba1 = None
Mamba1Config = None
Mamba2Mixer = None # type: ignore[no-redef]
Expand Down Expand Up @@ -182,14 +181,12 @@ def forward(
if sys.platform.startswith("linux"):
reason_failed = err.msg

@once
def warn_mamba_import():
skip_warning = "PYTEST_CURRENT_TEST" in os.environ

if not skip_warning:
warnings.warn(
f"Failed to import Mamba2, falling back to Mamba1 (PyTorch version). Reason: {reason_failed}",
category=ImportWarning,
)

warn_mamba_import()


__all__ = ["Mamba1", "Mamba1Config", "Mamba2Mixer"]
15 changes: 0 additions & 15 deletions src/mblm/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,3 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> T | None:
return inner

return wrapper


def once(f: Callable[P, T]) -> Callable[P, T | None]:
"""Run a function only once and return None afterwards"""
has_run = False

def inner(*args: P.args, **kwargs: P.kwargs) -> T | None:
nonlocal has_run
if has_run:
return None

has_run = True
return f(*args, **kwargs)

return inner
1 change: 1 addition & 0 deletions tests/integration/install/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dev = ["pytest>=8.3.3"]
[tool.pytest.ini_options]
testpaths = ["."]
addopts = ["--import-mode=importlib"]
filterwarnings = ["ignore::UserWarning:mblm.model.mamba_shim"]

[tool.ruff]
line-length = 80
Expand Down
16 changes: 0 additions & 16 deletions tests/integration/install/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 1 addition & 9 deletions tests/unit/utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from pytest_mock import MockerFixture

from mblm.utils.misc import once, retry
from mblm.utils.misc import retry


class FailThenSuccess:
Expand Down Expand Up @@ -51,11 +51,3 @@ def test_retry(
assert result is expected_result
assert try_func_spy.call_count == expected_calls
assert on_error_stub.call_count == min(n_inner_fails, n_retries + 1)

def test_once(self):
@once
def func():
return True

assert func()
assert func() is None
Loading

0 comments on commit 8498712

Please sign in to comment.