Skip to content

Commit

Permalink
⚡️Add torch.compile Functionality (#716)
Browse files Browse the repository at this point in the history
- Integrates PyTorch 2.0's [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html) functionality to demonstrate performance improvements in torch code. This PR focuses on adding `torch.compile` to `PatchPredictor`.

**Notes:**
- According to the [documentation](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), noticeable performance can be achieved when using modern NVIDIA GPUs (H100, A100, or V100)

**TODO:**
- [x] Resolve compilation errors related to using `torch.compile` in running models
- [x] Initial config
- [x] Add to patch predictor
- [x] Add to registration
- [x] Add to segmentation
- [x] Test on custom models
- [x] Test on `torch.compile` compatible GPUs

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Shan E Ahmed Raza <[email protected]>
Co-authored-by: Jiaqi-Lv <[email protected]>
  • Loading branch information
4 people authored Nov 15, 2024
1 parent 9113996 commit 32cae0b
Show file tree
Hide file tree
Showing 16 changed files with 383 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
- name: Test with pytest
run: |
pytest --basetemp={envtmpdir} \
--cov=tiatoolbox --cov-report=term --cov-report=xml \
--cov=tiatoolbox --cov-report=term --cov-report=xml --cov-config=pyproject.toml \
--capture=sys \
--durations=10 --durations-min=1.0 \
--maxfail=1
Expand Down
38 changes: 37 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

import os
import shutil
import time
from pathlib import Path
from typing import Callable

import pytest
import torch

import tiatoolbox
from tiatoolbox import logger
from tiatoolbox.data import _fetch_remote_sample
from tiatoolbox.utils.env_detection import running_on_ci
from tiatoolbox.utils.env_detection import has_gpu, running_on_ci

# -------------------------------------------------------------------------------------
# Generate Parameterized Tests
Expand Down Expand Up @@ -608,3 +610,37 @@ def data_path(tmp_path_factory: pytest.TempPathFactory) -> dict[str, object]:
(tmp_path / "slides").mkdir()
(tmp_path / "overlays").mkdir()
return {"base_path": tmp_path}


# -------------------------------------------------------------------------------------
# Utility functions
# -------------------------------------------------------------------------------------


def timed(fn: Callable, *args: object) -> (Callable, float):
"""A decorator that times the execution of a function.
Args:
fn (Callable): The function to be timed.
args (object): Arguments to be passed to the function.
Returns:
A tuple containing the result of the function
and the time taken to execute it in seconds.
"""
compile_time = 0.0
if has_gpu():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn(*args)
end.record()
torch.cuda.synchronize()
compile_time = start.elapsed_time(end) / 1000
else:
start = time.time()
result = fn(*args)
end = time.time()
compile_time = end - start
return result, compile_time
8 changes: 7 additions & 1 deletion tests/models/test_nucleus_instance_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import joblib
import numpy as np
import pytest
import torch
import yaml
from click.testing import CliRunner

from tiatoolbox import cli
from tiatoolbox import cli, rcParam
from tiatoolbox.models import (
IOSegmentorConfig,
NucleusInstanceSegmentor,
Expand Down Expand Up @@ -44,7 +45,12 @@ def _crash_func(_x: object) -> None:

def helper_tile_info() -> list:
"""Helper function for tile information."""
torch._dynamo.reset()
current_torch_compile_mode = rcParam["torch_compile_mode"]
rcParam["torch_compile_mode"] = "disable"
predictor = NucleusInstanceSegmentor(model="A")
torch._dynamo.reset()
rcParam["torch_compile_mode"] = current_torch_compile_mode
# ! assuming the tiles organized as follows (coming out from
# ! PatchExtractor). If this is broken, need to check back
# ! PatchExtractor output ordering first
Expand Down
53 changes: 52 additions & 1 deletion tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import torch
from click.testing import CliRunner

from tiatoolbox import cli
from tests.conftest import timed
from tiatoolbox import cli, logger, rcParam
from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor
from tiatoolbox.models.architecture.vanilla import CNNModel
from tiatoolbox.models.dataset import (
Expand Down Expand Up @@ -1226,3 +1227,53 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
assert tmp_path.joinpath("2.merged.npy").exists()
assert tmp_path.joinpath("2.raw.json").exists()
assert tmp_path.joinpath("results.json").exists()


# -------------------------------------------------------------------------------------
# torch.compile
# -------------------------------------------------------------------------------------


def test_patch_predictor_torch_compile(
sample_patch1: Path,
sample_patch2: Path,
tmp_path: Path,
) -> None:
"""Test PatchPredictor with with torch.compile functionality.
Args:
sample_patch1 (Path): Path to sample patch 1.
sample_patch2 (Path): Path to sample patch 2.
tmp_path (Path): Path to temporary directory.
"""
torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "default"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = torch_compile_mode
48 changes: 47 additions & 1 deletion tests/models/test_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from click.testing import CliRunner
from torch import nn

from tiatoolbox import cli
from tests.conftest import timed
from tiatoolbox import cli, logger, rcParam
from tiatoolbox.models import SemanticSegmentor
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.utils import centre_crop
Expand Down Expand Up @@ -897,3 +898,48 @@ def test_cli_semantic_segmentation_multi_file(
_test_pred = (_test_pred[..., 1] > 0.50) * 255

assert np.mean(np.abs(_cache_pred - _test_pred) / 255) < 1e-3


# -------------------------------------------------------------------------------------
# torch.compile
# -------------------------------------------------------------------------------------


def test_semantic_segmentor_torch_compile(
remote_sample: Callable,
tmp_path: Path,
) -> None:
"""Test SemanticSegmentor using pretrained model with torch.compile functionality.
Args:
remote_sample (Callable): Callable object used to extract remote sample.
tmp_path (Path): Path to temporary directory.
"""
torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "default"
_, compile_time = timed(
test_functional_pretrained,
remote_sample,
tmp_path,
)
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
_, compile_time = timed(
test_functional_pretrained,
remote_sample,
tmp_path,
)
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
_, compile_time = timed(
test_functional_pretrained,
remote_sample,
tmp_path,
)
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = torch_compile_mode
41 changes: 40 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
import numpy as np
import pandas as pd
import pytest
import torch
from PIL import Image
from requests import HTTPError
from shapely.geometry import Polygon

from tests.test_annotation_stores import cell_polygon
from tiatoolbox import utils
from tiatoolbox import rcParam, utils
from tiatoolbox.annotation.storage import DictionaryStore, SQLiteStore
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import FileNotSupportedError
from tiatoolbox.utils.transforms import locsize2bounds
Expand Down Expand Up @@ -1827,3 +1829,40 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:
# check correct error is raised if coordinates are missing
with pytest.raises(ValueError, match="coordinates"):
misc.dict_to_store(patch_output, (1.0, 1.0))


def test_torch_compile_already_compiled() -> None:
"""Test that torch_compile does not recompile a model that is already compiled."""
torch_compile_modes = [
"default",
"reduce-overhead",
"max-autotune",
"max-autotune-no-cudagraphs",
]
current_torch_compile_mode = rcParam["torch_compile_mode"]
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))

for mode in torch_compile_modes:
torch._dynamo.reset()
rcParam["torch_compile_mode"] = mode
compiled_model = compile_model(model, mode=mode)
recompiled_model = compile_model(compiled_model, mode=mode)
assert compiled_model == recompiled_model

torch._dynamo.reset()
rcParam["torch_compile_mode"] = current_torch_compile_mode


def test_torch_compile_disable() -> None:
"""Test torch_compile's disable mode."""
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))
compiled_model = compile_model(model, mode="disable")
assert model == compiled_model


