Skip to content

Commit

Permalink
♻️ Update Changes from New Engine Design (#882)
Browse files Browse the repository at this point in the history
- Add changes from New engine design #578. This will not only simplify the PR but also keep the main repo up to date.
- Refactor `model_to` to `model_abc`
- Instead of `on_gpu` use `device` as an input in line with `PyTorch`.
- `infer_batch` uses `device` as an input instead of `on_gpu`
  • Loading branch information
shaneahmed authored Nov 21, 2024
1 parent 32cae0b commit ca13e7f
Show file tree
Hide file tree
Showing 42 changed files with 342 additions and 290 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,6 @@ ENV/

# vim/vi generated
*.swp

# output zarr generated
*.zarr
20 changes: 18 additions & 2 deletions tests/models/test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

import pytest
import torch
import torchvision.models as torch_models
from torch import nn

from tiatoolbox import rcParam
from tiatoolbox import rcParam, utils
from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.models.models_abc import ModelABC, model_to
from tiatoolbox.utils import env_detection as toolbox_env

if TYPE_CHECKING:
Expand Down Expand Up @@ -149,3 +150,18 @@ def test_model_abc() -> None:
weights_path = fetch_pretrained_weights("alexnet-kather100k")
with pytest.raises(RuntimeError, match=r".*loading state_dict*"):
_ = model.load_weights_from_file(weights_path)


def test_model_to() -> None:
"""Test for placing model on device."""
# Test on GPU
# no GPU on GitHub Actions so this will crash
if not utils.env_detection.has_gpu():
model = torch_models.resnet18()
with pytest.raises((AssertionError, RuntimeError)):
_ = model_to(device="cuda", model=model)

# Test on CPU
model = torch_models.resnet18()
model = model_to(device="cpu", model=model)
assert isinstance(model, nn.Module)
4 changes: 2 additions & 2 deletions tests/models/test_arch_mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_functionality(remote_sample: Callable) -> None:
model = _load_mapde(name="mapde-conic")
patch = model.preproc(patch)
batch = torch.from_numpy(patch)[None]
model = model.to(select_device(on_gpu=ON_GPU))
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
model = model.to()
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
output = model.postproc(output[0])
assert np.all(output[0:2] == [[19, 171], [53, 89]])
2 changes: 1 addition & 1 deletion tests/models/test_arch_micronet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_functionality(
model = model.to(map_location)
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=map_location)
output, _ = model.postproc(output[0])
assert np.max(np.unique(output)) == 46

Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_arch_nuclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tiatoolbox.models import NuClick
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import imread
from tiatoolbox.utils.misc import select_device

ON_GPU = False

Expand Down Expand Up @@ -53,7 +54,7 @@ def test_functional_nuclick(
model = NuClick(num_input_channels=5, num_output_channels=1)
pretrained = torch.load(weights_path, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
postproc_masks = model.postproc(
output,
do_reconstruction=True,
Expand Down
17 changes: 13 additions & 4 deletions tests/models/test_arch_sccnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
import numpy as np
import torch

from tiatoolbox import utils
from tiatoolbox.models import SCCNN
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import env_detection
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader


def _load_sccnn(name: str) -> torch.nn.Module:
"""Loads SCCNN model with specified weights."""
model = SCCNN()
weights_path = fetch_pretrained_weights(name)
map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu())
map_location = select_device(on_gpu=env_detection.has_gpu())
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)

Expand All @@ -40,11 +41,19 @@ def test_functionality(remote_sample: Callable) -> None:
)
batch = torch.from_numpy(patch)[None]
model = _load_sccnn(name="sccnn-crchisto")
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
assert np.all(output == [[8, 7]])

model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(
model,
batch,
device=select_device(on_gpu=env_detection.has_gpu()),
)
output = model.postproc(output[0])
assert np.all(output == [[7, 8]])
5 changes: 3 additions & 2 deletions tests/models/test_arch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.unet import UNetModel
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = False
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_functional_unet(remote_sample: Callable) -> None:
model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3])
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=ON_GPU)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
_ = output[0]

# run untrained network to test for architecture
Expand All @@ -60,4 +61,4 @@ def test_functional_unet(remote_sample: Callable) -> None:
encoder_levels=[32, 64],
skip_type="concat",
)
_ = model.infer_batch(model, batch, on_gpu=ON_GPU)
_ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
11 changes: 6 additions & 5 deletions tests/models/test_arch_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import torch

