From 32cae0be3ec371111fef53e2a5b10ccc3865f609 Mon Sep 17 00:00:00 2001 From: Abdol Date: Fri, 15 Nov 2024 20:29:34 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8FAdd=20`torch.compile`=20Funct?= =?UTF-8?q?ionality=20(#716)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Integrates PyTorch 2.0's [torch.compile](https://pytorch.org/docs/stable/generated/torch.compile.html) functionality to demonstrate performance improvements in torch code. This PR focuses on adding `torch.compile` to `PatchPredictor`. **Notes:** - According to the [documentation](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), noticeable performance can be achieved when using modern NVIDIA GPUs (H100, A100, or V100) **TODO:** - [x] Resolve compilation errors related to using `torch.compile` in running models - [x] Initial config - [x] Add to patch predictor - [x] Add to registration - [x] Add to segmentation - [x] Test on custom models - [x] Test on `torch.compile` compatible GPUs --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Co-authored-by: Jiaqi-Lv <60471431+Jiaqi-Lv@users.noreply.github.com> --- .github/workflows/python-package.yml | 2 +- tests/conftest.py | 38 ++++++++- .../models/test_nucleus_instance_segmentor.py | 8 +- tests/models/test_patch_predictor.py | 53 +++++++++++- tests/models/test_semantic_segmentation.py | 48 ++++++++++- tests/test_utils.py | 41 +++++++++- tests/test_wsi_registration.py | 70 ++++++++++++++++ tiatoolbox/__init__.py | 5 ++ tiatoolbox/models/architecture/__init__.py | 1 + tiatoolbox/models/architecture/utils.py | 82 +++++++++++++++++++ .../models/engine/multi_task_segmentor.py | 2 +- tiatoolbox/models/engine/patch_predictor.py | 10 ++- .../models/engine/semantic_segmentor.py | 18 ++-- tiatoolbox/models/models_abc.py | 4 + .../tools/registration/wsi_registration.py | 13 ++- tiatoolbox/utils/__init__.py | 9 +- 16 files changed, 383 insertions(+), 21 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ef64d04d8..5235ebe28 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -58,7 +58,7 @@ jobs: - name: Test with pytest run: | pytest --basetemp={envtmpdir} \ - --cov=tiatoolbox --cov-report=term --cov-report=xml \ + --cov=tiatoolbox --cov-report=term --cov-report=xml --cov-config=pyproject.toml \ --capture=sys \ --durations=10 --durations-min=1.0 \ --maxfail=1 diff --git a/tests/conftest.py b/tests/conftest.py index e3d676587..e2d5bab6a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,15 +4,17 @@ import os import shutil +import time from pathlib import Path from typing import Callable import pytest +import torch import tiatoolbox from tiatoolbox import logger from tiatoolbox.data import _fetch_remote_sample -from tiatoolbox.utils.env_detection import running_on_ci +from tiatoolbox.utils.env_detection import has_gpu, running_on_ci # ------------------------------------------------------------------------------------- # Generate Parameterized Tests @@ -608,3 +610,37 @@ def data_path(tmp_path_factory: pytest.TempPathFactory) -> dict[str, object]: (tmp_path / "slides").mkdir() (tmp_path / "overlays").mkdir() return {"base_path": tmp_path} + + +# ------------------------------------------------------------------------------------- +# Utility functions +# ------------------------------------------------------------------------------------- + + +def timed(fn: Callable, *args: object) -> (Callable, float): + """A decorator that times the execution of a function. + + Args: + fn (Callable): The function to be timed. + args (object): Arguments to be passed to the function. + + Returns: + A tuple containing the result of the function + and the time taken to execute it in seconds. + + """ + compile_time = 0.0 + if has_gpu(): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + result = fn(*args) + end.record() + torch.cuda.synchronize() + compile_time = start.elapsed_time(end) / 1000 + else: + start = time.time() + result = fn(*args) + end = time.time() + compile_time = end - start + return result, compile_time diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/models/test_nucleus_instance_segmentor.py index 2a493c671..ff6b9a4cc 100644 --- a/tests/models/test_nucleus_instance_segmentor.py +++ b/tests/models/test_nucleus_instance_segmentor.py @@ -11,10 +11,11 @@ import joblib import numpy as np import pytest +import torch import yaml from click.testing import CliRunner -from tiatoolbox import cli +from tiatoolbox import cli, rcParam from tiatoolbox.models import ( IOSegmentorConfig, NucleusInstanceSegmentor, @@ -44,7 +45,12 @@ def _crash_func(_x: object) -> None: def helper_tile_info() -> list: """Helper function for tile information.""" + torch._dynamo.reset() + current_torch_compile_mode = rcParam["torch_compile_mode"] + rcParam["torch_compile_mode"] = "disable" predictor = NucleusInstanceSegmentor(model="A") + torch._dynamo.reset() + rcParam["torch_compile_mode"] = current_torch_compile_mode # ! assuming the tiles organized as follows (coming out from # ! PatchExtractor). If this is broken, need to check back # ! PatchExtractor output ordering first diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index ab59efc53..5fd930138 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -13,7 +13,8 @@ import torch from click.testing import CliRunner -from tiatoolbox import cli +from tests.conftest import timed +from tiatoolbox import cli, logger, rcParam from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.dataset import ( @@ -1226,3 +1227,53 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) - assert tmp_path.joinpath("2.merged.npy").exists() assert tmp_path.joinpath("2.raw.json").exists() assert tmp_path.joinpath("results.json").exists() + + +# ------------------------------------------------------------------------------------- +# torch.compile +# ------------------------------------------------------------------------------------- + + +def test_patch_predictor_torch_compile( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test PatchPredictor with with torch.compile functionality. + + Args: + sample_patch1 (Path): Path to sample patch 1. + sample_patch2 (Path): Path to sample patch 2. + tmp_path (Path): Path to temporary directory. + + """ + torch_compile_mode = rcParam["torch_compile_mode"] + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "default" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile default mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "reduce-overhead" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile reduce-overhead mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "max-autotune" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile max-autotune mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = torch_compile_mode diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py index 9efe758c6..8fee41a9b 100644 --- a/tests/models/test_semantic_segmentation.py +++ b/tests/models/test_semantic_segmentation.py @@ -20,7 +20,8 @@ from click.testing import CliRunner from torch import nn -from tiatoolbox import cli +from tests.conftest import timed +from tiatoolbox import cli, logger, rcParam from tiatoolbox.models import SemanticSegmentor from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.utils import centre_crop @@ -897,3 +898,48 @@ def test_cli_semantic_segmentation_multi_file( _test_pred = (_test_pred[..., 1] > 0.50) * 255 assert np.mean(np.abs(_cache_pred - _test_pred) / 255) < 1e-3 + + +# ------------------------------------------------------------------------------------- +# torch.compile +# ------------------------------------------------------------------------------------- + + +def test_semantic_segmentor_torch_compile( + remote_sample: Callable, + tmp_path: Path, +) -> None: + """Test SemanticSegmentor using pretrained model with torch.compile functionality. + + Args: + remote_sample (Callable): Callable object used to extract remote sample. + tmp_path (Path): Path to temporary directory. + + """ + torch_compile_mode = rcParam["torch_compile_mode"] + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "default" + _, compile_time = timed( + test_functional_pretrained, + remote_sample, + tmp_path, + ) + logger.info("torch.compile default mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "reduce-overhead" + _, compile_time = timed( + test_functional_pretrained, + remote_sample, + tmp_path, + ) + logger.info("torch.compile reduce-overhead mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "max-autotune" + _, compile_time = timed( + test_functional_pretrained, + remote_sample, + tmp_path, + ) + logger.info("torch.compile max-autotune mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = torch_compile_mode diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d33de1ca..fe18e0d36 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,14 +13,16 @@ import numpy as np import pandas as pd import pytest +import torch from PIL import Image from requests import HTTPError from shapely.geometry import Polygon from tests.test_annotation_stores import cell_polygon -from tiatoolbox import utils +from tiatoolbox import rcParam, utils from tiatoolbox.annotation.storage import DictionaryStore, SQLiteStore from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupportedError from tiatoolbox.utils.transforms import locsize2bounds @@ -1827,3 +1829,40 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): misc.dict_to_store(patch_output, (1.0, 1.0)) + + +def test_torch_compile_already_compiled() -> None: + """Test that torch_compile does not recompile a model that is already compiled.""" + torch_compile_modes = [ + "default", + "reduce-overhead", + "max-autotune", + "max-autotune-no-cudagraphs", + ] + current_torch_compile_mode = rcParam["torch_compile_mode"] + model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10)) + + for mode in torch_compile_modes: + torch._dynamo.reset() + rcParam["torch_compile_mode"] = mode + compiled_model = compile_model(model, mode=mode) + recompiled_model = compile_model(compiled_model, mode=mode) + assert compiled_model == recompiled_model + + torch._dynamo.reset() + rcParam["torch_compile_mode"] = current_torch_compile_mode + + +def test_torch_compile_disable() -> None: + """Test torch_compile's disable mode.""" + model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10)) + compiled_model = compile_model(model, mode="disable") + assert model == compiled_model + + +def test_torch_compile_compatibility(caplog: pytest.LogCaptureFixture) -> None: + """Test if torch-compile compatibility is checked correctly.""" + from tiatoolbox.models.architecture.utils import is_torch_compile_compatible + + is_torch_compile_compatible() + assert "torch.compile" in caplog.text diff --git a/tests/test_wsi_registration.py b/tests/test_wsi_registration.py index 79abd3855..4e0f07366 100644 --- a/tests/test_wsi_registration.py +++ b/tests/test_wsi_registration.py @@ -5,7 +5,10 @@ import cv2 import numpy as np import pytest +import torch +from tests.conftest import timed +from tiatoolbox import logger, rcParam from tiatoolbox.tools.registration.wsi_registration import ( AffineWSITransformer, DFBRegister, @@ -576,3 +579,70 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None: expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE) assert np.sum(expected - output) == 0 + + +def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None: + """Test DFBRFeatureExtractor with torch.compile functionality. + + Args: + dfbr_features (Path): Path to the expected features. + + """ + + def _extract_features() -> tuple: + dfbr = DFBRegister() + fixed_img = np.repeat( + np.expand_dims( + np.repeat( + np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1), + 64, + axis=1, + ), + axis=2, + ), + 3, + axis=2, + ) + output = dfbr.extract_features(fixed_img, fixed_img) + pool3_feat = output["block3_pool"][0, :].detach().numpy() + pool4_feat = output["block4_pool"][0, :].detach().numpy() + pool5_feat = output["block5_pool"][0, :].detach().numpy() + + return pool3_feat, pool4_feat, pool5_feat + + torch_compile_mode = rcParam["torch_compile_mode"] + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "default" + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) + _pool3_feat, _pool4_feat, _pool5_feat = np.load( + str(dfbr_features), + allow_pickle=True, + ) + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 + logger.info("torch.compile default mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "reduce-overhead" + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) + _pool3_feat, _pool4_feat, _pool5_feat = np.load( + str(dfbr_features), + allow_pickle=True, + ) + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 + logger.info("torch.compile reduce-overhead mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "max-autotune" + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) + _pool3_feat, _pool4_feat, _pool5_feat = np.load( + str(dfbr_features), + allow_pickle=True, + ) + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 + logger.info("torch.compile max-autotune mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = torch_compile_mode diff --git a/tiatoolbox/__init__.py b/tiatoolbox/__init__.py index d87e63177..80ad1b4ff 100644 --- a/tiatoolbox/__init__.py +++ b/tiatoolbox/__init__.py @@ -73,6 +73,7 @@ class _RcParam(TypedDict): TIATOOLBOX_HOME: Path pretrained_model_info: dict[str, dict] + torch_compile_mode: str def read_registry_files(path_to_registry: str | Path) -> dict: @@ -102,6 +103,10 @@ def read_registry_files(path_to_registry: str | Path) -> dict: "pretrained_model_info": read_registry_files( "data/pretrained_model.yaml", ), # Load a dictionary of sample files data (names and urls) + "torch_compile_mode": "default", + # Set `torch-compile` mode to `default` + # Options: `disable`, `default`, `reduce-overhead`, `max-autotune` + # or “max-autotune-no-cudagraphs” } diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 9ac1bbd82..6fac9b08b 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -150,6 +150,7 @@ def get_pretrained_model( model.load_state_dict(saved_state_dict, strict=True) # ! + io_info = info["ioconfig"] creator = locate(f"tiatoolbox.models.engine.{io_info['class']}") diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index cefeca1c3..9df4dd56f 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -2,10 +2,92 @@ from __future__ import annotations +import sys +from typing import Callable, NoReturn + import numpy as np import torch from torch import nn +from tiatoolbox import logger + + +def is_torch_compile_compatible() -> NoReturn: + """Check if the current GPU is compatible with torch-compile. + + Raises: + Warning if GPU is not compatible with `torch.compile`. + + """ + if torch.cuda.is_available(): # pragma: no cover + device_cap = torch.cuda.get_device_capability() + if device_cap not in ((7, 0), (8, 0), (9, 0)): + logger.warning( + "GPU is not compatible with torch.compile. " + "Compatible GPUs include NVIDIA V100, A100, and H100. " + "Speedup numbers may be lower than expected.", + stacklevel=2, + ) + else: + logger.warning( + "No GPU detected or cuda not installed, " + "torch.compile is only supported on selected NVIDIA GPUs. " + "Speedup numbers may be lower than expected.", + stacklevel=2, + ) + + +def compile_model( + model: nn.Module | None = None, + *, + mode: str = "default", +) -> Callable: + """A decorator to compile a model using torch-compile. + + Args: + model (torch.nn.Module): + Model to be compiled. + mode (str): + Mode to be used for torch-compile. Available modes are: + + - `disable` disables torch-compile + - `default` balances performance and overhead + - `reduce-overhead` reduces overhead of CUDA graphs (useful for small + batches) + - `max-autotune` leverages Triton/template based matrix multiplications + on GPUs + - `max-autotune-no-cudagraphs` similar to “max-autotune” but without + CUDA graphs + + Returns: + Callable: + Compiled model. + + """ + if mode == "disable": + return model + + # Check if GPU is compatible with torch.compile + is_torch_compile_compatible() + + # This check will be removed when torch.compile is supported in Python 3.12+ + if sys.version_info >= (3, 12): # pragma: no cover + logger.warning( + ("torch-compile is currently not supported in Python 3.12+. ",), + ) + return model + + if isinstance( # pragma: no cover + model, + torch._dynamo.eval_frame.OptimizedModule, # skipcq: PYL-W0212 # noqa: SLF001 + ): + logger.info( + ("The model is already compiled. ",), + ) + return model + + return torch.compile(model, mode=mode) # pragma: no cover + def centre_crop( img: np.ndarray | torch.tensor, diff --git a/tiatoolbox/models/engine/multi_task_segmentor.py b/tiatoolbox/models/engine/multi_task_segmentor.py index c0e5ff337..3e5cf97e9 100644 --- a/tiatoolbox/models/engine/multi_task_segmentor.py +++ b/tiatoolbox/models/engine/multi_task_segmentor.py @@ -49,7 +49,7 @@ # Python is yet to be able to natively pickle Object method/static method. # Only top-level function is passable to multi-processing as caller. # May need 3rd party libraries to use method/static method otherwise. -def _process_tile_predictions( +def _process_tile_predictions( # skipcq: PY-R1000 ioconfig: IOSegmentorConfig, tile_bounds: IntBounds, tile_flag: list, diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index f68e11f4d..da4420cb0 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -11,8 +11,9 @@ import torch import tqdm -from tiatoolbox import logger +from tiatoolbox import logger, rcParam from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig from tiatoolbox.utils import misc, save_as_json @@ -250,7 +251,12 @@ def __init__( self.ioconfig = ioconfig # for storing original self._ioconfig = None # for storing runtime - self.model = model # for runtime, such as after wrapping with nn.DataParallel + self.model = ( + compile_model( # for runtime, such as after wrapping with nn.DataParallel + model, + mode=rcParam["torch_compile_mode"], + ) + ) self.pretrained_model = pretrained_model self.batch_size = batch_size self.num_loader_worker = num_loader_workers diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index 3a5f3475f..271d49150 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -17,8 +17,9 @@ import torch.utils.data as torch_data import tqdm -from tiatoolbox import logger +from tiatoolbox import logger, rcParam from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.models.models_abc import IOConfigABC from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread, misc @@ -563,7 +564,10 @@ def __init__( self.masks = None self.dataset_class: WSIStreamDataset = dataset_class - self.model = model # original copy + self.model = compile_model( + model, + mode=rcParam["torch_compile_mode"], + ) self.pretrained_model = pretrained_model self.batch_size = batch_size self.num_loader_workers = num_loader_workers @@ -573,9 +577,9 @@ def __init__( @staticmethod def get_coordinates( - image_shape: list[int] | np.ndarray, + image_shape: tuple[int, int] | np.ndarray, ioconfig: IOSegmentorConfig, - ) -> tuple[list, list]: + ) -> tuple[np.ndarray, np.ndarray]: """Calculate patch tiling coordinates. By default, internally, it will call the @@ -619,13 +623,13 @@ def get_coordinates( >>> segmentor.get_coordinates = func """ - (patch_inputs, patch_outputs) = PatchExtractor.get_coordinates( + results = PatchExtractor.get_coordinates( + patch_output_shape=ioconfig.patch_output_shape, image_shape=image_shape, patch_input_shape=ioconfig.patch_input_shape, - patch_output_shape=ioconfig.patch_output_shape, stride_shape=ioconfig.stride_shape, ) - return patch_inputs, patch_outputs + return results[0], results[1] @staticmethod def filter_coordinates( diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 74954e59e..e16540c87 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -6,8 +6,12 @@ from typing import TYPE_CHECKING, Any, Callable import torch +import torch._dynamo from torch import device as torch_device +torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001 + + if TYPE_CHECKING: # pragma: no cover from pathlib import Path diff --git a/tiatoolbox/tools/registration/wsi_registration.py b/tiatoolbox/tools/registration/wsi_registration.py index 8d42365b3..a333d77b5 100644 --- a/tiatoolbox/tools/registration/wsi_registration.py +++ b/tiatoolbox/tools/registration/wsi_registration.py @@ -16,7 +16,8 @@ from skimage.util import img_as_float from torchvision.models import VGG16_Weights -from tiatoolbox import logger +from tiatoolbox import logger, rcParam +from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.metrics import dice from tiatoolbox.utils.transforms import imresize @@ -338,8 +339,9 @@ def __init__(self: DFBRFeatureExtractor) -> None: output_layers_id: list[str] = ["16", "23", "30"] output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"] self.features: dict = dict.fromkeys(output_layers_key, None) - self.pretrained: torch.nn.Sequential = torchvision.models.vgg16( - weights=VGG16_Weights.IMAGENET1K_V1, + self.pretrained: torch.nn.Sequential = compile_model( + torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1), + mode=rcParam["torch_compile_mode"], ).features self.f_hooks = [ getattr(self.pretrained, layer).register_forward_hook( @@ -431,7 +433,10 @@ class DFBRegister: """ - def __init__(self: DFBRegister, patch_size: tuple[int, int] = (224, 224)) -> None: + def __init__( + self: DFBRegister, + patch_size: tuple[int, int] = (224, 224), + ) -> None: """Initialize :class:`DFBRegister`.""" self.patch_size = patch_size self.x_scale: np.ndarray diff --git a/tiatoolbox/utils/__init__.py b/tiatoolbox/utils/__init__.py index 3653b0321..e29f28c12 100644 --- a/tiatoolbox/utils/__init__.py +++ b/tiatoolbox/utils/__init__.py @@ -10,7 +10,14 @@ visualization, ) -from .misc import download_data, imread, imwrite, save_as_json, save_yaml, unzip_data +from .misc import ( + download_data, + imread, + imwrite, + save_as_json, + save_yaml, + unzip_data, +) __all__ = [ "imread",