def test_torch_compile_compatibility(caplog: pytest.LogCaptureFixture) -> None:
"""Test if torch-compile compatibility is checked correctly."""
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

is_torch_compile_compatible()
assert "torch.compile" in caplog.text
70 changes: 70 additions & 0 deletions tests/test_wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import cv2
import numpy as np
import pytest
import torch

from tests.conftest import timed
from tiatoolbox import logger, rcParam
from tiatoolbox.tools.registration.wsi_registration import (
AffineWSITransformer,
DFBRegister,
Expand Down Expand Up @@ -576,3 +579,70 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None:
expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE)

assert np.sum(expected - output) == 0


def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None:
"""Test DFBRFeatureExtractor with torch.compile functionality.
Args:
dfbr_features (Path): Path to the expected features.
"""

def _extract_features() -> tuple:
dfbr = DFBRegister()
fixed_img = np.repeat(
np.expand_dims(
np.repeat(
np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1),
64,
axis=1,
),
axis=2,
),
3,
axis=2,
)
output = dfbr.extract_features(fixed_img, fixed_img)
pool3_feat = output["block3_pool"][0, :].detach().numpy()
pool4_feat = output["block4_pool"][0, :].detach().numpy()
pool5_feat = output["block5_pool"][0, :].detach().numpy()

return pool3_feat, pool4_feat, pool5_feat

torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "default"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = torch_compile_mode
5 changes: 5 additions & 0 deletions tiatoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class _RcParam(TypedDict):

TIATOOLBOX_HOME: Path
pretrained_model_info: dict[str, dict]
torch_compile_mode: str


def read_registry_files(path_to_registry: str | Path) -> dict:
Expand Down Expand Up @@ -102,6 +103,10 @@ def read_registry_files(path_to_registry: str | Path) -> dict:
"pretrained_model_info": read_registry_files(
"data/pretrained_model.yaml",
), # Load a dictionary of sample files data (names and urls)
"torch_compile_mode": "default",
# Set `torch-compile` mode to `default`
# Options: `disable`, `default`, `reduce-overhead`, `max-autotune`
# or “max-autotune-no-cudagraphs”
}


Expand Down
1 change: 1 addition & 0 deletions tiatoolbox/models/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def get_pretrained_model(
model.load_state_dict(saved_state_dict, strict=True)

# !

io_info = info["ioconfig"]
creator = locate(f"tiatoolbox.models.engine.{io_info['class']}")

Expand Down
Loading

0 comments on commit 32cae0b

Please sign in to comment.