from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
from tiatoolbox.utils.misc import model_to
from tiatoolbox.models.models_abc import model_to

ON_GPU = False
RNG = np.random.default_rng() # Numpy Random Generator
device = "cuda" if ON_GPU else "cpu"


def test_functional() -> None:
Expand Down Expand Up @@ -43,8 +44,8 @@ def test_functional() -> None:
try:
for backbone in backbones:
model = CNNModel(backbone, num_classes=1)
model_ = model_to(on_gpu=ON_GPU, model=model)
model.infer_batch(model_, samples, on_gpu=ON_GPU)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand All @@ -70,8 +71,8 @@ def test_timm_functional() -> None:
try:
for backbone in backbones:
model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
model_ = model_to(on_gpu=ON_GPU, model=model)
model.infer_batch(model_, samples, on_gpu=ON_GPU)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
IOSegmentorConfig,
)
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader

ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()
Expand All @@ -35,7 +36,7 @@ def test_engine(remote_sample: Callable, tmp_path: Path) -> None:
output_list = extractor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -82,7 +83,7 @@ def test_full_inference(
[mini_wsi_svs],
mode="wsi",
ioconfig=ioconfig,
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down
9 changes: 5 additions & 4 deletions tests/models/test_hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ResidualBlock,
TFSamepaddingLayer,
)
from tiatoolbox.utils.misc import select_device
from tiatoolbox.wsicore.wsireader import WSIReader


Expand All @@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_fast-pannuke")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_fast-monusac")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_original-consep")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand All @@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernet_original-kumar")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
output = [v[0] for v in output]
output = model.postproc(output)
assert len(output[1]) > 0, "Must have some nuclei."
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_hovernetplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tiatoolbox.models import HoVerNetPlus
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.utils import imread
from tiatoolbox.utils.misc import select_device
from tiatoolbox.utils.transforms import imresize


Expand All @@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None:
weights_path = fetch_pretrained_weights("hovernetplus-oed")
pretrained = torch.load(weights_path)
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, on_gpu=False)
output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches."
output = [v[0] for v in output]
output = model.postproc(output)
Expand Down
23 changes: 12 additions & 11 deletions tests/models/test_multi_task_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils import imwrite
from tiatoolbox.utils.metrics import f1_detection
from tiatoolbox.utils.misc import select_device

ON_GPU = toolbox_env.has_gpu()
BATCH_SIZE = 1 if not ON_GPU else 8 # 16
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand All @@ -83,7 +84,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -117,7 +118,7 @@ def test_functionality_hovernetplus(remote_sample: Callable, tmp_path: Path) ->
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -148,7 +149,7 @@ def test_functionality_hovernet(remote_sample: Callable, tmp_path: Path) -> None
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -195,7 +196,7 @@ def test_masked_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
masks=[sample_wsi_msk],
mode="wsi",
ioconfig=ioconfig,
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -230,7 +231,7 @@ def test_functionality_process_instance_predictions(
output = semantic_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -268,7 +269,7 @@ def test_empty_image(tmp_path: Path) -> None:
_ = multi_segmentor.predict(
[sample_patch_path],
mode="tile",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand All @@ -284,7 +285,7 @@ def test_empty_image(tmp_path: Path) -> None:
_ = multi_segmentor.predict(
[sample_patch_path],
mode="tile",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Expand Down Expand Up @@ -312,7 +313,7 @@ def test_empty_image(tmp_path: Path) -> None:
_ = multi_segmentor.predict(
[sample_patch_path],
mode="tile",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
ioconfig=bcc_wsi_ioconfig,
Expand Down Expand Up @@ -361,7 +362,7 @@ def test_functionality_semantic(remote_sample: Callable, tmp_path: Path) -> None
output = multi_segmentor.predict(
[mini_wsi_svs],
mode="wsi",
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
ioconfig=bcc_wsi_ioconfig,
Expand Down Expand Up @@ -413,7 +414,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
masks=[sample_wsi_msk],
mode="wsi",
ioconfig=ioconfig,
on_gpu=ON_GPU,
device=select_device(on_gpu=ON_GPU),
crash_on_exception=True,
save_dir=save_dir,
)
Loading

0 comments on commit ca13e7f

Please sign in to comment.