diff --git a/.coverage_others b/.coverage_others
new file mode 100644
index 00000000..c1b7fcbe
Binary files /dev/null and b/.coverage_others differ
diff --git a/.coverage_tilestitcher b/.coverage_tilestitcher
new file mode 100644
index 00000000..2f637b86
Binary files /dev/null and b/.coverage_tilestitcher differ
diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 00000000..bf500e1f
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1 @@
+tests/testdata/tilestitching_testdata/*.tif filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml
index dd5dd367..ad759638 100644
--- a/.github/workflows/publish-to-pypi.yml
+++ b/.github/workflows/publish-to-pypi.yml
@@ -1,13 +1,14 @@
name: Publish PathML distribution to PyPI and TestPyPI
on:
+ workflow_dispatch:
release:
types: [published]
jobs:
build-n-publish:
name: Build and publish PathML distribution to PyPI and TestPyPI
- runs-on: ubuntu-18.04
+ runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- name: Set up Python 3.9
diff --git a/.github/workflows/tests-conda.yml b/.github/workflows/tests-conda.yml
index 05401f41..c7282f10 100644
--- a/.github/workflows/tests-conda.yml
+++ b/.github/workflows/tests-conda.yml
@@ -1,6 +1,7 @@
name: Tests
-on:
+on:
+ workflow_dispatch:
pull_request:
branches:
- dev
@@ -17,6 +18,7 @@ jobs:
max-parallel: 5
matrix:
python-version: [3.8, 3.9]
+ timeout-minutes: 60 # add a timeout
steps:
- uses: actions/checkout@v2
@@ -24,53 +26,124 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- # Test matrix by printing the current Python version
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install dependencies
shell: bash -l {0}
run: |
sudo apt-get update
- # install openslide
sudo apt-get install openslide-tools
- # install pandoc for making documentation
sudo apt-get install pandoc
- name: Setup Miniconda
- # You may pin to the exact commit or the version.
- # uses: conda-incubator/setup-miniconda@f4c00b0ec69bdc87b1ab4972613558dd9f4f36f3
- uses: conda-incubator/setup-miniconda@v2.0.0
+ uses: conda-incubator/setup-miniconda@v2
with:
- add_pip_as_python_dependency: false
- environment-file: environment.yml
+ auto-activate-base: false
activate-environment: pathml
+ environment-file: environment.yml
+ mamba-version: "*"
python-version: ${{ matrix.python-version }}
+ - name: Debugging
+ run: |
+ echo "Printing the environment.yml file..."
+ cat environment.yml
+ echo "Checking the status of mamba..."
+ mamba --version
+ echo "Checking the available disk space..."
+ df -h
+ - name: Install dependencies with mamba
+ shell: bash -l {0}
+ run: mamba env update --file environment.yml --name pathml
- name: Conda info
shell: bash -l {0}
run: |
conda info
conda list
- - name: pip install pathml
- shell: bash -l {0}
+
+ - name: Set default Temurin JDK 17
+ run: |
+ sudo update-java-alternatives --set temurin-17-jdk-amd64 || true
+ java -version
+
+ - name: Install PathML
+ shell: bash -l {0}
run: pip install -e .
+
+ - name: Install python-spams
+ shell: bash -l {0}
+ run: pip install spams
+
- name: disk usage
shell: bash -l {0}
run: |
sudo df -h
sudo du -h
- - name: Test with pytest and generate coverage report
+
+ # - name: Check Coverage Command
+ # run: |
+ # which coverage
+ # coverage --version
+
+ # - name: Test with pytest for tile_stitcher
+ # run: |
+ # java -version
+ # python -m pytest tests/preprocessing_tests/test_tilestitcher.py
+
+ # - name: Test with pytest and generate coverage report
+ # shell: bash -l {0}
+ # run: |
+ # coverage run -m pytest -m "not slow and not exclude"
+ # coverage xml
+ # - name: Upload coverage to Codecov
+ # uses: codecov/codecov-action@v2
+ # with:
+ # token: ${{ secrets.CODECOV_TOKEN }}
+ # env_vars: OS,PYTHON
+ # fail_ci_if_error: true
+ # files: ./coverage.xml
+ # name: codecov-umbrella
+ # verbose: true
+
+ - name: Test other modules with pytest and generate coverage
+ shell: bash -l {0}
+ run: |
+ COVERAGE_FILE=.coverage_others coverage run -m pytest -m "not slow and not exclude"
+
+ - name: Test tile_stitcher with pytest and generate coverage
shell: bash -l {0}
run: |
- coverage run -m pytest -m "not slow"
- coverage xml
- - name: Upload coverage to Codecov
+ COVERAGE_FILE=.coverage_tilestitcher coverage run -m pytest tests/preprocessing_tests/test_tilestitcher.py
+
+ - name: List Files in Directory
+ shell: bash -l {0}
+ run: |
+ ls -la
+
+ - name: Combine Coverage Data
+ shell: bash -l {0}
+ run: |
+ coverage combine .coverage_tilestitcher .coverage_others
+
+ - name: Generate Combined Coverage Report
+ shell: bash -l {0}
+ run: |
+ coverage xml -o combined_coverage.xml
+
+ # - name: Combine coverage data
+ # shell: bash -l {0}
+ # run: |
+ # coverage combine coverage_tilestitcher.xml coverage_others.xml
+ # coverage xml -o coverage_combined.xml
+
+ - name: Upload combined coverage to Codecov
uses: codecov/codecov-action@v2
with:
token: ${{ secrets.CODECOV_TOKEN }}
env_vars: OS,PYTHON
fail_ci_if_error: true
- files: ./coverage.xml
+ files: ./combined_coverage.xml
name: codecov-umbrella
verbose: true
+
- name: Compile docs
shell: bash -l {0}
run: |
diff --git a/.gitignore b/.gitignore
index 2fe6662a..ab777666 100644
--- a/.gitignore
+++ b/.gitignore
@@ -33,3 +33,6 @@ scratch.ipynb
# dask
dask-worker-space/
+
+# tile stitching files
+tools/*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7320bf2f..e096e8b3 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -16,7 +16,7 @@ repos:
- id: black
- repo: https://github.com/timothycrosley/isort
- rev: 5.10.1
+ rev: 5.11.5
hooks:
- id: isort
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 86d61526..02377793 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -5,6 +5,11 @@
# Required
version: 2
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.8"
+
# Build documentation with Sphinx
sphinx:
configuration: docs/source/conf.py
@@ -12,6 +17,5 @@ sphinx:
fail_on_warning: false
python:
- version: "3.8"
install:
- requirements: docs/readthedocs-requirements.txt
diff --git a/README.md b/README.md
index 745e0fc2..e229cf6e 100644
--- a/README.md
+++ b/README.md
@@ -1,19 +1,27 @@
-
-
-
+🤖🔬 **PathML: Tools for computational pathology**
+[![Downloads](https://static.pepy.tech/badge/pathml)](https://pepy.tech/project/pathml)
[![Documentation Status](https://readthedocs.org/projects/pathml/badge/?version=latest)](https://pathml.readthedocs.io/en/latest/?badge=latest)
+[![codecov](https://codecov.io/gh/Dana-Farber-AIOS/pathml/branch/master/graph/badge.svg?token=UHSQPTM28Y)](https://codecov.io/gh/Dana-Farber-AIOS/pathml)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![PyPI version](https://img.shields.io/pypi/v/pathml)](https://pypi.org/project/pathml/)
-[![Downloads](https://pepy.tech/badge/pathml)](https://pepy.tech/project/pathml)
-[![codecov](https://codecov.io/gh/Dana-Farber-AIOS/pathml/branch/master/graph/badge.svg?token=UHSQPTM28Y)](https://codecov.io/gh/Dana-Farber-AIOS/pathml)
+
+⭐ **PathML objective is to lower the barrier to entry to digital pathology**
+
+Imaging datasets in cancer research are growing exponentially in both quantity and information density. These massive datasets may enable derivation of insights for cancer research and clinical care, but only if researchers are equipped with the tools to leverage advanced computational analysis approaches such as machine learning and artificial intelligence. In this work, we highlight three themes to guide development of such computational tools: scalability, standardization, and ease of use. We then apply these principles to develop PathML, a general-purpose research toolkit for computational pathology. We describe the design of the PathML framework and demonstrate applications in diverse use cases.
+
+🚀 **The fastest way to get started?**
+
+ docker pull pathml/pathml && docker run -it -p 8888:8888 pathml/pathml
| Branch | Test status |
| ------ | ------------- |
| master | ![tests](https://github.com/Dana-Farber-AIOS/pathml/actions/workflows/tests-conda.yml/badge.svg?branch=master) |
| dev | ![tests](https://github.com/Dana-Farber-AIOS/pathml/actions/workflows/tests-conda.yml/badge.svg?branch=dev) |
-An open-source toolkit for computational pathology and machine learning.
+
+
+
**View [documentation](https://pathml.readthedocs.io/en/latest/)**
@@ -125,6 +133,24 @@ Note that these instructions assume that there are no other processes using port
Please refer to the `Docker run` [documentation](https://docs.docker.com/engine/reference/run/) for further instructions
on accessing the container, e.g. for mounting volumes to access files on a local machine from within the container.
+## Option 4: Google Colab
+
+To get PathML running in a Colab environment:
+
+````
+import os
+!pip install openslide-python
+!apt-get install openslide-tools
+!apt-get install openjdk-8-jdk-headless -qq > /dev/null
+os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
+!update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java
+!java -version
+!pip install pathml
+````
+
+*Thanks to all of our open-source collaborators for helping maintain these installation instructions!*
+*Please open an issue for any bugs or other problems during installation process.*
+
## CUDA
To use GPU acceleration for model training or other tasks, you must install CUDA.
@@ -191,12 +217,36 @@ See [contributing](https://github.com/Dana-Farber-AIOS/pathml/blob/master/CONTRI
# Citing
-If you use `PathML` in your work, please cite our paper:
+If you use `PathML` please cite:
+
+- [**J. Rosenthal et al., "Building tools for machine learning and artificial intelligence in cancer research: best practices and a case study with the PathML toolkit for computational pathology." Molecular Cancer Research, 2022.**](https://doi.org/10.1158/1541-7786.MCR-21-0665)
+
+So far, PathML was used in the following manuscripts:
+
+- [J. Linares et al. **Molecular Cell** 2021](https://www.cell.com/molecular-cell/fulltext/S1097-2765(21)00729-2)
+- [A. Shmatko et al. **Nature Cancer** 2022](https://www.nature.com/articles/s43018-022-00436-4)
+- [J. Pocock et al. **Nature Communications Medicine** 2022](https://www.nature.com/articles/s43856-022-00186-5)
+- [S. Orsulic et al. **Frontiers in Oncology** 2022](https://www.frontiersin.org/articles/10.3389/fonc.2022.924945/full)
+- [D. Brundage et al. **arXiv** 2022](https://arxiv.org/abs/2203.13888)
+- [A. Marcolini et al. **SoftwareX** 2022](https://www.sciencedirect.com/science/article/pii/S2352711022001558)
+- [M. Rahman et al. **Bioengineering** 2022](https://www.mdpi.com/2306-5354/9/8/335)
+- [C. Lama et al. **bioRxiv** 2022](https://www.biorxiv.org/content/10.1101/2022.09.28.509751v1.full)
+- the list continues [**here 🔗 for 2023 and onwards**](https://scholar.google.com/scholar?oi=bibs&hl=en&cites=1157052756975292108)
+
+# Users
+
+
This is where in the world our most enthusiastic supporters are located:
+
+
+ |
+and this is where they work:
+
+
+ |
+
+
-Rosenthal J, Carelli R, Omar M, Brundage D, Halbert E, Nyman J, Hari SN, Van Allen EM, Marchionni L, Umeton R, Loda M.
-Building tools for machine learning and artificial intelligence in cancer research: best practices and a case study
-with the PathML toolkit for computational pathology. *Molecular Cancer Research*, 2021.
-DOI: [10.1158/1541-7786.MCR-21-0665](https://doi.org/10.1158/1541-7786.MCR-21-0665)
+Source: https://ossinsight.io/analyze/Dana-Farber-AIOS/pathml#people
# License
@@ -209,6 +259,6 @@ Commercial license options are available also.
Questions? Comments? Suggestions? Get in touch!
-[PathML@dfci.harvard.edu](mailto:PathML@dfci.harvard.edu)
+[pathml@dfci.harvard.edu](mailto:pathml@dfci.harvard.edu)
diff --git a/docs/readthedocs-requirements.txt b/docs/readthedocs-requirements.txt
index 831e134d..0e8f22d3 100644
--- a/docs/readthedocs-requirements.txt
+++ b/docs/readthedocs-requirements.txt
@@ -1,7 +1,7 @@
-sphinx==4.3.2
+sphinx==7.1.2
nbsphinx==0.8.8
nbsphinx-link==1.3.0
-sphinx-rtd-theme==1.0.0
-sphinx-autoapi==1.8.4
-ipython==7.31.1
+sphinx-rtd-theme==1.3.0
+sphinx-autoapi==3.0.0
+ipython==8.10.0
sphinx-copybutton==0.4.0
diff --git a/environment.yml b/environment.yml
index 22754c02..513592b7 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,33 +1,37 @@
name: pathml
channels:
- - conda-forge
- pytorch
+ - conda-forge
dependencies:
- pip==21.3.1
- - python==3.8
- - numpy==1.19.5
- - scipy==1.7.3
- - scikit-image==0.18.3
+ - numpy # orig = 1.19.5
+ - scipy # orig = 1.7.3
+ - scikit-image # orig 0.18.3
- matplotlib==3.5.1
- - python-spams==2.6.1
- - openjdk==8.0.152
- - pytorch==1.10.1
+ - openjdk<=18.0.0
+ - pytorch==1.13.1 # orig = 1.10.1
- h5py==3.1.0
- dask==2021.12.0
- pydicom==2.2.2
- - pytest==6.2.5
+ - pytest==7.4.0 # orig = 6.2.5
- pre-commit==2.16.0
- coverage==5.5
+ - networkx==3.1
- pip:
- python-bioformats==4.0.0
- python-javabridge==4.0.0
- - protobuf==3.20.1
- - deepcell==0.11.0
+ - protobuf==3.20.3
+ - deepcell==0.12.7 # orig = 0.11.0
+ - onnx==1.14.0
+ - onnxruntime==1.15.1
- opencv-contrib-python==4.5.3.56
- - openslide-python==1.1.2
+ - openslide-python==1.2.0
- scanpy==1.8.2
- anndata==0.7.8
- tqdm==4.62.3
- loguru==0.5.3
+ - pandas==1.5.2 # orig no req
+ - torch-geometric==2.3.1
+ - jpype1
diff --git a/examples/construct_graphs.ipynb b/examples/construct_graphs.ipynb
new file mode 100644
index 00000000..0b10bc84
--- /dev/null
+++ b/examples/construct_graphs.ipynb
@@ -0,0 +1,494 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "14070544-7803-40fb-8f4b-99724b49f224",
+ "metadata": {},
+ "source": [
+ "# PathML Graph construction and processing "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8886bc5f-83db-4abe-97e0-b8bc9b3aab56",
+ "metadata": {},
+ "source": [
+ "In this notebook, we will demonstrate the ability of the new pathml.graph API to construct cell and tissue graphs. Specifically, we will do the following:\n",
+ "\n",
+ "1. Use a pre-trained HoVer-Net model to detect cells in a given Region of Interested (ROI)\n",
+ "2. Use boundary detection techniques to detect tissues in a given ROI\n",
+ "3. Featurize the detected cell and tissue patches using a ResNet model\n",
+ "4. Construct both tissue and cell graphs using k-Nearest Neighbour (k-NN) and Region-Adjacency Graph (RAG) methods and save them as torch tensors.\n",
+ "\n",
+ "To get the full functionality of this notebook for a real-world dataset, we suggest you download the BRACS ROI set from the [BRACS dataset](https://www.bracs.icar.cnr.it/download/). To do so, you will have to sign up and create an account. Next, you will just have to replace the root folder in this tutorial to whatever directory you download the BRACS dataset to. \n",
+ "\n",
+ "In this notebook, we will use a representative image from this [link](https://github.com/histocartography/hact-net/tree/main/data) stored in `data/`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a2399d96-abf7-46b6-b783-c4c292259bdc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from glob import glob\n",
+ "import argparse\n",
+ "from PIL import Image\n",
+ "import numpy as np\n",
+ "from tqdm import tqdm\n",
+ "import torch \n",
+ "import h5py\n",
+ "import warnings\n",
+ "import math\n",
+ "from skimage.measure import regionprops, label\n",
+ "import networkx as nx\n",
+ "import traceback\n",
+ "from glob import glob\n",
+ "\n",
+ "from pathml.core import HESlide, SlideData\n",
+ "import matplotlib.pyplot as plt \n",
+ "from pathml.preprocessing.transforms import Transform\n",
+ "from pathml.core import HESlide\n",
+ "from pathml.preprocessing import Pipeline, BoxBlur, TissueDetectionHE, NucleusDetectionHE\n",
+ "import pathml.core.tile\n",
+ "from pathml.ml import HoVerNet, loss_hovernet, post_process_batch_hovernet\n",
+ "\n",
+ "from pathml.datasets.utils import DeepPatchFeatureExtractor\n",
+ "from pathml.preprocessing import StainNormalizationHE\n",
+ "from pathml.graph import RAGGraphBuilder, KNNGraphBuilder\n",
+ "from pathml.graph import ColorMergedSuperpixelExtractor\n",
+ "from pathml.graph.utils import _valid_image, _exists, plot_graph_on_image, get_full_instance_map, build_assignment_matrix"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ee47ddb6-2edf-42b5-9227-65c9ac1ecbf1",
+ "metadata": {},
+ "source": [
+ "## Building a HoverNetNucleusDetectionHE class using pathml.transforms API "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5d58d8d0-e7a9-4e92-bfd0-2ec1cecafcf6",
+ "metadata": {},
+ "source": [
+ "First, we will use a pre-trained HoVer-Net model to detect cells and return a instance map containing masks that corresponds to cells. We will use the `pathml.preprocessing.transforms` class that can be used to apply a function over each ROI. The new `HoverNetNucleusDetectionHE` simply inherits this class and applies a HoVer-Net model onto each ROI that is passed into it. \n",
+ "\n",
+ "To obtain the pre-trained HoVer-Net model, we follow the steps in this [tutorial](https://pathml.readthedocs.io/en/latest/examples/link_train_hovernet.html). For simplicity, we provide a pre-trained model under the `pretrained_models` folder. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e0358a83-e93a-4525-a7bb-f697e2252ce7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class HoverNetNucleusDetectionHE(Transform):\n",
+ " \n",
+ " \"\"\"\n",
+ " Nucleus detection algorithm for H&E stained images using pre-trained HoverNet Model.\n",
+ "\n",
+ " Args:\n",
+ " mask_name (str): Name of mask that is created.\n",
+ " model_path (str): Path to the pretrained model. \n",
+ " \n",
+ " References:\n",
+ " Graham, S., Vu, Q.D., Raza, S.E.A., Azam, A., Tsang, Y.W., Kwak, J.T. and Rajpoot, N., 2019. \n",
+ " Hover-net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images. \n",
+ " Medical image analysis, 58, p.101563.\n",
+ " \n",
+ " \"\"\"\n",
+ " \n",
+ " def __init__(\n",
+ " self,\n",
+ " mask_name,\n",
+ " model_path = None\n",
+ " ):\n",
+ " self.mask_name = mask_name\n",
+ " \n",
+ " cuda = torch.cuda.is_available()\n",
+ " self.device = torch.device(\"cuda:0\" if cuda else \"cpu\")\n",
+ " \n",
+ " if model_path is None:\n",
+ " raise NotImplementedError(\"Downloadable models not available\")\n",
+ " else:\n",
+ " checkpoint = torch.load(model_path)\n",
+ " self.model = HoVerNet(n_classes=6)\n",
+ " self.model.load_state_dict(checkpoint)\n",
+ " \n",
+ " self.model = self.model.to(self.device)\n",
+ " self.model.eval()\n",
+ "\n",
+ " def F(self, image):\n",
+ " assert (\n",
+ " image.dtype == np.uint8\n",
+ " ), f\"Input image dtype {image.dtype} must be np.uint8\"\n",
+ " \n",
+ " image = torch.from_numpy(image).float()\n",
+ " image = image.permute(2, 0, 1)\n",
+ " image = image.unsqueeze(0)\n",
+ " image = image.to(self.device)\n",
+ " with torch.no_grad():\n",
+ " out = self.model(image)\n",
+ " preds_detection, _ = post_process_batch_hovernet(out, n_classes=6)\n",
+ " preds_detection = preds_detection.transpose(1,2,0)\n",
+ " return preds_detection\n",
+ "\n",
+ " def apply(self, tile):\n",
+ " assert isinstance(\n",
+ " tile, pathml.core.tile.Tile\n",
+ " ), f\"tile is type {type(tile)} but must be pathml.core.tile.Tile\"\n",
+ " assert (\n",
+ " self.mask_name is not None\n",
+ " ), \"mask_name is None. Must supply a valid mask name\"\n",
+ " assert (\n",
+ " tile.slide_type.stain == \"HE\"\n",
+ " ), f\"Tile has slide_type.stain={tile.slide_type.stain}, but must be 'HE'\"\n",
+ " \n",
+ " nucleus_mask = self.F(tile.image)\n",
+ " tile.masks[self.mask_name] = nucleus_mask"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3c2e1bc1-6825-44b6-8909-8d47574295ed",
+ "metadata": {},
+ "source": [
+ "A simple example on using this class is given below. The input image for this is present in the `data` folder. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a2bbe1db-17b0-4ddf-8e74-c338bae2c740",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wsi = SlideData('../data/example_0_N_0.png', name = 'example', backend = \"openslide\", stain = 'HE')\n",
+ "region = wsi.slide.extract_region(location = (900, 800), size = (256, 256))\n",
+ "plt.imshow(region)\n",
+ "plt.title('Input image', fontsize=11)\n",
+ "plt.gca().set_xticks([])\n",
+ "plt.gca().set_yticks([])\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d098e3eb-a32f-4d67-88bc-a056fe0207e6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nuclei_detect = HoverNetNucleusDetectionHE(mask_name = 'cell', model_path = '../pretrained_models/hovernet_fully_trained.pt')\n",
+ "cell_mask = nuclei_detect.F(region)\n",
+ "plt.imshow(cell_mask)\n",
+ "plt.title('Nuclei mask', fontsize=11)\n",
+ "plt.gca().set_xticks([])\n",
+ "plt.gca().set_yticks([])\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b84c3b68-9028-4724-97ad-6d9934e0ce2e",
+ "metadata": {},
+ "source": [
+ "## Cell and Tissue graph construction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "993b2866-1c88-485c-810b-60f0be998174",
+ "metadata": {},
+ "source": [
+ "Next, we can move on to applying a function that uses the new `pathml.graph` API to construct cell and tissue graphs.\n",
+ "\n",
+ "We have to first define some constants. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b9377c10-d38c-4bea-ada9-aaff38cd92a9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Convert the tumor time given in the filename to a label\n",
+ "TUMOR_TYPE_TO_LABEL = {\n",
+ " 'N': 0,\n",
+ " 'PB': 1,\n",
+ " 'UDH': 2,\n",
+ " 'ADH': 3,\n",
+ " 'FEA': 4,\n",
+ " 'DCIS': 5,\n",
+ " 'IC': 6\n",
+ "}\n",
+ "\n",
+ "# Define minimum and maximum pixels for processing a ROI\n",
+ "MIN_NR_PIXELS = 50000\n",
+ "MAX_NR_PIXELS = 50000000 \n",
+ "\n",
+ "# Define the reference image and HoVer-Net model path\n",
+ "ref_path = '../data/example_0_N_0.png'\n",
+ "hovernet_model_path = '../pretrained_models/hovernet_fully_trained.pt'\n",
+ "\n",
+ "# Define the patch size for applying HoverNetNucleusDetectionHE \n",
+ "PATCH_SIZE = 256"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d7b5798f-3d93-4179-9bdc-8ae80b38573a",
+ "metadata": {},
+ "source": [
+ "Next, we write the main preprocessing loop as a function. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2aa61a77-b882-4161-8fb5-72e57ad2be17",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def process(image_path, save_path, split, plot=True, overwrite=False):\n",
+ " # 1. get image path\n",
+ " subdirs = os.listdir(image_path)\n",
+ " image_fnames = []\n",
+ " for subdir in (subdirs + ['']): \n",
+ " image_fnames += glob(os.path.join(image_path, subdir, '*.png'))\n",
+ " \n",
+ " image_ids_failing = []\n",
+ " \n",
+ " print('*** Start analysing {} image(s) ***'.format(len(image_fnames)))\n",
+ " \n",
+ " ref_image = np.array(Image.open(ref_path))\n",
+ " norm = StainNormalizationHE(stain_estimation_method='vahadane')\n",
+ " norm.fit_to_reference(ref_image)\n",
+ " ref_stain_matrix = norm.stain_matrix_target_od \n",
+ " ref_max_C = norm.max_c_target \n",
+ " \n",
+ " for image_path in tqdm(image_fnames):\n",
+ " \n",
+ " # a. load image & check if already there \n",
+ " _, image_name = os.path.split(image_path)\n",
+ " image = np.array(Image.open(image_path))\n",
+ "\n",
+ " # Compute number of pixels in image and check the label of the image\n",
+ " nr_pixels = image.shape[0] * image.shape[1]\n",
+ " image_label = TUMOR_TYPE_TO_LABEL[image_name.split('_')[2]]\n",
+ "\n",
+ " # Get the output file paths of cell graphs, tissue graphs and assignment matrices\n",
+ " cg_out = os.path.join(save_path, 'cell_graphs', split, image_name.replace('.png', '.pt'))\n",
+ " tg_out = os.path.join(save_path, 'tissue_graphs', split, image_name.replace('.png', '.pt'))\n",
+ " assign_out = os.path.join(save_path, 'assignment_matrices', split, image_name.replace('.png', '.pt')) \n",
+ "\n",
+ " # If file was not already created or not too big or not too small, then process \n",
+ " if not _exists(cg_out, tg_out, assign_out, overwrite) and _valid_image(nr_pixels):\n",
+ " \n",
+ " print(f'Image size: {image.shape[0], image.shape[1]}')\n",
+ "\n",
+ " if plot:\n",
+ " print('Input ROI:')\n",
+ " plt.imshow(image)\n",
+ " plt.show()\n",
+ " \n",
+ " try:\n",
+ " # Read the image as a pathml.core.SlideData class\n",
+ " print('\\nReading image')\n",
+ " wsi = SlideData(image_path, name = image_path, backend = \"openslide\", stain = 'HE')\n",
+ "\n",
+ " # Apply our HoverNetNucleusDetectionHE as a pathml.preprocessing.Pipeline over all patches\n",
+ " print('Detecting nuclei')\n",
+ " pipeline = Pipeline([HoverNetNucleusDetectionHE(mask_name='cell', \n",
+ " model_path=hovernet_model_path)])\n",
+ " \n",
+ " # Run the Pipeline \n",
+ " wsi.run(pipeline, overwrite_existing_tiles=True, distributed=False, tile_pad=True, tile_size=PATCH_SIZE)\n",
+ "\n",
+ " # Extract the ROI, nuclei instance maps as an np.array from a pathml.core.SlideData object\n",
+ " image, nuclei_map, nuclei_centroid = get_full_instance_map(wsi, patch_size = PATCH_SIZE)\n",
+ "\n",
+ " # Use a ResNet-34 to extract the features from each detected cell in the ROI\n",
+ " print('Extracting features from cells')\n",
+ " extractor = DeepPatchFeatureExtractor(patch_size=64, \n",
+ " batch_size=64, \n",
+ " entity = 'cell',\n",
+ " architecture='resnet34', \n",
+ " fill_value=255, \n",
+ " resize_size=224,\n",
+ " threshold=0)\n",
+ " features = extractor.process(image, nuclei_map)\n",
+ "\n",
+ " # Build a kNN graph with nodes as cells, node features as ResNet-34 computed features, and edges within\n",
+ " # a threshold of 50\n",
+ " print('Building graphs')\n",
+ " knn_graph_builder = KNNGraphBuilder(k=5, thresh=50, add_loc_feats=True)\n",
+ " cell_graph = knn_graph_builder.process(nuclei_map, features, target = image_label)\n",
+ "\n",
+ " # Plot cell graph on ROI image \n",
+ " if plot:\n",
+ " print('Cell graph on ROI:')\n",
+ " plot_graph_on_image(cell_graph, image)\n",
+ "\n",
+ " # Save the cell graph \n",
+ " torch.save(cell_graph, cg_out)\n",
+ "\n",
+ " # Detect tissue using pathml.graph.ColorMergedSuperpixelExtractor class\n",
+ " print('\\nDetecting tissue')\n",
+ " tissue_detector = ColorMergedSuperpixelExtractor(superpixel_size=200,\n",
+ " compactness=20,\n",
+ " blur_kernel_size=1,\n",
+ " threshold=0.05,\n",
+ " downsampling_factor=4)\n",
+ "\n",
+ " superpixels, _ = tissue_detector.process(image)\n",
+ "\n",
+ " # Use a ResNet-34 to extract the features from each detected tissue in the ROI\n",
+ " print('Extracting features from tissues')\n",
+ " tissue_feature_extractor = DeepPatchFeatureExtractor(architecture='resnet34',\n",
+ " patch_size=144,\n",
+ " entity = 'tissue',\n",
+ " resize_size=224,\n",
+ " fill_value=255,\n",
+ " batch_size=32,\n",
+ " threshold = 0.25)\n",
+ " features = tissue_feature_extractor.process(image, superpixels)\n",
+ "\n",
+ " # Build a RAG with tissues as nodes, node features as ResNet-34 computed features, and edges using the \n",
+ " # RAG algorithm\n",
+ " print('Building graphs')\n",
+ " rag_graph_builder = RAGGraphBuilder(add_loc_feats=True)\n",
+ " tissue_graph = rag_graph_builder.process(superpixels, features, target = image_label)\n",
+ "\n",
+ " # Plot tissue graph on ROI image\n",
+ " if plot:\n",
+ " print('Tissue graph on ROI:')\n",
+ " plot_graph_on_image(tissue_graph, image)\n",
+ "\n",
+ " # Save the tissue graph \n",
+ " torch.save(tissue_graph, tg_out) \n",
+ "\n",
+ " # Build as assignment matrix that maps each cell to the tissue it is a part of \n",
+ " assignment = build_assignment_matrix(nuclei_centroid, superpixels)\n",
+ "\n",
+ " # Save the assignment matrix\n",
+ " torch.save(torch.tensor(assignment), assign_out)\n",
+ " \n",
+ " except:\n",
+ " print(f'Failed {image_path}')\n",
+ " image_ids_failing.append(image_path)\n",
+ " \n",
+ " print('\\nOut of {} images, {} successful graph generations.'.format(\n",
+ " len(image_fnames),\n",
+ " len(image_fnames) - len(image_ids_failing)\n",
+ " ))\n",
+ " print('Failing IDs are:', image_ids_failing)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7fe9a5cd-4e01-4f4f-b5cd-91f1b204835b",
+ "metadata": {},
+ "source": [
+ "Finally, we write a main function that calls the process function for a specified root and output directory, along with the name of the split (either train, test or validation if using BRACS). "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5e6a71aa-babc-47e7-a641-d9688962350b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def main(base_path, save_path, split=None):\n",
+ " if split is not None:\n",
+ " root_path = os.path.join(base_path, split)\n",
+ " else:\n",
+ " root_path = base_path\n",
+ " \n",
+ " print(root_path)\n",
+ " \n",
+ " os.makedirs(os.path.join(save_path, 'cell_graphs', split), exist_ok=True)\n",
+ " os.makedirs(os.path.join(save_path, 'tissue_graphs', split), exist_ok=True)\n",
+ " os.makedirs(os.path.join(save_path, 'assignment_matrices', split), exist_ok=True)\n",
+ " \n",
+ " process(root_path, save_path, split, plot=True, overwrite=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "54d660e9-b32c-4101-9d64-7fb9cdc14c93",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Folder containing all images\n",
+ "base = '../data/'\n",
+ "\n",
+ "# Output path \n",
+ "save_path = '../data/output/'\n",
+ "\n",
+ "# Start preprocessing\n",
+ "main(base, save_path, split='')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8c7cfdb1-e016-4d9c-acff-b3eb0f40ce55",
+ "metadata": {},
+ "source": [
+ "## References\n",
+ "\n",
+ "* Pati, Pushpak, Guillaume Jaume, Antonio Foncubierta-Rodriguez, Florinda Feroce, Anna Maria Anniciello, Giosue Scognamiglio, Nadia Brancati et al. \"Hierarchical graph representations in digital pathology.\" Medical image analysis 75 (2022): 102264.\n",
+ "* Brancati, Nadia, Anna Maria Anniciello, Pushpak Pati, Daniel Riccio, Giosuè Scognamiglio, Guillaume Jaume, Giuseppe De Pietro et al. \"Bracs: A dataset for breast carcinoma subtyping in h&e histology images.\" Database 2022 (2022): baac093."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2307dfb9-19cc-4f6b-aa96-7f5f40322e0b",
+ "metadata": {},
+ "source": [
+ "## Session info"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1e1d903d-01ef-492f-a806-d974a0940c4c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import IPython\n",
+ "print(IPython.sys_info())\n",
+ "print(f\"torch version: {torch.__version__}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pathml_graph_dev",
+ "language": "python",
+ "name": "pathml_graph_dev"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/train_hactnet.ipynb b/examples/train_hactnet.ipynb
new file mode 100644
index 00000000..72d83e20
--- /dev/null
+++ b/examples/train_hactnet.ipynb
@@ -0,0 +1,318 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "655f37d3-1591-4a7c-8f54-f7fc8148dcfd",
+ "metadata": {},
+ "source": [
+ "# Training a HACTNet model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "88157b8c-0368-42a4-af82-800b7bab74d7",
+ "metadata": {},
+ "source": [
+ "In this notebook, we will train the HACTNet graph neural network (GNN) model on input cell and tissue graphs using the new `pathml.graph` API.\n",
+ "\n",
+ "To run the notebook and train the model, you will have to first download the BRACS ROI set from the [BRACS dataset](https://www.bracs.icar.cnr.it/download/). To do so, you will have to sign up and create an account. Next, you will have to construct the cell and tissue graphs using the tutorial in `examples/construct_graphs.ipynb`. Use the output directory specified there as the input to the functions in this tutorial. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "f682f702-b590-4e3d-8c97-a362411acade",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from glob import glob\n",
+ "import argparse\n",
+ "from PIL import Image\n",
+ "import numpy as np\n",
+ "from tqdm import tqdm\n",
+ "import torch \n",
+ "import h5py\n",
+ "import warnings\n",
+ "import math\n",
+ "from skimage.measure import regionprops, label\n",
+ "import networkx as nx\n",
+ "import traceback\n",
+ "from glob import glob\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch_geometric.data import Batch\n",
+ "from torch_geometric.data import Data\n",
+ "from torch.utils.data import Dataset\n",
+ "from torch_geometric.loader import DataLoader\n",
+ "from torch.optim.lr_scheduler import StepLR\n",
+ "from sklearn.metrics import f1_score\n",
+ "\n",
+ "from pathml.datasets import EntityDataset\n",
+ "from pathml.ml.utils import get_degree_histogram, get_class_weights\n",
+ "from pathml.ml import HACTNet"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6fb4ee6f-7b17-424d-9cdd-710e36c7341c",
+ "metadata": {},
+ "source": [
+ "## Model Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "10601c10-d069-481d-b502-f98f76e18e3c",
+ "metadata": {},
+ "source": [
+ "Here we define the main training loop for loading the constructed graphs, initializing and training the model. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "00cb474e-0441-4ff0-a495-709d3df3759d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def train_hactnet(root_dir, load_histogram=True, histogram_dir=None, calc_class_weights=True):\n",
+ "\n",
+ " # Read the train, validation and test dataset into the pathml.datasets.EntityDataset class \n",
+ " train_dataset = EntityDataset(os.path.join(root_dir, 'cell_graphs/train/'),\n",
+ " os.path.join(root_dir, 'tissue_graphs/train/'),\n",
+ " os.path.join(root_dir, 'assignment_matrices/train/'))\n",
+ " val_dataset = EntityDataset(os.path.join(root_dir, 'cell_graphs/val/'),\n",
+ " os.path.join(root_dir, 'tissue_graphs/val/'),\n",
+ " os.path.join(root_dir, 'assignment_matrices/val/'))\n",
+ " test_dataset = EntityDataset(os.path.join(root_dir, 'cell_graphs/test/'),\n",
+ " os.path.join(root_dir, 'tissue_graphs/test/'),\n",
+ " os.path.join(root_dir, 'assignment_matrices/test/'))\n",
+ "\n",
+ " # Print the lengths of each dataset split\n",
+ " print(f\"Length of training dataset: {len(train_dataset)}\")\n",
+ " print(f\"Length of validation dataset: {len(val_dataset)}\")\n",
+ " print(f\"Length of test dataset: {len(test_dataset)}\")\n",
+ "\n",
+ " # Define the torch_geometric.DataLoader object for each dataset split with a batch size of 4\n",
+ " train_batch = DataLoader(train_dataset, batch_size=4, shuffle=False, follow_batch =['x_cell', 'x_tissue'], drop_last=True)\n",
+ " val_batch = DataLoader(val_dataset, batch_size=4, shuffle=True, follow_batch =['x_cell', 'x_tissue'], drop_last=True)\n",
+ " test_batch = DataLoader(test_dataset, batch_size=4, shuffle=True, follow_batch =['x_cell', 'x_tissue'], drop_last=True)\n",
+ "\n",
+ " # The GNN layer we use in this model, PNAConv, requires the computation of a node degree histogram of the \n",
+ " # train dataset. We only need to compute it once. If it is precomputed already, set the load_histogram=True.\n",
+ " # Else, the degree histogram is calculated. \n",
+ " if load_histogram:\n",
+ " histogram_dir = \"./\"\n",
+ " cell_deg = torch.load(os.path.join(histogram_dir, 'cell_degree_norm.pt'))\n",
+ " tissue_deg = torch.load(os.path.join(histogram_dir, 'tissue_degree_norm.pt'))\n",
+ " else:\n",
+ " train_batch_hist = DataLoader(train_dataset, batch_size=20, shuffle=True, follow_batch =['x_cell', 'x_tissue'])\n",
+ " print('Calculating degree histogram for cell graph')\n",
+ " cell_deg = get_degree_histogram(train_batch_hist, 'edge_index_cell', 'x_cell')\n",
+ " print('Calculating degree histogram for tissue graph')\n",
+ " tissue_deg = get_degree_histogram(train_batch_hist, 'edge_index_tissue', 'x_tissue')\n",
+ " torch.save(cell_deg, 'cell_degree_norm.pt')\n",
+ " torch.save(tissue_deg, 'tissue_degree_norm.pt')\n",
+ "\n",
+ " # Since the BRACS dataset has unbalanced data, it is important to calculate the class weights in the training set\n",
+ " # and provide that as an argument to our loss function. \n",
+ " if calc_class_weights:\n",
+ " train_w = get_class_weights(train_batch)\n",
+ " torch.save(torch.tensor(train_w), 'loss_weights_norm.pt')\n",
+ "\n",
+ " # Here we define the keyword arguments for the PNAConv layer in the model for both cell and tissue processing \n",
+ " # layers. \n",
+ " kwargs_pna_cell = {'aggregators': [\"mean\", \"max\", \"min\", \"std\"],\n",
+ " \"scalers\": [\"identity\", \"amplification\", \"attenuation\"],\n",
+ " \"deg\": cell_deg}\n",
+ " kwargs_pna_tissue = {'aggregators': [\"mean\", \"max\", \"min\", \"std\"],\n",
+ " \"scalers\": [\"identity\", \"amplification\", \"attenuation\"],\n",
+ " \"deg\": tissue_deg}\n",
+ " \n",
+ " cell_params = {'layer':'PNAConv', 'in_channels':514, 'hidden_channels':64, \n",
+ " 'num_layers':3, 'out_channels':64, 'readout_op':'lstm', \n",
+ " 'readout_type':'mean', 'kwargs':kwargs_pna_cell}\n",
+ " \n",
+ " tissue_params = {'layer':'PNAConv', 'in_channels':514, 'hidden_channels':64, \n",
+ " 'num_layers':3, 'out_channels':64, 'readout_op':'lstm', \n",
+ " 'readout_type':'mean', 'kwargs':kwargs_pna_tissue}\n",
+ " \n",
+ " classifier_params = {'in_channels':128, 'hidden_channels':128,\n",
+ " 'out_channels':7, 'num_layers': 2}\n",
+ "\n",
+ " # Transfer the model to GPU\n",
+ " device = torch.device(\"cuda\")\n",
+ "\n",
+ " # Initialize the pathml.ml.HACTNet model\n",
+ " model = HACTNet(cell_params, tissue_params, classifier_params)\n",
+ "\n",
+ " # Set up optimizer\n",
+ " opt = torch.optim.Adam(model.parameters(), lr = 0.0005)\n",
+ "\n",
+ " # Learning rate scheduler to reduce LR by factor of 10 each 25 epochs\n",
+ " scheduler = StepLR(opt, step_size=25, gamma=0.1)\n",
+ "\n",
+ " # Send the model to GPU\n",
+ " model = model.to(device)\n",
+ "\n",
+ " # Define number of epochs \n",
+ " n_epochs = 60\n",
+ "\n",
+ " # Keep a track of best epoch and metric for saving only the best models\n",
+ " best_epoch = 0\n",
+ " best_metric = 0\n",
+ "\n",
+ " # Load the computed class weights if calc_class_weights = True\n",
+ " if calc_class_weights:\n",
+ " loss_weights = torch.load('loss_weights_norm.pt')\n",
+ "\n",
+ " # Define the loss function\n",
+ " loss_fn = nn.CrossEntropyLoss(weight=loss_weights.float().to(device) if calc_class_weights else None)\n",
+ "\n",
+ " # Define the evaluate function to compute metrics for validation and test set to keep track of performance.\n",
+ " # The metrics used are per-class and weighted F1 score. \n",
+ " def evaluate(data_loader):\n",
+ " model.eval()\n",
+ " y_true = []\n",
+ " y_pred = []\n",
+ " with torch.no_grad():\n",
+ " for data in tqdm(data_loader):\n",
+ " data = data.to(device)\n",
+ " outputs = model(data)\n",
+ " y_true.append(torch.argmax(outputs.detach().cpu().softmax(dim=1), dim=-1).numpy())\n",
+ " y_pred.append(data.target.cpu().numpy())\n",
+ " y_true = np.array(y_true).ravel()\n",
+ " y_pred = np.array(y_pred).ravel()\n",
+ " per_class = f1_score(y_true, y_pred, average=None)\n",
+ " weighted = f1_score(y_true, y_pred, average='weighted')\n",
+ " print(f'Per class F1: {per_class}')\n",
+ " print(f'Weighted F1: {weighted}')\n",
+ " return np.append(per_class, weighted)\n",
+ "\n",
+ " # Start the training loop\n",
+ " for i in range(n_epochs):\n",
+ " print(f'\\n>>>>>>>>>>>>>>>>Epoch number {i}>>>>>>>>>>>>>>>>')\n",
+ " minibatch_train_losses = []\n",
+ " \n",
+ " # Put model in training mode\n",
+ " model.train()\n",
+ " \n",
+ " print('Training')\n",
+ " \n",
+ " for data in tqdm(train_batch):\n",
+ " \n",
+ " # Send the data to the GPU\n",
+ " data = data.to(device)\n",
+ " \n",
+ " # Zero out gradient\n",
+ " opt.zero_grad()\n",
+ " \n",
+ " # Forward pass\n",
+ " outputs = model(data)\n",
+ " \n",
+ " # Compute loss\n",
+ " loss = loss_fn(outputs, data.target)\n",
+ " \n",
+ " # Compute gradients\n",
+ " loss.backward()\n",
+ " \n",
+ " # Step optimizer and scheduler\n",
+ " opt.step() \n",
+ "\n",
+ " # Track loss\n",
+ " minibatch_train_losses.append(loss.detach().cpu().numpy())\n",
+ " \n",
+ " print(f'Loss: {np.array(minibatch_train_losses).ravel().mean()}')\n",
+ "\n",
+ " # Print performance metrics on validation set\n",
+ " print('\\nEvaluating on validation')\n",
+ " val_metrics = evaluate(val_batch)\n",
+ "\n",
+ " # Save the model only if it is better than previous checkpoint in validation metrics\n",
+ " if val_metrics[-1] > best_metric:\n",
+ " print('Saving checkpoint')\n",
+ " torch.save(model.state_dict(), \"hact_net_norm.pt\")\n",
+ " best_metric = val_metrics[-1]\n",
+ "\n",
+ " # Print performance metrics on test set\n",
+ " print('\\nEvaluating on test')\n",
+ " _ = evaluate(test_batch)\n",
+ " \n",
+ " # Step LR scheduler\n",
+ " scheduler.step()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7f408188-5804-4ac4-9a1e-642c6e5e6d09",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "root_dir = '../../../../mnt/disks/data/varun/BRACS_RoI/latest_version/pathml_graph_data_norm/'\n",
+ "train_hactnet(root_dir, load_histogram=True, calc_class_weights=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5c4d6c48-c5be-4dcd-a38e-6277d1fd5956",
+ "metadata": {},
+ "source": [
+ "## References\n",
+ "\n",
+ "* Pati, Pushpak, Guillaume Jaume, Antonio Foncubierta-Rodriguez, Florinda Feroce, Anna Maria Anniciello, Giosue Scognamiglio, Nadia Brancati et al. \"Hierarchical graph representations in digital pathology.\" Medical image analysis 75 (2022): 102264.\n",
+ "* Brancati, Nadia, Anna Maria Anniciello, Pushpak Pati, Daniel Riccio, Giosuè Scognamiglio, Guillaume Jaume, Giuseppe De Pietro et al. \"Bracs: A dataset for breast carcinoma subtyping in h&e histology images.\" Database 2022 (2022): baac093."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "363ea74a-da2b-4e92-8d29-cbf7f59792cb",
+ "metadata": {},
+ "source": [
+ "## Session info"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b0c5f9f8-cb8c-4d61-9147-7d82bcb45c9c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import IPython\n",
+ "print(IPython.sys_info())\n",
+ "print(f\"torch version: {torch.__version__}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e7872fb6-7b62-422b-a7db-0b7e169d82c7",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "pathml_graph_dev",
+ "language": "python",
+ "name": "pathml_graph_dev"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/train_hovernet.ipynb b/examples/train_hovernet.ipynb
index 5d152ee3..5384dd85 100644
--- a/examples/train_hovernet.ipynb
+++ b/examples/train_hovernet.ipynb
@@ -736,9 +736,9 @@
"uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.1-6:m59"
},
"kernelspec": {
- "display_name": "hovernet",
+ "display_name": "pathml_graph_dev",
"language": "python",
- "name": "hovernet"
+ "name": "pathml_graph_dev"
},
"language_info": {
"codemirror_mode": {
@@ -750,7 +750,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.12"
+ "version": "3.9.18"
}
},
"nbformat": 4,
diff --git a/pathml/.coverage_others b/pathml/.coverage_others
new file mode 100644
index 00000000..c60bd5c7
Binary files /dev/null and b/pathml/.coverage_others differ
diff --git a/pathml/_version.py b/pathml/_version.py
index b91684f4..2704c284 100644
--- a/pathml/_version.py
+++ b/pathml/_version.py
@@ -3,4 +3,4 @@
License: GNU GPL 2.0
"""
-__version__ = "2.1.0"
+__version__ = "2.1.1"
diff --git a/pathml/datasets/__init__.py b/pathml/datasets/__init__.py
index 1f5ececd..6784fde3 100644
--- a/pathml/datasets/__init__.py
+++ b/pathml/datasets/__init__.py
@@ -3,5 +3,6 @@
License: GNU GPL 2.0
"""
+from .datasets import EntityDataset, TileDataset
from .deepfocus import DeepFocusDataModule
from .pannuke import PanNukeDataModule
diff --git a/pathml/datasets/datasets.py b/pathml/datasets/datasets.py
new file mode 100644
index 00000000..d6dbbe01
--- /dev/null
+++ b/pathml/datasets/datasets.py
@@ -0,0 +1,446 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import copy
+import os
+import warnings
+from glob import glob
+
+import h5py
+import numpy as np
+import torch
+from skimage.measure import regionprops
+from skimage.transform import resize
+
+from pathml.graph.utils import HACTPairData
+
+
+class TileDataset(torch.utils.data.Dataset):
+ """
+ PyTorch Dataset class for h5path files
+
+ Each item is a tuple of (``tile_image``, ``tile_masks``, ``tile_labels``, ``slide_labels``) where:
+
+ - ``tile_image`` is a torch.Tensor of shape (C, H, W) or (T, Z, C, H, W)
+ - ``tile_masks`` is a torch.Tensor of shape (n_masks, tile_height, tile_width)
+ - ``tile_labels`` is a dict
+ - ``slide_labels`` is a dict
+
+ This is designed to be wrapped in a PyTorch DataLoader for feeding tiles into ML models.
+ Note that label dictionaries are not standardized, as users are free to store whatever labels they want.
+ For that reason, PyTorch cannot automatically stack labels into batches.
+ When creating a DataLoader from a TileDataset, it may therefore be necessary to create a custom ``collate_fn`` to
+ specify how to create batches of labels. See: https://discuss.pytorch.org/t/how-to-use-collate-fn/27181
+
+ Args:
+ file_path (str): Path to .h5path file on disk
+ """
+
+ def __init__(self, file_path):
+ self.file_path = file_path
+ self.h5 = None
+ with h5py.File(self.file_path, "r") as file:
+ self.tile_shape = eval(file["tiles"].attrs["tile_shape"])
+ self.tile_keys = list(file["tiles"].keys())
+ self.dataset_len = len(self.tile_keys)
+ self.slide_level_labels = {
+ key: val
+ for key, val in file["fields"]["labels"].attrs.items()
+ if val is not None
+ }
+
+ def __len__(self):
+ return self.dataset_len
+
+ def __getitem__(self, ix):
+ if self.h5 is None:
+ self.h5 = h5py.File(self.file_path, "r")
+
+ k = self.tile_keys[ix]
+ # this part copied from h5manager.get_tile()
+ tile_image = self.h5["tiles"][str(k)]["array"][:]
+
+ # get corresponding masks if there are masks
+ if "masks" in self.h5["tiles"][str(k)].keys():
+ masks = {
+ mask: self.h5["tiles"][str(k)]["masks"][mask][:]
+ for mask in self.h5["tiles"][str(k)]["masks"]
+ }
+ else:
+ masks = None
+
+ labels = {
+ key: val for key, val in self.h5["tiles"][str(k)]["labels"].attrs.items()
+ }
+
+ if tile_image.ndim == 3:
+ # swap axes from HWC to CHW for pytorch
+ im = tile_image.transpose(2, 0, 1)
+ elif tile_image.ndim == 5:
+ # in this case, we assume that we have XYZCT channel order (OME-TIFF)
+ # so we swap axes to TCZYX for batching
+ im = tile_image.transpose(4, 3, 2, 1, 0)
+ else:
+ raise NotImplementedError(
+ f"tile image has shape {tile_image.shape}. Expecting an image with 3 dims (HWC) or 5 dims (XYZCT)"
+ )
+
+ masks = np.stack(list(masks.values()), axis=0) if masks else None
+
+ return im, masks, labels, self.slide_level_labels
+
+
+class EntityDataset(torch.utils.data.Dataset):
+ """
+ Torch Geometric Dataset class for storing cell or tissue graphs. Each item returns a
+ pathml.graph.utils.HACTPairData object.
+
+ Args:
+ cell_dir (str): Path to folder containing cell graphs
+ tissue_dir (str): Path to folder containing tissue graphs
+ assign_dir (str): Path to folder containing assignment matrices
+ """
+
+ def __init__(self, cell_dir=None, tissue_dir=None, assign_dir=None):
+ self.cell_dir = cell_dir
+ self.tissue_dir = tissue_dir
+ self.assign_dir = assign_dir
+
+ if self.cell_dir is not None:
+ if not os.path.exists(cell_dir):
+ raise FileNotFoundError(f"Directory not found: {self.cell_dir}")
+ self.cell_graphs = glob(os.path.join(cell_dir, "*.pt"))
+
+ if self.tissue_dir is not None:
+ if not os.path.exists(tissue_dir):
+ raise FileNotFoundError(f"Directory not found: {self.tissue_dir}")
+ self.tissue_graphs = glob(os.path.join(tissue_dir, "*.pt"))
+
+ if self.assign_dir is not None:
+ if not os.path.exists(assign_dir):
+ raise FileNotFoundError(f"Directory not found: {self.assign_dir}")
+ self.assigns = glob(os.path.join(assign_dir, "*.pt"))
+
+ def __len__(self):
+ return len(self.cell_graphs)
+
+ def __getitem__(self, index):
+
+ # Load cell graphs, tissue graphs and assignments if they are provided
+ if self.cell_dir is not None:
+ cell_graph = torch.load(self.cell_graphs[index])
+ target = cell_graph["target"]
+
+ if self.tissue_dir is not None:
+ tissue_graph = torch.load(self.tissue_graphs[index])
+ target = tissue_graph["target"]
+
+ if self.assign_dir is not None:
+ assignment = torch.load(self.assigns[index])
+
+ # Create pathml.graph.utils.HACTPairData object with prvided objects
+ data = HACTPairData(
+ x_cell=cell_graph.node_features if self.cell_dir is not None else None,
+ edge_index_cell=cell_graph.edge_index
+ if self.cell_dir is not None
+ else None,
+ x_tissue=tissue_graph.node_features
+ if self.tissue_dir is not None
+ else None,
+ edge_index_tissue=tissue_graph.edge_index
+ if self.tissue_dir is not None
+ else None,
+ assignment=assignment[1, :] if self.assign_dir is not None else None,
+ target=target,
+ )
+ return data
+
+
+class InstanceMapPatchDataset(torch.utils.data.Dataset):
+ """
+ Create a dataset for a given image and extracted instance map with desired patches
+ of (patch_size, patch_size, 3).
+ Args:
+ image (np.ndarray): RGB input image.
+ instance map (np.ndarray): Extracted instance map.
+ entity (str): Entity to be processed. Must be one of 'cell' or 'tissue'. Defaults to 'cell'.
+ patch_size (int): Desired size of patch.
+ threshold (float): Minimum threshold for processing a patch or not.
+ resize_size (int): Desired resized size to input the network. If None, no resizing is done and the
+ patches of size patch_size are provided to the network. Defaults to None.
+ fill_value (Optional[int]): Value to fill outside the instance maps. Defaults to 255.
+ mean (list[float], optional): Channel-wise mean for image normalization.
+ std (list[float], optional): Channel-wise std for image normalization.
+ with_instance_masking (bool): If pixels outside instance should be masked. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ image,
+ instance_map,
+ entity="cell",
+ patch_size=64,
+ threshold=0.2,
+ resize_size=None,
+ fill_value=255,
+ mean=None,
+ std=None,
+ with_instance_masking=False,
+ ):
+
+ self.image = image
+ self.instance_map = instance_map
+ self.entity = entity
+ self.patch_size = patch_size
+ self.with_instance_masking = with_instance_masking
+ self.fill_value = fill_value
+ self.resize_size = resize_size
+ self.mean = mean
+ self.std = std
+
+ self.patch_size_2 = int(self.patch_size // 2)
+
+ self.image = np.pad(
+ self.image,
+ (
+ (self.patch_size_2, self.patch_size_2),
+ (self.patch_size_2, self.patch_size_2),
+ (0, 0),
+ ),
+ mode="constant",
+ constant_values=self.fill_value,
+ )
+ self.instance_map = np.pad(
+ self.instance_map,
+ (
+ (self.patch_size_2, self.patch_size_2),
+ (self.patch_size_2, self.patch_size_2),
+ ),
+ mode="constant",
+ constant_values=0,
+ )
+
+ self.threshold = int(self.patch_size * self.patch_size * threshold)
+ self.warning_threshold = 0.75
+
+ try:
+ from torchvision import transforms
+
+ self.use_torchvision = True
+ except ImportError:
+ print(
+ "Torchvision is not installed, using base modules for resizing patches and skipping normalization"
+ )
+ self.use_torchvision = False
+
+ if self.use_torchvision:
+ basic_transforms = [transforms.ToPILImage()]
+ if self.resize_size is not None:
+ basic_transforms.append(transforms.Resize(self.resize_size))
+ basic_transforms.append(transforms.ToTensor())
+ if self.mean is not None and self.std is not None:
+ basic_transforms.append(transforms.Normalize(self.mean, self.std))
+ self.dataset_transform = transforms.Compose(basic_transforms)
+
+ if self.entity not in ["cell", "tissue"]:
+ raise ValueError(
+ "Invalid value for entity. Expected 'cell' or 'tissue', got '{}'.".format(
+ self.entity
+ )
+ )
+
+ if self.entity == "cell":
+ self._precompute_cell()
+ elif self.entity == "tissue":
+ self._precompute_tissue()
+
+ def _add_patch(self, center_x, center_y, instance_index, region_count):
+ """Extract and include patch information."""
+
+ # Get a patch for each entity in the instance map
+ mask = self.instance_map[
+ center_y - self.patch_size_2 : center_y + self.patch_size_2,
+ center_x - self.patch_size_2 : center_x + self.patch_size_2,
+ ]
+
+ # Check the overlap between the extracted patch and the entity
+ overlap = np.sum(mask == instance_index)
+
+ # Add patch coordinates if overlap is greated than threshold
+ if overlap > self.threshold:
+ loc = [center_x - self.patch_size_2, center_y - self.patch_size_2]
+ self.patch_coordinates.append(loc)
+ self.patch_region_count.append(region_count)
+ self.patch_instance_ids.append(instance_index)
+ self.patch_overlap.append(overlap)
+
+ def _get_patch_tissue(self, loc, region_id=None):
+ """Extract tissue patches from image."""
+
+ # Get bounding box of given location
+ min_x = loc[0]
+ min_y = loc[1]
+ max_x = min_x + self.patch_size
+ max_y = min_y + self.patch_size
+
+ patch = copy.deepcopy(self.image[min_y:max_y, min_x:max_x])
+
+ # Fill background pixels with instance masking value
+ if self.with_instance_masking:
+ instance_mask = ~(self.instance_map[min_y:max_y, min_x:max_x] == region_id)
+ patch[instance_mask, :] = self.fill_value
+ return patch
+
+ def _get_patch_cell(self, loc, region_id):
+ """Extract cell patches from image."""
+
+ # Get bounding box of given location
+ min_y, min_x = loc
+ patch = self.image[
+ min_y : min_y + self.patch_size, min_x : min_x + self.patch_size, :
+ ]
+
+ # Fill background pixels with instance masking value
+ if self.with_instance_masking:
+ instance_mask = ~(
+ self.instance_map[
+ min_y : min_y + self.patch_size, min_x : min_x + self.patch_size
+ ]
+ == region_id
+ )
+ patch[instance_mask, :] = self.fill_value
+
+ return patch
+
+ def _precompute_cell(self):
+ """Precompute instance-wise patch information for all cell instances in the input image."""
+
+ # Get location of all entities from the instance map
+ self.entities = regionprops(self.instance_map)
+ self.patch_coordinates = []
+ self.patch_overlap = []
+ self.patch_region_count = []
+ self.patch_instance_ids = []
+
+ # Get coordinates for all entities and add them to the pile
+ for region_count, region in enumerate(self.entities):
+ min_y, min_x, max_y, max_x = region.bbox
+
+ cy, cx = region.centroid
+ cy, cx = int(cy), int(cx)
+
+ coord = [cy - self.patch_size_2, cx - self.patch_size_2]
+
+ instance_mask = self.instance_map[
+ coord[0] : coord[0] + self.patch_size,
+ coord[1] : coord[1] + self.patch_size,
+ ]
+ overlap = np.sum(instance_mask == region.label)
+ if overlap >= self.threshold:
+ self.patch_coordinates.append(coord)
+ self.patch_region_count.append(region_count)
+ self.patch_instance_ids.append(region.label)
+ self.patch_overlap.append(overlap)
+
+ def _precompute_tissue(self):
+ """Precompute instance-wise patch information for all tissue instances in the input image."""
+
+ # Get location of all entities from the instance map
+ self.patch_coordinates = []
+ self.patch_region_count = []
+ self.patch_instance_ids = []
+ self.patch_overlap = []
+
+ self.entities = regionprops(self.instance_map)
+ self.stride = self.patch_size
+
+ # Get coordinates for all entities and add them to the pile
+ for region_count, region in enumerate(self.entities):
+
+ # Extract centroid
+ center_y, center_x = region.centroid
+ center_x = int(round(center_x))
+ center_y = int(round(center_y))
+
+ # Extract bounding box
+ min_y, min_x, max_y, max_x = region.bbox
+
+ # Extract patch information around the centroid patch
+ # quadrant 1 (includes centroid patch)
+ y_ = copy.deepcopy(center_y)
+ while y_ >= min_y:
+ x_ = copy.deepcopy(center_x)
+ while x_ >= min_x:
+ self._add_patch(x_, y_, region.label, region_count)
+ x_ -= self.stride
+ y_ -= self.stride
+
+ # quadrant 4
+ y_ = copy.deepcopy(center_y)
+ while y_ >= min_y:
+ x_ = copy.deepcopy(center_x) + self.stride
+ while x_ <= max_x:
+ self._add_patch(x_, y_, region.label, region_count)
+ x_ += self.stride
+ y_ -= self.stride
+
+ # quadrant 2
+ y_ = copy.deepcopy(center_y) + self.stride
+ while y_ <= max_y:
+ x_ = copy.deepcopy(center_x)
+ while x_ >= min_x:
+ self._add_patch(x_, y_, region.label, region_count)
+ x_ -= self.stride
+ y_ += self.stride
+
+ # quadrant 3
+ y_ = copy.deepcopy(center_y) + self.stride
+ while y_ <= max_y:
+ x_ = copy.deepcopy(center_x) + self.stride
+ while x_ <= max_x:
+ self._add_patch(x_, y_, region.label, region_count)
+ x_ += self.stride
+ y_ += self.stride
+
+ def _warning(self):
+ """Check patch coverage statistics to identify if provided patch size includes too much background."""
+
+ self.patch_overlap = np.array(self.patch_overlap) / (
+ self.patch_size * self.patch_size
+ )
+ if np.mean(self.patch_overlap) < self.warning_threshold:
+ warnings.warn("Provided patch size is large")
+ warnings.warn("Suggestion: Reduce patch size to include relevant context.")
+
+ def __getitem__(self, index):
+ """Loads an image for a given patch index."""
+
+ if self.entity == "cell":
+ patch = self._get_patch_cell(
+ self.patch_coordinates[index], self.patch_instance_ids[index]
+ )
+ elif self.entity == "tissue":
+ patch = self._get_patch_tissue(
+ self.patch_coordinates[index], self.patch_instance_ids[index]
+ )
+ else:
+ raise ValueError(
+ "Invalid value for entity. Expected 'cell' or 'tissue', got '{}'.".format(
+ self.entity
+ )
+ )
+
+ if self.use_torchvision:
+ patch = self.dataset_transform(patch)
+ else:
+ patch = patch / 255.0 if patch.max() > 1 else patch
+ patch = resize(patch, (self.resize_size, self.resize_size))
+ patch = torch.from_numpy(patch).permute(2, 0, 1).float()
+
+ return patch, self.patch_region_count[index]
+
+ def __len__(self):
+ """Returns the length of the dataset."""
+ return len(self.patch_coordinates)
diff --git a/pathml/datasets/pannuke.py b/pathml/datasets/pannuke.py
index 63bb35e8..9e801515 100644
--- a/pathml/datasets/pannuke.py
+++ b/pathml/datasets/pannuke.py
@@ -17,7 +17,7 @@
from pathml.datasets.base_data_module import BaseDataModule
from pathml.datasets.utils import pannuke_multiclass_mask_to_nucleus_mask
-from pathml.ml.hovernet import compute_hv_map
+from pathml.ml.models.hovernet import compute_hv_map
from pathml.utils import download_from_url
diff --git a/pathml/datasets/utils.py b/pathml/datasets/utils.py
index daf408f5..cd18c522 100644
--- a/pathml/datasets/utils.py
+++ b/pathml/datasets/utils.py
@@ -3,7 +3,15 @@
License: GNU GPL 2.0
"""
+import importlib
+
import numpy as np
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+
+from pathml.datasets.datasets import InstanceMapPatchDataset
def pannuke_multiclass_mask_to_nucleus_mask(multiclass_mask):
@@ -30,3 +38,230 @@ def pannuke_multiclass_mask_to_nucleus_mask(multiclass_mask):
# ignore last channel
out = np.sum(multiclass_mask[:-1, :, :], axis=0)
return out
+
+
+def _remove_modules(model, last_layer):
+ """
+ Remove all modules in the model that come after a given layer.
+
+ Args:
+ model (nn.Module): A PyTorch model.
+ last_layer (str): Last layer to keep in the model.
+
+ Returns:
+ Model (nn.Module) without pruned modules.
+ """
+ modules = [n for n, _ in model.named_children()]
+ modules_to_remove = modules[modules.index(last_layer) + 1 :]
+ for mod in modules_to_remove:
+ setattr(model, mod, nn.Sequential())
+ return model
+
+
+class DeepPatchFeatureExtractor:
+ """
+ Patch feature extracter of a given architecture and put it on GPU if available using
+ Pathml.datasets.InstanceMapPatchDataset.
+
+ Args:
+ patch_size (int): Desired size of patch.
+ batch_size (int): Desired size of batch.
+ architecture (str or nn.Module): String of architecture. According to torchvision.models syntax, path to local model or nn.Module class directly.
+ entity (str): Entity to be processed. Must be one of 'cell' or 'tissue'. Defaults to 'cell'.
+ device (torch.device): Torch Device used for inference.
+ fill_value (int): Value to fill outside the instance maps. Defaults to 255.
+ threshold (float): Threshold for processing a patch or not.
+ resize_size (int): Desired resized size to input the network. If None, no resizing is done and
+ the patches of size patch_size are provided to the network. Defaults to None.
+ with_instance_masking (bool): If pixels outside instance should be masked. Defaults to False.
+ extraction_layer (str): Name of the network module from where the features are
+ extracted.
+
+ Returns:
+ Tensor of features computed for each entity.
+ """
+
+ def __init__(
+ self,
+ patch_size,
+ batch_size,
+ architecture,
+ device="cpu",
+ entity="cell",
+ fill_value=255,
+ threshold=0.2,
+ resize_size=224,
+ with_instance_masking=False,
+ extraction_layer=None,
+ ):
+
+ self.fill_value = fill_value
+ self.patch_size = patch_size
+ self.batch_size = batch_size
+ self.resize_size = resize_size
+ self.threshold = threshold
+ self.with_instance_masking = with_instance_masking
+ self.entity = entity
+ self.device = device
+
+ if isinstance(architecture, nn.Module):
+ self.model = architecture.to(self.device)
+ elif architecture.endswith(".pth"):
+ model = self._get_local_model(path=architecture)
+ self._validate_model(model)
+ self.model = self._remove_layers(model, extraction_layer)
+ else:
+ try:
+ global torchvision
+ import torchvision
+
+ model = self._get_torchvision_model(architecture).to(self.device)
+ self._validate_model(model)
+ self.model = self._remove_layers(model, extraction_layer)
+ except (ImportError, ModuleNotFoundError):
+ raise Exception(
+ "Using online models require torchvision to be installed"
+ )
+
+ self.normalizer_mean = [0.485, 0.456, 0.406]
+ self.normalizer_std = [0.229, 0.224, 0.225]
+
+ self.num_features = self._get_num_features(patch_size)
+ self.model.eval()
+
+ @staticmethod
+ def _validate_model(model):
+ """Raise an error if the model does not have the required attributes."""
+
+ if not isinstance(model, torchvision.models.resnet.ResNet):
+ if not hasattr(model, "classifier"):
+ raise ValueError(
+ "Please provide either a ResNet-type architecture or"
+ + ' an architecture that has the attribute "classifier".'
+ )
+
+ if not (hasattr(model, "features") or hasattr(model, "model")):
+ raise ValueError(
+ "Please provide an architecture that has the attribute"
+ + ' "features" or "model".'
+ )
+
+ def _get_num_features(self, patch_size):
+ """Get the number of features of a given model."""
+ dummy_patch = torch.zeros(1, 3, self.resize_size, self.resize_size).to(
+ self.device
+ )
+ features = self.model(dummy_patch)
+ return features.shape[-1]
+
+ def _get_local_model(self, path):
+ """Load a model from a local path."""
+ model = torch.load(path, map_location=self.device)
+ return model
+
+ def _get_torchvision_model(self, architecture):
+ """Returns a torchvision model from a given architecture string."""
+
+ module = importlib.import_module("torchvision.models")
+ model_class = getattr(module, architecture)
+ model = model_class(weights="IMAGENET1K_V1")
+ model = model.to(self.device)
+ return model
+
+ @staticmethod
+ def _remove_layers(model, extraction_layer=None):
+ """Returns the model without the unused layers to get embeddings."""
+
+ if hasattr(model, "model"):
+ model = model.model
+ if extraction_layer is not None:
+ model = _remove_modules(model, extraction_layer)
+ if isinstance(model, torchvision.models.resnet.ResNet):
+ if extraction_layer is None:
+ # remove classifier
+ model.fc = nn.Sequential()
+ else:
+ # remove all layers after the extraction layer
+ model = _remove_modules(model, extraction_layer)
+ else:
+ # remove classifier
+ model.classifier = nn.Sequential()
+ if extraction_layer is not None:
+ # remove average pooling layer if necessary
+ if hasattr(model, "avgpool"):
+ model.avgpool = nn.Sequential()
+ # remove all layers in the feature extractor after the extraction layer
+ model.features = _remove_modules(model.features, extraction_layer)
+ return model
+
+ @staticmethod
+ def _preprocess_architecture(architecture):
+ """Preprocess the architecture string to avoid characters that are not allowed as paths."""
+ if architecture.endswith(".pth"):
+ return f"Local({architecture.replace('/', '_')})"
+ else:
+ return architecture
+
+ def _collate_patches(self, batch):
+ """Patch collate function"""
+
+ instance_indices = [item[1] for item in batch]
+ patches = [item[0] for item in batch]
+ patches = torch.stack(patches)
+ return instance_indices, patches
+
+ def process(self, input_image, instance_map):
+ """Main processing function that takes in an input image and an instance map and returns features for all
+ entities in the instance map"""
+
+ # Create a pathml.datasets.datasets.InstanceMapPatchDataset class
+ image_dataset = InstanceMapPatchDataset(
+ image=input_image,
+ instance_map=instance_map,
+ entity=self.entity,
+ patch_size=self.patch_size,
+ threshold=self.threshold,
+ resize_size=self.resize_size,
+ fill_value=self.fill_value,
+ mean=self.normalizer_mean,
+ std=self.normalizer_std,
+ with_instance_masking=self.with_instance_masking,
+ )
+
+ # Create a torch DataLoader
+ image_loader = DataLoader(
+ image_dataset,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=0,
+ collate_fn=self._collate_patches,
+ )
+
+ # Initialize feature tensor
+ features = torch.zeros(
+ size=(len(image_dataset.entities), self.num_features),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ embeddings = {}
+
+ # Get features for batches of patches and add to feature tensor
+ for instance_indices, patches in tqdm(image_loader, total=len(image_loader)):
+
+ # Send to device
+ patches = patches.to(self.device)
+
+ # Inference mode
+ with torch.no_grad():
+ emb = self.model(patches).squeeze()
+ for j, key in enumerate(instance_indices):
+
+ # If entity already exists, add features on top of previous features
+ if key in embeddings:
+ embeddings[key][0] += emb[j]
+ embeddings[key][1] += 1
+ else:
+ embeddings[key] = [emb[j], 1]
+ for k, v in embeddings.items():
+ features[k, :] = v[0] / v[1]
+ return features.cpu().detach()
diff --git a/pathml/graph/__init__.py b/pathml/graph/__init__.py
new file mode 100644
index 00000000..b8a871ed
--- /dev/null
+++ b/pathml/graph/__init__.py
@@ -0,0 +1,11 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+from .preprocessing import (
+ ColorMergedSuperpixelExtractor,
+ KNNGraphBuilder,
+ RAGGraphBuilder,
+)
+from .utils import Graph, HACTPairData, build_assignment_matrix, get_full_instance_map
diff --git a/pathml/graph/preprocessing.py b/pathml/graph/preprocessing.py
new file mode 100644
index 00000000..e38a44d0
--- /dev/null
+++ b/pathml/graph/preprocessing.py
@@ -0,0 +1,663 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import math
+from abc import abstractmethod
+
+import cv2
+import networkx as nx
+import numpy as np
+import pandas as pd
+import skimage
+import torch
+
+if skimage.__version__ < "0.20.0":
+ from skimage.future import graph
+else:
+ from skimage import graph
+
+from skimage.color.colorconv import rgb2hed
+from skimage.measure import regionprops
+from skimage.segmentation import slic
+from sklearn.neighbors import kneighbors_graph
+
+from pathml.graph.utils import Graph, two_hop
+
+
+class GraphFeatureExtractor:
+ """
+ Extracts features from a networkx graph object.
+
+ Args:
+ use_weight (bool, optional): Whether to use edge weights for feature computation. Defaults to False.
+ alpha (float, optional): Alpha value for personalized page-rank. Defaults to 0.85.
+
+ Returns:
+ Dictionary of keys as feature type and values as features
+ """
+
+ def __init__(self, use_weight=False, alpha=0.85):
+ self.use_weight = use_weight
+ self.feature_dict = {}
+ self.alpha = alpha
+
+ def get_stats(self, dct, prefix="add_pre"):
+ local_dict = {}
+ lst = list(dct.values())
+ local_dict[f"{prefix}_mean"] = np.mean(lst)
+ local_dict[f"{prefix}_median"] = np.median(lst)
+ local_dict[f"{prefix}_max"] = np.max(lst)
+ local_dict[f"{prefix}_min"] = np.min(lst)
+ local_dict[f"{prefix}_sum"] = np.sum(lst)
+ local_dict[f"{prefix}_std"] = np.std(lst)
+ return local_dict
+
+ def process(self, G):
+ if self.use_weight:
+ if "weight" in list(list(G.edges(data=True))[0][-1].keys()):
+ weight = "weight"
+ else:
+ raise ValueError(
+ "No edge attribute called 'weight' when use_weight is True"
+ )
+ else:
+ weight = None
+
+ self.feature_dict["diameter"] = nx.diameter(G)
+ self.feature_dict["radius"] = nx.radius(G)
+ self.feature_dict["assortativity_degree"] = nx.degree_assortativity_coefficient(
+ G
+ )
+ self.feature_dict["density"] = nx.density(G)
+ self.feature_dict["transitivity_undir"] = nx.transitivity(G)
+
+ self.feature_dict.update(self.get_stats(nx.hits(G)[0], prefix="hubs"))
+ self.feature_dict.update(self.get_stats(nx.hits(G)[1], prefix="authorities"))
+ self.feature_dict.update(
+ self.get_stats(nx.constraint(G, weight=weight), prefix="constraint")
+ )
+ self.feature_dict.update(self.get_stats(nx.core_number(G), prefix="coreness"))
+ self.feature_dict.update(
+ self.get_stats(
+ nx.eigenvector_centrality(G, weight=weight), prefix="egvec_centr"
+ )
+ )
+ self.feature_dict.update(
+ self.get_stats(
+ {node: val for (node, val) in G.degree(weight=weight)}, prefix="degree"
+ )
+ )
+ self.feature_dict.update(
+ self.get_stats(
+ nx.pagerank(G, alpha=self.alpha), prefix="personalized_pgrank"
+ )
+ )
+
+ return self.feature_dict
+
+
+class BaseGraphBuilder:
+ """Base interface class for graph building.
+
+ Args:
+ nr_annotation_classes (int): Number of classes in annotation. Used only if setting node labels.
+ annotation_background_class (int): Background class label in annotation. Used only if setting node labels.
+ add_loc_feats (bool): Flag to include location-based features (ie normalized centroids) in node feature representation.
+ Defaults to False.
+ """
+
+ def __init__(
+ self,
+ nr_annotation_classes: int = 5,
+ annotation_background_class=None,
+ add_loc_feats=False,
+ **kwargs,
+ ):
+ """Base Graph Builder constructor."""
+ self.nr_annotation_classes = nr_annotation_classes
+ self.annotation_background_class = annotation_background_class
+ self.add_loc_feats = add_loc_feats
+ super().__init__(**kwargs)
+
+ def process( # type: ignore[override]
+ self, instance_map, features, annotation=None, target=None
+ ):
+ """Generates a graph from a given instance_map and features"""
+ # add nodes
+ self.num_nodes = features.shape[0]
+
+ # add image size as graph data
+ image_size = (instance_map.shape[1], instance_map.shape[0]) # (x, y)
+
+ # get instance centroids
+ self.centroids = self._get_node_centroids(instance_map)
+
+ # add node content
+ node_features = self._compute_node_features(features, image_size)
+
+ if annotation is not None:
+ node_labels = self._compute_node_labels(instance_map, annotation)
+ else:
+ node_labels = None
+
+ # build edges
+ edges = self._build_topology(instance_map)
+
+ return Graph(
+ node_centroids=self.centroids,
+ node_features=node_features,
+ edge_index=edges,
+ node_labels=node_labels,
+ target=torch.tensor(target),
+ )
+
+ def _get_node_centroids(self, instance_map):
+ """Get the centroids of the graphs"""
+ regions = regionprops(instance_map)
+ centroids = np.empty((len(regions), 2))
+ for i, region in enumerate(regions):
+ center_y, center_x = region.centroid # (y, x)
+ center_x = int(round(center_x))
+ center_y = int(round(center_y))
+ centroids[i, 0] = center_x
+ centroids[i, 1] = center_y
+ return torch.tensor(centroids)
+
+ def _compute_node_features(self, features, image_size):
+ """Set the provided node features"""
+ if not torch.is_tensor(features):
+ features = torch.FloatTensor(features)
+ if not self.add_loc_feats:
+ return features
+ elif self.add_loc_feats and image_size is not None:
+ # compute normalized centroid features
+
+ normalized_centroids = torch.empty_like(self.centroids) # (x, y)
+ normalized_centroids[:, 0] = self.centroids[:, 0] / image_size[0]
+ normalized_centroids[:, 1] = self.centroids[:, 1] / image_size[1]
+
+ if features.ndim == 3:
+ normalized_centroids = normalized_centroids.unsqueeze(dim=1).repeat(
+ 1, features.shape[1], 1
+ )
+ concat_dim = 2
+ elif features.ndim == 2:
+ concat_dim = 1
+
+ concat_features = torch.cat(
+ (features, normalized_centroids),
+ dim=concat_dim,
+ )
+ return concat_features
+ else:
+ raise ValueError(
+ "Please provide image size to add the normalized centroid to the node features."
+ )
+
+ @abstractmethod
+ def _set_node_labels(self, instance_map, annotation):
+ """Set the node labels of the graphs"""
+
+ @abstractmethod
+ def _build_topology(self, instance_map):
+ """Generate the graph topology from the provided instance_map"""
+
+
+class KNNGraphBuilder(BaseGraphBuilder):
+ """
+ k-Nearest Neighbors Graph class for graph building.
+
+ Args:
+ k (int, optional): Number of neighbors. Defaults to 5.
+ thresh (int, optional): Maximum allowed distance between 2 nodes. Defaults to None (no thresholding).
+
+ Returns:
+ A pathml.graph.utils.Graph object containing node and edge information.
+ """
+
+ def __init__(self, k=5, thresh=None, **kwargs):
+ """Create a graph builder that uses the (thresholded) kNN algorithm to define the graph topology."""
+
+ self.k = k
+ self.thresh = thresh
+ super().__init__(**kwargs)
+
+ def _set_node_labels(self, instance_map, annotation):
+ """Set the node labels of the graphs using annotation"""
+ regions = regionprops(instance_map)
+ assert annotation.shape[0] == len(
+ regions
+ ), "Number of annotations do not match number of nodes"
+ return torch.FloatTensor(annotation.astype(float))
+
+ def _build_topology(self, instance_map):
+ """Build topology using (thresholded) kNN"""
+
+ # build kNN adjacency
+ adjacency = kneighbors_graph(
+ self.centroids,
+ self.k,
+ mode="distance",
+ include_self=False,
+ metric="euclidean",
+ ).toarray()
+
+ # filter edges that are too far (ie larger than thresh)
+ if self.thresh is not None:
+ adjacency[adjacency > self.thresh] = 0
+
+ edge_list = torch.tensor(np.array(np.nonzero(adjacency)))
+ return edge_list
+
+
+class RAGGraphBuilder(BaseGraphBuilder):
+ """
+ Region Adjacency Graph builder class.
+
+ Args:
+ kernel_size (int, optional): Size of the kernel to detect connectivity. Defaults to 5.
+ hops (int, optional): Number of hops in a multi-hop neighbourhood. Defaults to 1.
+
+ Returns:
+ A pathml.graph.utils.Graph object containing node and edge information.
+
+ """
+
+ def __init__(self, kernel_size=3, hops=1, **kwargs):
+ """Create a graph builder that uses a provided kernel size to detect connectivity"""
+ assert hops > 0 and isinstance(
+ hops, int
+ ), f"Invalid hops {hops} ({type(hops)}). Must be integer >= 0"
+ self.kernel_size = kernel_size
+ self.hops = hops
+ super().__init__(**kwargs)
+
+ def _build_topology(self, instance_map):
+ """Create the graph topology from the instance connectivty in the instance_map"""
+
+ regions = regionprops(instance_map)
+ instance_ids = torch.empty(len(regions), dtype=torch.uint8)
+
+ kernel = np.ones((self.kernel_size, self.kernel_size), np.uint8)
+ adjacency = np.zeros(shape=(len(instance_ids), len(instance_ids)))
+
+ for instance_id in np.arange(1, len(instance_ids) + 1):
+ mask = (instance_map == instance_id).astype(np.uint8)
+ dilation = cv2.dilate(mask, kernel, iterations=1)
+ boundary = dilation - mask
+ idx = pd.unique(instance_map[boundary.astype(bool)])
+ instance_id -= 1 # because instance_map id starts from 1
+ idx -= 1 # because instance_map id starts from 1
+ adjacency[instance_id, idx] = 1
+
+ edge_list = torch.tensor(np.array(np.nonzero(adjacency)))
+
+ for _ in range(self.hops - 1):
+ edge_list = two_hop(edge_list, self.num_nodes)
+ return edge_list
+
+
+class SuperpixelExtractor:
+ """Helper class to extract superpixels from images
+
+ Args:
+ nr_superpixels (None, int): The number of super pixels before any merging.
+ superpixel_size (None, int): The size of super pixels before any merging.
+ max_nr_superpixels (int, optional): Upper bound for the number of super pixels.
+ Useful when providing a superpixel size.
+ blur_kernel_size (float, optional): Size of the blur kernel. Defaults to 0.
+ compactness (int, optional): Compactness of the superpixels. Defaults to 30.
+ max_iterations (int, optional): Number of iterations of the slic algorithm. Defaults to 10.
+ threshold (float, optional): Connectivity threshold. Defaults to 0.03.
+ connectivity (int, optional): Connectivity for merging graph. Defaults to 2.
+ downsampling_factor (int, optional): Downsampling factor from the input image
+ resolution. Defaults to 1.
+ """
+
+ def __init__(
+ self,
+ nr_superpixels: int = None,
+ superpixel_size: int = None,
+ max_nr_superpixels=None,
+ blur_kernel_size=1,
+ compactness=20,
+ max_iterations=10,
+ threshold=0.03,
+ connectivity=2,
+ color_space="rgb",
+ downsampling_factor=1,
+ **kwargs,
+ ):
+ """Abstract class that extracts superpixels from RGB Images"""
+
+ assert (nr_superpixels is None and superpixel_size is not None) or (
+ nr_superpixels is not None and superpixel_size is None
+ ), "Provide value for either nr_superpixels or superpixel_size"
+ self.nr_superpixels = nr_superpixels
+ self.superpixel_size = superpixel_size
+ self.max_nr_superpixels = max_nr_superpixels
+ self.blur_kernel_size = blur_kernel_size
+ self.compactness = compactness
+ self.max_iterations = max_iterations
+ self.threshold = threshold
+ self.connectivity = connectivity
+ self.color_space = color_space
+ self.downsampling_factor = downsampling_factor
+ super().__init__(**kwargs)
+
+ def process(self, input_image, tissue_mask=None): # type: ignore[override]
+ """Return the superpixels of a given input image"""
+ original_height, original_width, _ = input_image.shape
+ if self.downsampling_factor != 1:
+ input_image = self._downsample(input_image, self.downsampling_factor)
+ if tissue_mask is not None:
+ tissue_mask = self._downsample(tissue_mask, self.downsampling_factor)
+ superpixels = self._extract_superpixels(
+ image=input_image, tissue_mask=tissue_mask
+ )
+ if self.downsampling_factor != 1:
+ superpixels = self._upsample(superpixels, original_height, original_width)
+ return superpixels
+
+ @abstractmethod
+ def _extract_superpixels(self, image, tissue_mask=None):
+ """Perform the superpixel extraction"""
+
+ @staticmethod
+ def _downsample(image, downsampling_factor):
+ """Downsample an input image with a given downsampling factor"""
+ height, width = image.shape[0], image.shape[1]
+ new_height = math.floor(height / downsampling_factor)
+ new_width = math.floor(width / downsampling_factor)
+ downsampled_image = cv2.resize(
+ image, (new_width, new_height), interpolation=cv2.INTER_NEAREST
+ )
+ return downsampled_image
+
+ @staticmethod
+ def _upsample(image, new_height, new_width):
+ """Upsample an input image to a speficied new height and width"""
+ upsampled_image = cv2.resize(
+ image, (new_width, new_height), interpolation=cv2.INTER_NEAREST
+ )
+ return upsampled_image
+
+
+class SLICSuperpixelExtractor(SuperpixelExtractor):
+ """Use the SLIC algorithm to extract superpixels."""
+
+ def __init__(self, **kwargs):
+ """Extract superpixels with the SLIC algorithm"""
+ super().__init__(**kwargs)
+
+ def _get_nr_superpixels(self, image):
+ """Compute the number of superpixels for initial segmentation"""
+ if self.superpixel_size is not None:
+ nr_superpixels = int(
+ (image.shape[0] * image.shape[1] / self.superpixel_size)
+ )
+ elif self.nr_superpixels is not None:
+ nr_superpixels = self.nr_superpixels
+ if self.max_nr_superpixels is not None:
+ nr_superpixels = min(nr_superpixels, self.max_nr_superpixels)
+ return nr_superpixels
+
+ def _extract_superpixels(self, image, *args, **kwargs):
+ """Perform the superpixel extraction"""
+ if self.color_space == "hed":
+ image = rgb2hed(image)
+ nr_superpixels = self._get_nr_superpixels(image)
+
+ slic_args = {
+ "image": image,
+ "sigma": self.blur_kernel_size,
+ "n_segments": nr_superpixels,
+ "compactness": self.compactness,
+ "start_label": 1,
+ }
+ if skimage.__version__ < "0.20.0":
+ slic_args["max_iter"] = self.max_iterations
+ else:
+ slic_args["max_num_iter"] = self.max_iterations
+
+ superpixels = slic(**slic_args)
+ return superpixels
+
+
+class MergedSuperpixelExtractor(SuperpixelExtractor):
+ """Use the SLIC algorithm to extract superpixels and a merging function to merge superpixels"""
+
+ def __init__(self, **kwargs):
+ """Extract superpixels with the SLIC algorithm and then merge"""
+ super().__init__(**kwargs)
+
+ def _get_nr_superpixels(self, image):
+ """Compute the number of superpixels for initial segmentation"""
+ if self.superpixel_size is not None:
+ nr_superpixels = int(
+ (image.shape[0] * image.shape[1] / self.superpixel_size)
+ )
+ elif self.nr_superpixels is not None:
+ nr_superpixels = self.nr_superpixels
+ if self.max_nr_superpixels is not None:
+ nr_superpixels = min(nr_superpixels, self.max_nr_superpixels)
+ return nr_superpixels
+
+ def _extract_initial_superpixels(self, image):
+ """Extract initial superpixels using SLIC"""
+ nr_superpixels = self._get_nr_superpixels(image)
+
+ slic_args = {
+ "image": image,
+ "sigma": self.blur_kernel_size,
+ "n_segments": nr_superpixels,
+ "compactness": self.compactness,
+ "start_label": 1,
+ }
+ if skimage.__version__ < "0.20.0":
+ slic_args["max_iter"] = self.max_iterations
+ else:
+ slic_args["max_num_iter"] = self.max_iterations
+
+ superpixels = slic(**slic_args)
+ return superpixels
+
+ def _merge_superpixels(self, input_image, initial_superpixels, tissue_mask=None):
+ """Merge the initial superpixels to return merged superpixels"""
+ if tissue_mask is not None:
+ # Remove superpixels belonging to background or having < 10% tissue
+ # content
+ ids_initial = np.unique(initial_superpixels, return_counts=True)
+ ids_masked = np.unique(
+ tissue_mask * initial_superpixels, return_counts=True
+ )
+
+ ctr = 1
+ superpixels = np.zeros_like(initial_superpixels)
+ for i in range(len(ids_initial[0])):
+ id = ids_initial[0][i]
+ if id in ids_masked[0]:
+ idx = np.where(id == ids_masked[0])[0]
+ ratio = ids_masked[1][idx] / ids_initial[1][i]
+ if ratio >= 0.1:
+ superpixels[initial_superpixels == id] = ctr
+ ctr += 1
+
+ initial_superpixels = superpixels
+
+ # Merge superpixels within tissue region
+ g = self._generate_graph(input_image, initial_superpixels)
+
+ merged_superpixels = graph.merge_hierarchical(
+ initial_superpixels,
+ g,
+ thresh=self.threshold,
+ rag_copy=False,
+ in_place_merge=True,
+ merge_func=self._merging_function,
+ weight_func=self._weighting_function,
+ )
+ merged_superpixels += 1 # Handle regionprops that ignores all values of 0
+ mask = np.zeros_like(initial_superpixels)
+ mask[initial_superpixels != 0] = 1
+ merged_superpixels = merged_superpixels * mask
+ return merged_superpixels
+
+ @abstractmethod
+ def _generate_graph(self, input_image, superpixels):
+ """Generate a graph based on the input image and initial superpixel segmentation."""
+
+ @abstractmethod
+ def _weighting_function(self, graph, src, dst, n):
+ """Handle merging of nodes of a region boundary region adjacency graph."""
+
+ @abstractmethod
+ def _merging_function(self, graph, src, dst):
+ """Call back called before merging 2 nodes."""
+
+ def _extract_superpixels(self, image, tissue_mask=None):
+ """Perform superpixel extraction"""
+ initial_superpixels = self._extract_initial_superpixels(image)
+ merged_superpixels = self._merge_superpixels(
+ image, initial_superpixels, tissue_mask
+ )
+
+ return merged_superpixels, initial_superpixels
+
+ def process(self, input_image, tissue_mask=None): # type: ignore[override]
+ """Return the superpixels of a given input image"""
+ original_height, original_width, _ = input_image.shape
+ if self.downsampling_factor is not None and self.downsampling_factor != 1:
+ input_image = self._downsample(input_image, self.downsampling_factor)
+ if tissue_mask is not None:
+ tissue_mask = self._downsample(tissue_mask, self.downsampling_factor)
+ merged_superpixels, initial_superpixels = self._extract_superpixels(
+ input_image, tissue_mask
+ )
+ if self.downsampling_factor != 1:
+ merged_superpixels = self._upsample(
+ merged_superpixels, original_height, original_width
+ )
+ initial_superpixels = self._upsample(
+ initial_superpixels, original_height, original_width
+ )
+ return merged_superpixels, initial_superpixels
+
+
+class ColorMergedSuperpixelExtractor(MergedSuperpixelExtractor):
+ """Superpixel merger based on color attibutes taken from the HACT-Net Implementation
+ Args:
+ w_hist (float, optional): Weight of the histogram features for merging. Defaults to 0.5.
+ w_mean (float, optional): Weight of the mean features for merging. Defaults to 0.5.
+ """
+
+ def __init__(self, w_hist: float = 0.5, w_mean: float = 0.5, **kwargs):
+ self.w_hist = w_hist
+ self.w_mean = w_mean
+ super().__init__(**kwargs)
+
+ def _color_features_per_channel(self, img_ch: np.ndarray) -> np.ndarray:
+ """Extract color histograms from image channel"""
+ hist, _ = np.histogram(img_ch, bins=np.arange(0, 257, 64)) # 8 bins
+ return hist
+
+ def _generate_graph(self, input_image, superpixels):
+ """Construct RAG graph using initial superpixel instance map"""
+ g = graph.RAG(superpixels, connectivity=self.connectivity)
+ if 0 in g.nodes:
+ g.remove_node(n=0) # remove background node
+
+ for n in g:
+ g.nodes[n].update(
+ {
+ "labels": [n],
+ "N": 0,
+ "x": np.array([0, 0, 0]),
+ "y": np.array([0, 0, 0]),
+ "r": np.array([]),
+ "g": np.array([]),
+ "b": np.array([]),
+ }
+ )
+
+ for index in np.ndindex(superpixels.shape):
+ current = superpixels[index]
+ if current == 0:
+ continue
+ g.nodes[current]["N"] += 1
+ g.nodes[current]["x"] += input_image[index]
+ g.nodes[current]["y"] = np.vstack(
+ (g.nodes[current]["y"], input_image[index])
+ )
+
+ for n in g:
+ g.nodes[n]["mean"] = g.nodes[n]["x"] / g.nodes[n]["N"]
+ g.nodes[n]["mean"] = g.nodes[n]["mean"] / np.linalg.norm(g.nodes[n]["mean"])
+
+ g.nodes[n]["y"] = np.delete(g.nodes[n]["y"], 0, axis=0)
+ g.nodes[n]["r"] = self._color_features_per_channel(g.nodes[n]["y"][:, 0])
+ g.nodes[n]["g"] = self._color_features_per_channel(g.nodes[n]["y"][:, 1])
+ g.nodes[n]["b"] = self._color_features_per_channel(g.nodes[n]["y"][:, 2])
+
+ g.nodes[n]["r"] = g.nodes[n]["r"] / np.linalg.norm(g.nodes[n]["r"])
+ g.nodes[n]["g"] = g.nodes[n]["r"] / np.linalg.norm(g.nodes[n]["g"])
+ g.nodes[n]["b"] = g.nodes[n]["r"] / np.linalg.norm(g.nodes[n]["b"])
+
+ for x, y, d in g.edges(data=True):
+ diff_mean = np.linalg.norm(g.nodes[x]["mean"] - g.nodes[y]["mean"]) / 2
+
+ diff_r = np.linalg.norm(g.nodes[x]["r"] - g.nodes[y]["r"]) / 2
+ diff_g = np.linalg.norm(g.nodes[x]["g"] - g.nodes[y]["g"]) / 2
+ diff_b = np.linalg.norm(g.nodes[x]["b"] - g.nodes[y]["b"]) / 2
+ diff_hist = (diff_r + diff_g + diff_b) / 3
+
+ diff = self.w_hist * diff_hist + self.w_mean * diff_mean
+
+ d["weight"] = diff
+
+ return g
+
+ def _weighting_function(self, graph, src, dst, n):
+ diff_mean = np.linalg.norm(graph.nodes[dst]["mean"] - graph.nodes[n]["mean"])
+
+ diff_r = np.linalg.norm(graph.nodes[dst]["r"] - graph.nodes[n]["r"]) / 2
+ diff_g = np.linalg.norm(graph.nodes[dst]["g"] - graph.nodes[n]["g"]) / 2
+ diff_b = np.linalg.norm(graph.nodes[dst]["b"] - graph.nodes[n]["b"]) / 2
+ diff_hist = (diff_r + diff_g + diff_b) / 3
+
+ diff = self.w_hist * diff_hist + self.w_mean * diff_mean
+
+ return {"weight": diff}
+
+ def _merging_function(self, graph, src, dst):
+ graph.nodes[dst]["x"] += graph.nodes[src]["x"]
+ graph.nodes[dst]["N"] += graph.nodes[src]["N"]
+ graph.nodes[dst]["mean"] = graph.nodes[dst]["x"] / graph.nodes[dst]["N"]
+ graph.nodes[dst]["mean"] = graph.nodes[dst]["mean"] / np.linalg.norm(
+ graph.nodes[dst]["mean"]
+ )
+
+ graph.nodes[dst]["y"] = np.vstack(
+ (graph.nodes[dst]["y"], graph.nodes[src]["y"])
+ )
+ graph.nodes[dst]["r"] = self._color_features_per_channel(
+ graph.nodes[dst]["y"][:, 0]
+ )
+ graph.nodes[dst]["g"] = self._color_features_per_channel(
+ graph.nodes[dst]["y"][:, 1]
+ )
+ graph.nodes[dst]["b"] = self._color_features_per_channel(
+ graph.nodes[dst]["y"][:, 2]
+ )
+
+ graph.nodes[dst]["r"] = graph.nodes[dst]["r"] / np.linalg.norm(
+ graph.nodes[dst]["r"]
+ )
+ graph.nodes[dst]["g"] = graph.nodes[dst]["r"] / np.linalg.norm(
+ graph.nodes[dst]["g"]
+ )
+ graph.nodes[dst]["b"] = graph.nodes[dst]["r"] / np.linalg.norm(
+ graph.nodes[dst]["b"]
+ )
diff --git a/pathml/graph/utils.py b/pathml/graph/utils.py
new file mode 100644
index 00000000..bb12aeb3
--- /dev/null
+++ b/pathml/graph/utils.py
@@ -0,0 +1,253 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import importlib
+import math
+import os
+
+import matplotlib.pyplot as plt
+import networkx as nx
+import numpy as np
+import torch
+from skimage.measure import label, regionprops
+from torch_geometric.data import Data
+from torch_geometric.utils import remove_self_loops, to_edge_index, to_torch_csr_tensor
+
+MIN_NR_PIXELS = 50000
+MAX_NR_PIXELS = 50000000
+
+
+class Graph(Data):
+ """Constructs pytorch-geometric data object for saving and loading
+
+ Args:
+ node_centroids (torch.tensor): Coordinates of the centers of each entity (cell or tissue) in the graph
+ node_features (torch.tensor): Computed features of each entity (cell or tissue) in the graph
+ edge_index (torch.tensor): Edge index in sparse format between nodes in the graph
+ node_labels (torch.tensor): Node labels of each entity (cell or tissue) in the graph. Defaults to None.
+ target (torch.tensor): Target label if used in a supervised setting. Defaults to None.
+ """
+
+ def __init__(
+ self, node_centroids, node_features, edge_index, node_labels=None, target=None
+ ):
+ super().__init__()
+ self.node_centroids = node_centroids
+ self.node_features = node_features
+ self.edge_index = edge_index
+ self.node_labels = node_labels
+ self.target = target
+
+ def __inc__(self, key, value, *args, **kwargs):
+ if key == "edge_index":
+ return self.node_features.size(0)
+ elif key == "target":
+ return 0
+ else:
+ return super().__inc__(key, value, *args, **kwargs)
+
+
+class HACTPairData(Data):
+ """Constructs pytorch-geometric data object for handling both cell and tissue data
+
+ Args:
+ x_cell (torch.tensor): Computed features of each cell in the graph
+ edge_index_cell (torch.tensor): Edge index in sparse format between nodes in the cell graph
+ x_tissue (torch.tensor): Computed features of each tissue in the graph
+ edge_index_tissue (torch.tensor): Edge index in sparse format between nodes in the tissue graph
+ assignment (torch.tensor): Assigment matrix that contains mapping between cells and tissues.
+ target (torch.tensor): Target label if used in a supervised setting.
+ """
+
+ def __init__(
+ self, x_cell, edge_index_cell, x_tissue, edge_index_tissue, assignment, target
+ ):
+ super().__init__()
+ self.x_cell = x_cell
+ self.edge_index_cell = edge_index_cell
+
+ self.x_tissue = x_tissue
+ self.edge_index_tissue = edge_index_tissue
+
+ self.assignment = assignment
+ self.target = target
+
+ def __inc__(self, key, value, *args, **kwargs):
+ if key == "edge_index_cell":
+ return self.x_cell.size(0)
+ if key == "edge_index_tissue":
+ return self.x_tissue.size(0)
+ elif key == "assignment":
+ return self.x_tissue.size(0)
+ elif key == "target":
+ return 0
+ else:
+ return super().__inc__(key, value, *args, **kwargs)
+
+
+def dynamic_import_from(source_file: str, class_name: str):
+ """Do a from source_file import class_name dynamically
+
+ Args:
+ source_file (str): Where to import from
+ class_name (str): What to import
+ Returns:
+ Any: The class to be imported
+ """
+ module = importlib.import_module(source_file)
+ return getattr(module, class_name)
+
+
+def _valid_image(nr_pixels):
+ """
+ Checks if image does not exceed maximum number of pixels or exceeds minimum number of pixels.
+
+ Args:
+ nr_pixels (int): Number of pixels in given image
+ """
+
+ if nr_pixels > MIN_NR_PIXELS and nr_pixels < MAX_NR_PIXELS:
+ return True
+ return False
+
+
+def plot_graph_on_image(graph, image):
+ """
+ Plots a given graph on the original WSI image
+
+ Args:
+ graph (torch.tensor): Graph as an sparse edge index
+ image (numpy.array): Input image
+ """
+
+ from torch_geometric.utils.convert import to_networkx
+
+ pos = graph.node_centroids.numpy()
+ G = to_networkx(graph, to_undirected=True)
+ plt.imshow(image)
+ nx.draw(G, pos, node_size=25)
+ plt.show()
+
+
+def _exists(cg_out, tg_out, assign_out, overwrite):
+ """
+ Checks if given input files exist or not
+
+ Args:
+ cg_out (str): Cell graph file
+ tg_out (str): Tissue graph file
+ assign_out (str): Assignment matrix file
+ overwrite (bool): Whether to overwrite files or not. If true, this function return false and files are
+ overwritten.
+ """
+
+ if overwrite:
+ return False
+ else:
+ if (
+ os.path.isfile(cg_out)
+ and os.path.isfile(tg_out)
+ and os.path.isfile(assign_out)
+ ):
+ return True
+ return False
+
+
+def get_full_instance_map(wsi, patch_size, mask_name="cell"):
+ """
+ Generates and returns the normalized image, cell instance map and cell centroids from pathml SlideData object
+
+ Args:
+ wsi (pathml.core.SlideData): Normalized WSI object with detected cells in the 'masks' slot
+ patch_size (int): Patch size used for cell detection
+ mask_name (str): Name of the mask slot storing the detected cells. Defaults to 'cell'.
+
+ Returns:
+ The image in np.unint8 format, the instance map for the entity and the instance centroids for each entity in
+ the instance map as numpy arrays.
+ """
+
+ x = math.ceil(wsi.shape[0] / patch_size) * patch_size
+ y = math.ceil(wsi.shape[1] / patch_size) * patch_size
+ image_norm = np.zeros((x, y, 3))
+ instance_map = np.zeros((x, y))
+ for tile in wsi.tiles:
+ tx, ty = tile.coords
+ image_norm[tx : tx + patch_size, ty : ty + patch_size] = tile.image
+ instance_map[tx : tx + patch_size, ty : ty + patch_size] = tile.masks[
+ mask_name
+ ][:, :, 0]
+ image_norm = image_norm[: wsi.shape[0], : wsi.shape[1], :]
+ instance_map = instance_map[: wsi.shape[0], : wsi.shape[1]]
+ label_instance_map = label(instance_map)
+ regions = regionprops(label_instance_map)
+ instance_centroids = np.empty((len(regions), 2))
+ for i, region in enumerate(regions):
+ center_y, center_x = region.centroid # row, col
+ center_x = int(round(center_x))
+ center_y = int(round(center_y))
+ instance_centroids[i, 0] = center_x
+ instance_centroids[i, 1] = center_y
+ return image_norm.astype("uint8"), label_instance_map, instance_centroids
+
+
+def build_assignment_matrix(low_level_centroids, high_level_map, matrix=False):
+ """
+ Builds an assignment matrix/mapping between low-level centroid locations and a high-level segmentation map
+
+ Args:
+ low_level_centroids (numpy.array): The low-level centroid coordinates in x-y plane
+ high-level map (numpy.array): The high-level map returned from regionprops
+ matrix (bool): Whether to return in a matrix format. If True, returns a N*L matrix where N is the number of low-level
+ instances and L is the number of high-level instances. If False, returns this mapping in sparse format.
+ Defaults to False.
+
+ Returns:
+ The assignment matrix as a numpy array.
+ """
+
+ low_level_centroids = low_level_centroids.astype(int)
+ low_to_high = high_level_map[
+ low_level_centroids[:, 1], low_level_centroids[:, 0]
+ ].astype(int)
+ high_instance_ids = np.sort(np.unique(np.ravel(high_level_map))).astype(int)
+ if 0 in high_instance_ids:
+ high_instance_ids = np.delete(high_instance_ids, 0)
+ assignment_matrix = np.zeros((low_level_centroids.shape[0], len(high_instance_ids)))
+ assignment_matrix[np.arange(low_to_high.size), low_to_high - 1] = 1
+ if not matrix:
+ sparse_matrix = np.nonzero(assignment_matrix)
+ return np.array(sparse_matrix)
+ return assignment_matrix
+
+
+def compute_histogram(input_array: np.ndarray, nr_values: int) -> np.ndarray:
+ """Calculates a histogram of a matrix of the values from 0 up to (excluding) nr_values
+ Args:
+ x (np.array): Input tensor.
+ nr_values (int): Possible values. From 0 up to (exclusing) nr_values.
+
+ Returns:
+ np.array: Output tensor.
+ """
+ output_array = np.empty(nr_values, dtype=int)
+ for i in range(nr_values):
+ output_array[i] = (input_array == i).sum()
+ return output_array
+
+
+def two_hop(edge_index, num_nodes):
+ """Calculates the two-hop graph.
+ Args:
+ edge_index (torch.tensor): The edge index in sparse form of the graph.
+ num_nodes (int): maximum number of nodes.
+ Returns:
+ torch.tensor: Output edge index tensor.
+ """
+ adj = to_torch_csr_tensor(edge_index, size=(num_nodes, num_nodes))
+ edge_index2, _ = to_edge_index(adj @ adj)
+ edge_index2, _ = remove_self_loops(edge_index2)
+ edge_index = torch.cat([edge_index, edge_index2], dim=1)
+ return edge_index
diff --git a/pathml/ml/__init__.py b/pathml/ml/__init__.py
index 69d72914..2da14674 100644
--- a/pathml/ml/__init__.py
+++ b/pathml/ml/__init__.py
@@ -4,4 +4,6 @@
"""
from .dataset import TileDataset
-from .hovernet import HoVerNet, loss_hovernet, post_process_batch_hovernet
+from .layers import GNNLayer
+from .models.hactnet import HACTNet
+from .models.hovernet import HoVerNet, loss_hovernet, post_process_batch_hovernet
diff --git a/pathml/ml/hovernet.py b/pathml/ml/hovernet.py
index 7938c93c..5b34e7ab 100644
--- a/pathml/ml/hovernet.py
+++ b/pathml/ml/hovernet.py
@@ -9,7 +9,7 @@
import torch
from loguru import logger
from matplotlib.colors import TABLEAU_COLORS
-from scipy.ndimage.morphology import binary_fill_holes
+from scipy.ndimage import binary_fill_holes
from skimage.segmentation import watershed
from torch import nn
from torch.nn import functional as F
diff --git a/pathml/ml/layers.py b/pathml/ml/layers.py
new file mode 100644
index 00000000..3250e178
--- /dev/null
+++ b/pathml/ml/layers.py
@@ -0,0 +1,92 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import importlib
+
+import torch
+import torch.nn as nn
+from torch_geometric.nn.pool import global_mean_pool
+
+
+class GNNLayer(nn.Module):
+ """
+ GNN layer for processing graph structures.
+
+ Args:
+ layer (str): Type of torch_geometric GNN layer to be used.
+ See https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#convolutional-layers for
+ all available options.
+ in_channels (int): Number of input features supplied to the model.
+ hidden_channels (int): Number of hidden channels used in each layer of the GNN model.
+ num_layers (int): Number of message-passing layers in the model.
+ out_channels (int): Number of output features returned by the model.
+ readout_op (str): Readout operation to summarize features from each layer. Supports 'lstm' and 'concat'.
+ readout_type (str): Type of readout to aggregate node embeddings. Supports 'mean'.
+ kwargs (dict): Extra layer-specific arguments. Must have required keyword arguments of layer from
+ https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#convolutional-layers.
+ """
+
+ def __init__(
+ self,
+ layer,
+ in_channels,
+ hidden_channels,
+ num_layers,
+ out_channels,
+ readout_op,
+ readout_type,
+ kwargs,
+ ):
+ super().__init__()
+ self.convs = nn.ModuleList()
+ self.batch_norms = nn.ModuleList()
+ self.readout_type = readout_type
+ self.readout_op = readout_op
+
+ # Import user-specified GNN layer from pytorch-geometric
+ conv_module = importlib.import_module("torch_geometric.nn.conv")
+ module = getattr(conv_module, layer)
+
+ # Make multi-layered GNN using imported GNN layer
+ self.convs.append(module(in_channels, hidden_channels, **kwargs))
+ self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
+ for _ in range(1, num_layers - 1):
+ conv = module(hidden_channels, hidden_channels, **kwargs)
+ self.convs.append(conv)
+ self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
+ self.convs.append(module(hidden_channels, out_channels, **kwargs))
+ self.batch_norms.append(nn.BatchNorm1d(out_channels))
+
+ # Define readout operation if using LSTM readout
+ if readout_op == "lstm":
+ self.lstm = nn.LSTM(
+ out_channels,
+ (num_layers * out_channels) // 2,
+ bidirectional=True,
+ batch_first=True,
+ )
+ self.att = nn.Linear(2 * ((num_layers * out_channels) // 2), 1)
+
+ def forward(self, x, edge_index, batch, with_readout=True):
+ h = []
+ x = x.float()
+ for norm, conv in zip(self.batch_norms, self.convs):
+ x = conv(x, edge_index)
+ x = norm(x)
+ h.append(x)
+ if self.readout_op == "concat":
+ out = torch.cat(h, dim=-1)
+ elif self.readout_op == "lstm":
+ x = torch.stack(h, dim=1)
+ alpha, _ = self.lstm(x)
+ alpha = self.att(alpha).squeeze(-1)
+ alpha = torch.softmax(alpha, dim=-1)
+ out = (x * alpha.unsqueeze(-1)).sum(dim=1)
+ else:
+ out = h[-1]
+ if with_readout:
+ if self.readout_type == "mean":
+ out = global_mean_pool(out, batch)
+ return out
diff --git a/pathml/ml/models/hactnet.py b/pathml/ml/models/hactnet.py
new file mode 100644
index 00000000..8a13cd0d
--- /dev/null
+++ b/pathml/ml/models/hactnet.py
@@ -0,0 +1,85 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import torch
+import torch.nn as nn
+from torch_geometric.nn.models import MLP
+
+from pathml.ml.layers import GNNLayer
+from pathml.ml.utils import scatter_sum
+
+
+class HACTNet(nn.Module):
+ """
+ Hierarchical cell-to-tissue model for supervised prediction using cell and tissue graphs.
+
+ Args:
+ cell_params (dict): Dictionary containing parameters for cell graph GNN.
+ tissue_params (dict): Dictionary containing parameters for tissue graph GNN.
+ classifier_params (dict): Dictionary containing parameters for prediction MLP.
+
+ References:
+ Pati, P., Jaume, G., Foncubierta-Rodriguez, A., Feroce, F., Anniciello, A.M., Scognamiglio, G., Brancati, N., Fiche, M.,
+ Dubruc, E., Riccio, D. and Di Bonito, M., 2022. Hierarchical graph representations in digital pathology. Medical image
+ analysis, 75, p.102264.
+
+ """
+
+ def __init__(self, cell_params, tissue_params, classifier_params):
+ super().__init__()
+
+ # Get cell and tissue graph readout operations
+ self.cell_readout_op = cell_params["readout_op"]
+ self.tissue_readout_op = tissue_params["readout_op"]
+
+ # Modify tissue GNN parameters
+ if self.cell_readout_op == "concat":
+ tissue_params["in_channels"] = (
+ tissue_params["in_channels"]
+ + cell_params["out_channels"] * cell_params["num_layers"]
+ )
+ else:
+ tissue_params["in_channels"] = (
+ tissue_params["in_channels"] + cell_params["out_channels"]
+ )
+
+ # Main GNN model for cell and tissue graphs
+ self.cell_gnn = GNNLayer(**cell_params)
+ self.tissue_gnn = GNNLayer(**tissue_params)
+
+ # Modify classifier parameters
+ if self.tissue_readout_op == "concat":
+ classifier_params["in_channels"] = (
+ tissue_params["out_channels"] * tissue_params["num_layers"]
+ )
+ else:
+ classifier_params["in_channels"] = tissue_params["out_channels"]
+
+ # Main classifier head
+ self.classifier = MLP(**classifier_params)
+
+ def forward(self, batch):
+
+ x_cell = batch.x_cell
+ x_tissue = batch.x_tissue
+
+ z_cell = self.cell_gnn(
+ x_cell, batch.edge_index_cell, batch.x_cell_batch, with_readout=False
+ )
+
+ out = torch.zeros(
+ (x_tissue.shape[0], z_cell.shape[1]),
+ dtype=z_cell.dtype,
+ device=z_cell.device,
+ )
+
+ z_cell_to_tissue = scatter_sum(z_cell, batch.assignment, dim=0, out=out)
+ x_tissue = torch.cat((z_cell_to_tissue, x_tissue), dim=1)
+
+ z_tissue = self.tissue_gnn(
+ x_tissue, batch.edge_index_tissue, batch.x_tissue_batch
+ )
+ out = self.classifier(z_tissue)
+ return out
diff --git a/pathml/ml/models/hovernet.py b/pathml/ml/models/hovernet.py
new file mode 100644
index 00000000..5b34e7ab
--- /dev/null
+++ b/pathml/ml/models/hovernet.py
@@ -0,0 +1,898 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from loguru import logger
+from matplotlib.colors import TABLEAU_COLORS
+from scipy.ndimage import binary_fill_holes
+from skimage.segmentation import watershed
+from torch import nn
+from torch.nn import functional as F
+
+from pathml.ml.utils import center_crop_im_batch, dice_loss, get_sobel_kernels
+from pathml.utils import segmentation_lines
+
+
+class _BatchNormRelu(nn.Module):
+ """BatchNorm + Relu layer"""
+
+ def __init__(self, n_channels):
+ super(_BatchNormRelu, self).__init__()
+ self.batch_norm = nn.BatchNorm2d(n_channels)
+ self.relu = nn.ReLU()
+
+ def forward(self, inputs):
+ return self.relu(self.batch_norm(inputs))
+
+
+class _HoVerNetResidualUnit(nn.Module):
+ """
+ Residual unit.
+ See: Fig. 2(a) from Graham et al. 2019 HoVer-Net paper.
+ This unit is not preactivated! That's handled when assembling units into blocks.
+ output_channels corresponds to m in the figure
+ """
+
+ def __init__(self, input_channels, output_channels, stride):
+ super(_HoVerNetResidualUnit, self).__init__()
+ internal_channels = output_channels // 4
+ if stride != 1 or input_channels != output_channels:
+ self.convshortcut = nn.Conv2d(
+ input_channels,
+ output_channels,
+ kernel_size=1,
+ stride=stride,
+ padding=0,
+ dilation=1,
+ bias=False,
+ )
+ else:
+ self.convshortcut = None
+ self.conv1 = nn.Conv2d(
+ input_channels, internal_channels, kernel_size=1, bias=False
+ )
+ self.bnrelu1 = _BatchNormRelu(internal_channels)
+ self.conv2 = nn.Conv2d(
+ internal_channels,
+ internal_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=False,
+ )
+ self.bnrelu2 = _BatchNormRelu(internal_channels)
+ self.conv3 = nn.Conv2d(
+ internal_channels, output_channels, kernel_size=1, bias=False
+ )
+
+ def forward(self, inputs):
+ skip = self.convshortcut(inputs) if self.convshortcut else inputs
+ out = self.conv1(inputs)
+ out = self.bnrelu1(out)
+ out = self.conv2(out)
+ out = self.bnrelu2(out)
+ out = self.conv3(out)
+ out = out + skip
+ return out
+
+
+def _make_HoVerNet_residual_block(input_channels, output_channels, stride, n_units):
+ """
+ Stack multiple residual units into a block.
+ output_channels is given as m in Fig. 2 from Graham et al. 2019 paper
+ """
+ units = []
+ # first unit in block is different
+ units.append(_HoVerNetResidualUnit(input_channels, output_channels, stride))
+
+ for i in range(n_units - 1):
+ units.append(_HoVerNetResidualUnit(output_channels, output_channels, stride=1))
+ # add a final activation ('preact' for the next unit)
+ # This is different from how authors implemented - they added BNRelu before all units except the first, plus
+ # a final one at the end.
+ # I think this is equivalent to just adding a BNRelu after each unit
+ units.append(_BatchNormRelu(output_channels))
+
+ return nn.Sequential(*units)
+
+
+class _HoVerNetEncoder(nn.Module):
+ """
+ Encoder for HoVer-Net.
+ 7x7 conv, then four residual blocks, then 1x1 conv.
+ BatchNormRelu after first convolution, based on code from authors, see:
+ (https://github.com/vqdang/hover_net/blob/5d1560315a3de8e7d4c8122b97b1fe9b9513910b/src/model/graph.py#L67)
+
+ Reuturn a list of the outputs from each residual block, for later skip connections
+ """
+
+ def __init__(self):
+ super(_HoVerNetEncoder, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=3)
+ self.bnrelu1 = _BatchNormRelu(64)
+ self.block1 = _make_HoVerNet_residual_block(
+ input_channels=64, output_channels=256, stride=1, n_units=3
+ )
+ self.block2 = _make_HoVerNet_residual_block(
+ input_channels=256, output_channels=512, stride=2, n_units=4
+ )
+ self.block3 = _make_HoVerNet_residual_block(
+ input_channels=512, output_channels=1024, stride=2, n_units=6
+ )
+ self.block4 = _make_HoVerNet_residual_block(
+ input_channels=1024, output_channels=2048, stride=2, n_units=3
+ )
+ self.conv2 = nn.Conv2d(
+ in_channels=2048, out_channels=1024, kernel_size=1, padding=0
+ )
+
+ def forward(self, inputs):
+ out1 = self.conv1(inputs)
+ out1 = self.bnrelu1(out1)
+ out1 = self.block1(out1)
+ out2 = self.block2(out1)
+ out3 = self.block3(out2)
+ out4 = self.block4(out3)
+ out4 = self.conv2(out4)
+ return [out1, out2, out3, out4]
+
+
+class _HoVerNetDenseUnit(nn.Module):
+ """
+ Dense unit.
+ See: Fig. 2(b) from Graham et al. 2019 HoVer-Net paper.
+ """
+
+ def __init__(self, input_channels):
+ super(_HoVerNetDenseUnit, self).__init__()
+ self.bnrelu1 = _BatchNormRelu(input_channels)
+ self.conv1 = nn.Conv2d(
+ in_channels=input_channels, out_channels=128, kernel_size=1
+ )
+ self.bnrelu2 = _BatchNormRelu(128)
+ self.conv2 = nn.Conv2d(
+ in_channels=128, out_channels=32, kernel_size=5, padding=2
+ )
+
+ def forward(self, inputs):
+ out = self.bnrelu1(inputs)
+ out = self.conv1(out)
+ out = self.bnrelu2(out)
+ out = self.conv2(out)
+
+ # need to make sure that inputs have same shape as out, so that we can concat
+ cropdims = (inputs.size(2) - out.size(2), inputs.size(3) - out.size(3))
+ inputs_cropped = center_crop_im_batch(inputs, dims=cropdims)
+ out = torch.cat((inputs_cropped, out), dim=1)
+ return out
+
+
+def _make_HoVerNet_dense_block(input_channels, n_units):
+ """
+ Stack multiple dense units into a block.
+ """
+ units = []
+ in_dim = input_channels
+ for i in range(n_units):
+ units.append(_HoVerNetDenseUnit(in_dim))
+ in_dim += 32
+ units.append(_BatchNormRelu(in_dim))
+ return nn.Sequential(*units)
+
+
+class _HoverNetDecoder(nn.Module):
+ """
+ One of the three identical decoder branches.
+ """
+
+ def __init__(self):
+ super(_HoverNetDecoder, self).__init__()
+ self.upsample1 = nn.Upsample(scale_factor=2)
+ self.conv1 = nn.Conv2d(
+ in_channels=1024,
+ out_channels=256,
+ kernel_size=5,
+ padding=2,
+ stride=1,
+ bias=False,
+ )
+ self.dense1 = _make_HoVerNet_dense_block(input_channels=256, n_units=8)
+ self.conv2 = nn.Conv2d(
+ in_channels=512, out_channels=512, kernel_size=1, stride=1, bias=False
+ )
+ self.upsample2 = nn.Upsample(scale_factor=2)
+ self.conv3 = nn.Conv2d(
+ in_channels=512,
+ out_channels=128,
+ kernel_size=5,
+ padding=2,
+ stride=1,
+ bias=False,
+ )
+ self.dense2 = _make_HoVerNet_dense_block(input_channels=128, n_units=4)
+
+ self.conv4 = nn.Conv2d(
+ in_channels=256, out_channels=256, kernel_size=1, stride=1, bias=False
+ )
+ self.upsample3 = nn.Upsample(scale_factor=2)
+ self.conv5 = nn.Conv2d(
+ in_channels=256,
+ out_channels=64,
+ kernel_size=5,
+ stride=1,
+ bias=False,
+ padding=2,
+ )
+
+ def forward(self, inputs):
+ """
+ Inputs should be a list of the outputs from each residual block, so that we can use skip connections
+ """
+ block1_out, block2_out, block3_out, block4_out = inputs
+ out = self.upsample1(block4_out)
+ # skip connection addition
+ out = out + block3_out
+ out = self.conv1(out)
+ out = self.dense1(out)
+ out = self.conv2(out)
+ out = self.upsample2(out)
+ # skip connection
+ out = out + block2_out
+ out = self.conv3(out)
+ out = self.dense2(out)
+ out = self.conv4(out)
+ out = self.upsample3(out)
+ # last skip connection
+ out = out + block1_out
+ out = self.conv5(out)
+ return out
+
+
+class HoVerNet(nn.Module):
+ """
+ Model for simultaneous segmentation and classification based on HoVer-Net.
+ Can also be used for segmentation only, if class labels are not supplied.
+ Each branch returns logits.
+
+ Args:
+ n_classes (int): Number of classes for classification task. If ``None`` then the classification branch is not
+ used.
+
+ References:
+ Graham, S., Vu, Q.D., Raza, S.E.A., Azam, A., Tsang, Y.W., Kwak, J.T. and Rajpoot, N., 2019.
+ Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images.
+ Medical Image Analysis, 58, p.101563.
+ """
+
+ def __init__(self, n_classes=None):
+ super().__init__()
+ self.n_classes = n_classes
+ self.encoder = _HoVerNetEncoder()
+
+ # NP branch (nuclear pixel)
+ self.np_branch = _HoverNetDecoder()
+ # classification head
+ self.np_head = nn.Sequential(
+ # two channels in output - background prob and pixel prob
+ nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
+ )
+
+ # HV branch (horizontal vertical)
+ self.hv_branch = _HoverNetDecoder() # hv = horizontal vertical
+ # classification head
+ self.hv_head = nn.Sequential(
+ # two channels in output - horizontal and vertical
+ nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)
+ )
+
+ # NC branch (nuclear classification)
+ # If n_classes is none, then we are in nucleus detection, not classification, so we don't use this branch
+ if self.n_classes is not None:
+ self.nc_branch = _HoverNetDecoder()
+ # classification head
+ self.nc_head = nn.Sequential(
+ # one channel in output for each class
+ nn.Conv2d(in_channels=64, out_channels=self.n_classes, kernel_size=1)
+ )
+
+ def forward(self, inputs):
+ encoded = self.encoder(inputs)
+
+ """for i, block_output in enumerate(encoded):
+ print(f"block {i} output shape: {block_output.shape}")"""
+
+ out_np = self.np_branch(encoded)
+ out_np = self.np_head(out_np)
+
+ out_hv = self.hv_branch(encoded)
+ out_hv = self.hv_head(out_hv)
+
+ outputs = [out_np, out_hv]
+
+ if self.n_classes is not None:
+ out_nc = self.nc_branch(encoded)
+ out_nc = self.nc_head(out_nc)
+ outputs.append(out_nc)
+
+ return outputs
+
+
+# loss functions and associated utils
+
+
+def _convert_multiclass_mask_to_binary(mask):
+ """
+ Input mask of shape (B, n_classes, H, W) is converted to a mask of shape (B, 1, H, W).
+ The last channel is assumed to be background, so the binary mask is computed by taking its inverse.
+ """
+ m = torch.tensor(1) - mask[:, -1, :, :]
+ m = m.unsqueeze(dim=1)
+ return m
+
+
+def _dice_loss_np_head(np_out, true_mask, epsilon=1e-3):
+ """
+ Dice loss term for nuclear pixel branch.
+ This will compute dice loss for the entire batch
+ (not the same as computing dice loss for each image and then averaging!)
+
+ Args:
+ np_out: logit outputs of np branch. Tensor of shape (B, 2, H, W)
+ true_mask: True mask. Tensor of shape (B, n_classes, H, W)
+ epsilon (float): Epsilon passed to ``dice_loss()``
+ """
+ # get logits for only the channel corresponding to prediction of 1
+ # unsqueeze to keep the dimensions the same
+ preds = np_out[:, 1, :, :].unsqueeze(dim=1)
+
+ true_mask = _convert_multiclass_mask_to_binary(true_mask)
+ true_mask = true_mask.type(torch.long)
+ loss = dice_loss(logits=preds, true=true_mask, eps=epsilon)
+ return loss
+
+
+def _dice_loss_nc_head(nc_out, true_mask, epsilon=1e-3):
+ """
+ Dice loss term for nuclear classification branch.
+ Computes dice loss for each channel, and sums up.
+ This will compute dice loss for the entire batch
+ (not the same as computing dice loss for each image and then averaging!)
+
+ Args:
+ nc_out: logit outputs of nc branch. Tensor of shape (B, n_classes, H, W)
+ true_mask: True mask. Tensor of shape (B, n_classes, H, W)
+ epsilon (float): Epsilon passed to ``dice_loss()``
+ """
+ truth = torch.argmax(true_mask, dim=1, keepdim=True).type(torch.long)
+ loss = dice_loss(logits=nc_out, true=truth, eps=epsilon)
+ return loss
+
+
+def _ce_loss_nc_head(nc_out, true_mask):
+ """
+ Cross-entropy loss term for nc branch.
+ Args:
+ nc_out: logit outputs of nc branch. Tensor of shape (B, n_classes, H, W)
+ true_mask: True mask. Tensor of shape (B, n_classes, H, W)
+ """
+ truth = torch.argmax(true_mask, dim=1).type(torch.long)
+ ce = nn.CrossEntropyLoss()
+ loss = ce(nc_out, truth)
+ return loss
+
+
+def _ce_loss_np_head(np_out, true_mask):
+ """
+ Cross-entropy loss term for np branch.
+ Args:
+ np_out: logit outputs of np branch. Tensor of shape (B, 2, H, W)
+ true_mask: True mask. Tensor of shape (B, n_classes, H, W)
+ """
+ truth = (
+ _convert_multiclass_mask_to_binary(true_mask).type(torch.long).squeeze(dim=1)
+ )
+ ce = nn.CrossEntropyLoss()
+ loss = ce(np_out, truth)
+ return loss
+
+
+def compute_hv_map(mask):
+ """
+ Preprocessing step for HoVer-Net architecture.
+ Compute center of mass for each nucleus, then compute distance of each nuclear pixel to its corresponding center
+ of mass.
+ Nuclear pixel distances are normalized to (-1, 1). Background pixels are left as 0.
+ Operates on a single mask.
+ Can be used in Dataset object to make Dataloader compatible with HoVer-Net.
+
+ Based on https://github.com/vqdang/hover_net/blob/195ed9b6cc67b12f908285492796fb5c6c15a000/src/loader/augs.py#L192
+
+ Args:
+ mask (np.ndarray): Mask indicating individual nuclei. Array of shape (H, W),
+ where each pixel is in {0, ..., n} with 0 indicating background pixels and {1, ..., n} indicating
+ n unique nuclei.
+
+ Returns:
+ np.ndarray: array of hv maps of shape (2, H, W). First channel corresponds to horizontal and second vertical.
+ """
+ assert (
+ mask.ndim == 2
+ ), f"Input mask has shape {mask.shape}. Expecting a mask with 2 dimensions (H, W)"
+
+ out = np.zeros((2, mask.shape[0], mask.shape[1]))
+ # each individual nucleus is indexed with a different number
+ inst_list = list(np.unique(mask))
+
+ try:
+ inst_list.remove(0) # 0 is background
+ # TODO: change to specific exception
+ except Exception:
+ logger.warning(
+ "No pixels with 0 label. This means that there are no background pixels. This may indicate a problem. Ignore this warning if this is expected/intended."
+ )
+
+ for inst_id in inst_list:
+ # get the mask for the nucleus
+ inst_map = mask == inst_id
+ inst_map = inst_map.astype(np.uint8)
+ contours, _ = cv2.findContours(
+ inst_map, mode=cv2.RETR_LIST, method=cv2.CHAIN_APPROX_NONE
+ )
+
+ # get center of mass coords
+ mom = cv2.moments(contours[0])
+ com_x = mom["m10"] / (mom["m00"] + 1e-6)
+ com_y = mom["m01"] / (mom["m00"] + 1e-6)
+ inst_com = (int(com_y), int(com_x))
+
+ inst_x_range = np.arange(1, inst_map.shape[1] + 1)
+ inst_y_range = np.arange(1, inst_map.shape[0] + 1)
+ # shifting center of pixels grid to instance center of mass
+ inst_x_range -= inst_com[1]
+ inst_y_range -= inst_com[0]
+
+ inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)
+
+ # remove coord outside of instance
+ inst_x[inst_map == 0] = 0
+ inst_y[inst_map == 0] = 0
+ inst_x = inst_x.astype("float32")
+ inst_y = inst_y.astype("float32")
+
+ # normalize min into -1 scale
+ if np.min(inst_x) < 0:
+ inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
+ if np.min(inst_y) < 0:
+ inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
+ # normalize max into +1 scale
+ if np.max(inst_x) > 0:
+ inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
+ if np.max(inst_y) > 0:
+ inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])
+
+ # add to output mask
+ # this works assuming background is 0, and each pixel is assigned to only one nucleus.
+ out[0, :, :] += inst_x
+ out[1, :, :] += inst_y
+ return out
+
+
+def _get_gradient_hv(hv_batch, kernel_size=5):
+ """
+ Calculate the horizontal partial differentiation for horizontal channel
+ and the vertical partial differentiation for vertical channel.
+ The partial differentiation is approximated by calculating the central differnce
+ which is obtained by using Sobel kernel of size 5x5. The boundary is zero-padded
+ when channel is convolved with the Sobel kernel.
+
+ Args:
+ hv_batch: tensor of shape (B, 2, H, W). Channel index 0 for horizonal maps and 1 for vertical maps.
+ These maps are distance from each nuclear pixel to center of mass of corresponding nucleus.
+ kernel_size (int): width of kernel to use for gradient approximation.
+
+ Returns:
+ Tuple of (h_grad, v_grad) where each is a Tensor giving horizontal and vertical gradients respectively
+ """
+ assert (
+ hv_batch.shape[1] == 2
+ ), f"inputs have shape {hv_batch.shape}. Expecting tensor of shape (B, 2, H, W)"
+ h_kernel, v_kernel = get_sobel_kernels(kernel_size, dt=hv_batch.dtype)
+
+ # move kernels to same device as batch
+ h_kernel = h_kernel.to(hv_batch.device)
+ v_kernel = v_kernel.to(hv_batch.device)
+
+ # add extra dims so we can convolve with a batch
+ h_kernel = h_kernel.unsqueeze(0).unsqueeze(0)
+ v_kernel = v_kernel.unsqueeze(0).unsqueeze(0)
+
+ # get the inputs for the h and v channels
+ h_inputs = hv_batch[:, 0, :, :].unsqueeze(dim=1)
+ v_inputs = hv_batch[:, 1, :, :].unsqueeze(dim=1)
+
+ h_grad = F.conv2d(h_inputs, h_kernel, stride=1, padding=2)
+ v_grad = F.conv2d(v_inputs, v_kernel, stride=1, padding=2)
+
+ del h_kernel
+ del v_kernel
+
+ return h_grad, v_grad
+
+
+def _loss_hv_grad(hv_out, true_hv, nucleus_pixel_mask):
+ """
+ Equation 3 from HoVer-Net paper for calculating loss for HV predictions.
+ Mask is used to compute the hv loss ONLY for nuclear pixels
+
+ Args:
+ hv_out: Ouput of hv branch. Tensor of shape (B, 2, H, W)
+ true_hv: Ground truth hv maps. Tensor of shape (B, 2, H, W)
+ nucleus_pixel_mask: Boolean mask indicating nuclear pixels. Tensor of shape (B, H, W)
+ """
+ pred_grad_h, pred_grad_v = _get_gradient_hv(hv_out)
+ true_grad_h, true_grad_v = _get_gradient_hv(true_hv)
+
+ # pull out only the values from nuclear pixels
+ pred_h = torch.masked_select(pred_grad_h, mask=nucleus_pixel_mask)
+ true_h = torch.masked_select(true_grad_h, mask=nucleus_pixel_mask)
+ pred_v = torch.masked_select(pred_grad_v, mask=nucleus_pixel_mask)
+ true_v = torch.masked_select(true_grad_v, mask=nucleus_pixel_mask)
+
+ loss_h = F.mse_loss(pred_h, true_h)
+ loss_v = F.mse_loss(pred_v, true_v)
+
+ loss = loss_h + loss_v
+ return loss
+
+
+def _loss_hv_mse(hv_out, true_hv):
+ """
+ Equation 2 from HoVer-Net paper for calculating loss for HV predictions.
+
+ Args:
+ hv_out: Ouput of hv branch. Tensor of shape (B, 2, H, W)
+ true_hv: Ground truth hv maps. Tensor of shape (B, 2, H, W)
+ """
+ loss = F.mse_loss(hv_out, true_hv)
+ return loss
+
+
+def loss_hovernet(outputs, ground_truth, n_classes=None):
+ """
+ Compute loss for HoVer-Net.
+ Equation (1) in Graham et al.
+
+ Args:
+ outputs: Output of HoVer-Net. Should be a list of [np, hv] if n_classes is None, or a list of [np, hv, nc] if
+ n_classes is not None.
+ Shapes of each should be:
+
+ - np: (B, 2, H, W)
+ - hv: (B, 2, H, W)
+ - nc: (B, n_classes, H, W)
+
+ ground_truth: True labels. Should be a list of [mask, hv], where mask is a Tensor of shape (B, 1, H, W)
+ if n_classes is ``None`` or (B, n_classes, H, W) if n_classes is not ``None``.
+ hv is a tensor of precomputed horizontal and vertical distances
+ of nuclear pixels to their corresponding centers of mass, and is of shape (B, 2, H, W).
+ n_classes (int): Number of classes for classification task. If ``None`` then the classification branch is not
+ used.
+
+ References:
+ Graham, S., Vu, Q.D., Raza, S.E.A., Azam, A., Tsang, Y.W., Kwak, J.T. and Rajpoot, N., 2019.
+ Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images.
+ Medical Image Analysis, 58, p.101563.
+ """
+ true_mask, true_hv = ground_truth
+ # unpack outputs, and also calculate nucleus masks
+ if n_classes is None:
+ np_out, hv = outputs
+ nucleus_mask = true_mask[:, 0, :, :] == 1
+ else:
+ np_out, hv, nc = outputs
+ # in multiclass setting, last channel of masks indicates background, so
+ # invert that to get a nucleus mask (Based on convention from PanNuke dataset)
+ nucleus_mask = true_mask[:, -1, :, :] == 0
+
+ # from Eq. 1 in HoVer-Net paper, loss function is composed of two terms for each branch.
+ np_loss_dice = _dice_loss_np_head(np_out, true_mask)
+ np_loss_ce = _ce_loss_np_head(np_out, true_mask)
+
+ hv_loss_grad = _loss_hv_grad(hv, true_hv, nucleus_mask)
+ hv_loss_mse = _loss_hv_mse(hv, true_hv)
+
+ # authors suggest using coefficient of 2 for hv gradient loss term
+ hv_loss_grad = 2 * hv_loss_grad
+
+ if n_classes is not None:
+ nc_loss_dice = _dice_loss_nc_head(nc, true_mask)
+ nc_loss_ce = _ce_loss_nc_head(nc, true_mask)
+ else:
+ nc_loss_dice = 0
+ nc_loss_ce = 0
+
+ loss = (
+ np_loss_dice
+ + np_loss_ce
+ + hv_loss_mse
+ + hv_loss_grad
+ + nc_loss_dice
+ + nc_loss_ce
+ )
+ return loss
+
+
+# Post-processing of HoVer-Net outputs
+
+
+def remove_small_objs(array_in, min_size):
+ """
+ Removes small foreground regions from binary array, leaving only the contiguous regions which are above
+ the size threshold. Pixels in regions below the size threshold are zeroed out.
+
+ Args:
+ array_in (np.ndarray): Input array. Must be binary array with dtype=np.uint8.
+ min_size (int): Minimum size of each region.
+
+ Returns:
+ np.ndarray: Array of labels for regions above the threshold. Each separate contiguous region is labelled with
+ a different integer from 1 to n, where n is the number of total distinct contiguous regions
+ """
+ assert (
+ array_in.dtype == np.uint8
+ ), f"Input dtype is {array_in.dtype}. Must be np.uint8"
+ # remove elements below size threshold
+ # each contiguous nucleus region gets a unique id
+ n_labels, labels = cv2.connectedComponents(array_in)
+ # each integer is a different nucleus, so bincount gives nucleus sizes
+ sizes = np.bincount(labels.flatten())
+ for nucleus_ix, size_ix in zip(range(n_labels), sizes):
+ if size_ix < min_size:
+ # below size threshold - set all to zero
+ labels[labels == nucleus_ix] = 0
+ return labels
+
+
+def _post_process_single_hovernet(
+ np_out, hv_out, small_obj_size_thresh=10, kernel_size=21, h=0.5, k=0.5
+):
+ """
+ Combine predictions of np channel and hv channel to create final predictions.
+ Works by creating energy landscape from gradients, and the applying watershed segmentation.
+ This function works on a single image and is wrapped in ``post_process_batch_hovernet()`` to apply across a batch.
+ See: Section B of HoVer-Net article and
+ https://github.com/vqdang/hover_net/blob/14c5996fa61ede4691e87905775e8f4243da6a62/models/hovernet/post_proc.py#L27
+
+ Args:
+ np_out (torch.Tensor): Output of NP branch. Tensor of shape (2, H, W) of logit predictions for binary classification
+ hv_out (torch.Tensor): Output of HV branch. Tensor of shape (2, H, W) of predictions for horizontal/vertical maps
+ small_obj_size_thresh (int): Minimum number of pixels in regions. Defaults to 10.
+ kernel_size (int): Width of Sobel kernel used to compute horizontal and vertical gradients.
+ h (float): hyperparameter for thresholding nucleus probabilities. Defaults to 0.5.
+ k (float): hyperparameter for thresholding energy landscape to create markers for watershed
+ segmentation. Defaults to 0.5.
+ """
+ # compute pixel probabilities from logits, apply threshold, and get into np array
+ np_preds = F.softmax(np_out, dim=0)[1, :, :]
+ np_preds = np_preds.numpy()
+
+ np_preds[np_preds >= h] = 1
+ np_preds[np_preds < h] = 0
+ np_preds = np_preds.astype(np.uint8)
+
+ np_preds = remove_small_objs(np_preds, min_size=small_obj_size_thresh)
+ # Back to binary. now np_preds corresponds to tau(q, h) from HoVer-Net paper
+ np_preds[np_preds > 0] = 1
+ tau_q_h = np_preds
+
+ # normalize hv predictions, and compute horizontal and vertical gradients, and normalize again
+ hv_out = hv_out.numpy().astype(np.float32)
+ h_out = hv_out[0, ...]
+ v_out = hv_out[1, ...]
+ # https://stackoverflow.com/a/39037135
+ h_normed = cv2.normalize(
+ h_out, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
+ )
+ v_normed = cv2.normalize(
+ v_out, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
+ )
+
+ h_grad = cv2.Sobel(h_normed, cv2.CV_64F, dx=1, dy=0, ksize=kernel_size)
+ v_grad = cv2.Sobel(v_normed, cv2.CV_64F, dx=0, dy=1, ksize=kernel_size)
+
+ h_grad = cv2.normalize(
+ h_grad, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
+ )
+ v_grad = cv2.normalize(
+ v_grad, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
+ )
+
+ # flip the gradient direction so that highest values are steepest gradient
+ h_grad = 1 - h_grad
+ v_grad = 1 - v_grad
+
+ S_m = np.maximum(h_grad, v_grad)
+ S_m[tau_q_h == 0] = 0
+ # energy landscape
+ # note that the paper says that they use E = (1 - tau(S_m, k)) * tau(q, h)
+ # but in the authors' code the actually use: E = (1 - S_m) * tau(q, h)
+ # this actually makes more sense because no need to threshold the energy surface
+ energy = (1.0 - S_m) * tau_q_h
+
+ # get markers
+ # In the paper it says they use M = sigma(tau(q, h) - tau(S_m, k))
+ # But it makes more sense to threshold the energy landscape to get the peaks of hills.
+ # Also, the fact they used sigma in the paper makes me think that this is what they intended,
+ m = np.array(energy >= k, dtype=np.uint8)
+ m = binary_fill_holes(m).astype(np.uint8)
+ m = remove_small_objs(m, min_size=small_obj_size_thresh)
+
+ # nuclei values form mountains so inverse to get basins for watershed
+ energy = -cv2.GaussianBlur(energy, (3, 3), 0)
+ out = watershed(image=energy, markers=m, mask=tau_q_h)
+
+ return out
+
+
+def post_process_batch_hovernet(
+ outputs, n_classes, small_obj_size_thresh=10, kernel_size=21, h=0.5, k=0.5
+):
+ """
+ Post-process HoVer-Net outputs to get a final predicted mask.
+ See: Section B of HoVer-Net article and
+ https://github.com/vqdang/hover_net/blob/14c5996fa61ede4691e87905775e8f4243da6a62/models/hovernet/post_proc.py#L27
+
+ Args:
+ outputs (list): Outputs of HoVer-Net model. List of [np_out, hv_out], or [np_out, hv_out, nc_out]
+ depending on whether model is predicting classification or not.
+
+ - np_out is a Tensor of shape (B, 2, H, W) of logit predictions for binary classification
+ - hv_out is a Tensor of shape (B, 2, H, W) of predictions for horizontal/vertical maps
+ - nc_out is a Tensor of shape (B, n_classes, H, W) of logits for classification
+
+ n_classes (int): Number of classes for classification task. If ``None`` then only segmentation is performed.
+ small_obj_size_thresh (int): Minimum number of pixels in regions. Defaults to 10.
+ kernel_size (int): Width of Sobel kernel used to compute horizontal and vertical gradients.
+ h (float): hyperparameter for thresholding nucleus probabilities. Defaults to 0.5.
+ k (float): hyperparameter for thresholding energy landscape to create markers for watershed
+ segmentation. Defaults to 0.5.
+
+ Returns:
+ np.ndarray: If n_classes is None, returns det_out. In classification setting, returns (det_out, class_out).
+
+ - det_out is np.ndarray of shape (B, H, W)
+ - class_out is np.ndarray of shape (B, n_classes, H, W)
+
+ Each pixel is labelled from 0 to n, where n is the number of individual nuclei detected. 0 pixels indicate
+ background. Pixel values i indicate that the pixel belongs to the ith nucleus.
+ """
+
+ assert len(outputs) in {2, 3}, (
+ f"outputs has size {len(outputs)}. Must have size 2 (for segmentation) or 3 (for "
+ f"classification)"
+ )
+ if n_classes is None:
+ np_out, hv_out = outputs
+ # send ouputs to cpu
+ np_out = np_out.detach().cpu()
+ hv_out = hv_out.detach().cpu()
+ classification = False
+ else:
+ assert len(outputs) == 3, (
+ f"n_classes={n_classes} but outputs has {len(outputs)} elements. Expecting a list "
+ f"of length 3, one for each of np, hv, and nc branches"
+ )
+ np_out, hv_out, nc_out = outputs
+ # send ouputs to cpu
+ np_out = np_out.detach().cpu()
+ hv_out = hv_out.detach().cpu()
+ nc_out = nc_out.detach().cpu()
+ classification = True
+
+ batchsize = hv_out.shape[0]
+ # first get the nucleus detection preds
+ out_detection_list = []
+ for i in range(batchsize):
+ preds = _post_process_single_hovernet(
+ np_out[i, ...], hv_out[i, ...], small_obj_size_thresh, kernel_size, h, k
+ )
+ out_detection_list.append(preds)
+ out_detection = np.stack(out_detection_list)
+
+ if classification:
+ # need to do last step of majority vote
+ # get the pixel-level class predictions from the logits
+ nc_out_preds = F.softmax(nc_out, dim=1).argmax(dim=1)
+
+ out_classification = np.zeros_like(nc_out.numpy(), dtype=np.uint8)
+
+ for batch_ix, nuc_preds in enumerate(out_detection_list):
+ # get labels of nuclei from nucleus detection
+ nucleus_labels = list(np.unique(nuc_preds))
+ if 0 in nucleus_labels:
+ nucleus_labels.remove(0) # 0 is background
+ nucleus_class_preds = nc_out_preds[batch_ix, ...]
+
+ out_class_preds_single = out_classification[batch_ix, ...]
+
+ # for each nucleus, get the class predictions for the pixels and take a vote
+ for nucleus_ix in nucleus_labels:
+ # get mask for the specific nucleus
+ ix_mask = nuc_preds == nucleus_ix
+ votes = nucleus_class_preds[ix_mask]
+ majority_class = np.argmax(np.bincount(votes))
+ out_class_preds_single[majority_class][ix_mask] = nucleus_ix
+
+ out_classification[batch_ix, ...] = out_class_preds_single
+
+ return out_detection, out_classification
+ else:
+ return out_detection
+
+
+# plotting hovernet outputs
+
+
+def _vis_outputs_single(
+ images, preds, n_classes, index=0, ax=None, markersize=5, palette=None
+):
+ """
+ Plot the results of HoVer-Net predictions for a single image, overlayed on the original image.
+
+ Args:
+ images: Input RGB image batch. Tensor of shape (B, 3, H, W).
+ preds: Postprocessed outputs of HoVer-Net. From post_process_batch_hovernet(). Can be either:
+ - Tensor of shape (B, H, W), in the context of nucleus detection.
+ - Tensor of shape (B, n_classes, H, W), in the context of nucleus classification.
+ n_classes (int): Number of classes for classification setting, or None to indicate detection setting.
+ index (int): Index of image to plot.
+ ax: Matplotlib axes object to plot on. If None, creates a new plot. Defaults to None.
+ markersize: Size of markers used to outline nuclei
+ palette (list): list of colors to use for plotting. If None, uses matplotlib.colors.TABLEAU_COLORS.
+ Defaults to None
+ """
+ if palette is None:
+ palette = list(TABLEAU_COLORS.values())
+
+ if n_classes is not None:
+ classification = True
+ n_classes = preds.shape[1]
+ assert (
+ len(palette) >= n_classes
+ ), f"len(palette)={len(palette)} < n_classes={n_classes}."
+ else:
+ classification = False
+
+ assert len(preds.shape) in [
+ 3,
+ 4,
+ ], f"Preds shape is {preds.shape}. Must be (B, H, W) or (B, n_classes, H, W)"
+
+ if ax is None:
+ fig, ax = plt.subplots()
+
+ ax.imshow(images[index, ...].permute(1, 2, 0))
+
+ if classification is False:
+ nucleus_labels = list(np.unique(preds[index, ...]))
+ nucleus_labels.remove(0) # background
+ # plot each individual nucleus
+ for label in nucleus_labels:
+ nuclei_mask = preds[index, ...] == label
+ x, y = segmentation_lines(nuclei_mask.astype(np.uint8))
+ ax.scatter(x, y, color=palette[0], marker=".", s=markersize)
+ else:
+ nucleus_labels = list(np.unique(preds[index, ...]))
+ nucleus_labels.remove(0) # background
+ # plot each individual nucleus
+ for label in nucleus_labels:
+ for i in range(n_classes):
+ nuclei_mask = preds[index, i, ...] == label
+ x, y = segmentation_lines(nuclei_mask.astype(np.uint8))
+ ax.scatter(x, y, color=palette[i], marker=".", s=markersize)
+ ax.axis("off")
diff --git a/pathml/ml/utils.py b/pathml/ml/utils.py
index b0dd36f0..f6eb8577 100644
--- a/pathml/ml/utils.py
+++ b/pathml/ml/utils.py
@@ -7,7 +7,95 @@
# Utilities for ML module
import torch
+from sklearn.utils.class_weight import compute_class_weight
from torch.nn import functional as F
+from torch_geometric.utils import degree
+from tqdm import tqdm
+
+
+def scatter_sum(src, index, dim, out=None, dim_size=None):
+ """
+ Reduces all values from the :attr:`src` tensor into :attr:`out` at the
+ indices specified in the :attr:`index` tensor along a given axis
+ :attr:`dim`.
+
+ For each value in :attr:`src`, its output index is specified by its index
+ in :attr:`src` for dimensions outside of :attr:`dim` and by the
+ corresponding value in :attr:`index` for dimension :attr:`dim`.
+ The applied reduction is defined via the :attr:`reduce` argument.
+
+ Args:
+ src: The source tensor.
+ index: The indices of elements to scatter.
+ dim: The axis along which to index. Default is -1.
+ out: The destination tensor.
+ dim_size: If `out` is not given, automatically create output with size `dim_size` at dimension `dim`.
+
+ Reference:
+ https://pytorch-scatter.readthedocs.io/en/latest/_modules/torch_scatter/scatter.html#scatter
+ """
+
+ index = broadcast(index, src, dim)
+ if out is None:
+ size = list(src.size())
+ if dim_size is not None:
+ size[dim] = dim_size
+ elif index.numel() == 0:
+ size[dim] = 0
+ else:
+ size[dim] = int(index.max()) + 1
+ out = torch.zeros(size, dtype=src.dtype, device=src.device)
+ return out.scatter_add_(dim, index, src)
+ else:
+ return out.scatter_add_(dim, index, src)
+
+
+def broadcast(src, other, dim):
+ """
+ Broadcast tensors to match output tensor dimension.
+ """
+ if dim < 0:
+ dim = other.dim() + dim
+ if src.dim() == 1:
+ for _ in range(0, dim):
+ src = src.unsqueeze(0)
+ for _ in range(src.dim(), other.dim()):
+ src = src.unsqueeze(-1)
+ src = src.expand(other.size())
+ return src
+
+
+def get_degree_histogram(loader, edge_index_str, x_str):
+ """
+ Returns the degree histogram to be used as input for the `deg` argument in `PNAConv`.
+ """
+
+ deg_histogram = torch.zeros(1, dtype=torch.long)
+ for data in tqdm(loader):
+ d = degree(
+ data[edge_index_str][1], num_nodes=data[x_str].shape[0], dtype=torch.long
+ )
+ d_bincount = torch.bincount(d, minlength=deg_histogram.numel())
+ if d_bincount.size(0) > deg_histogram.size(0):
+ d_bincount[: deg_histogram.size(0)] += deg_histogram
+ deg_histogram = d_bincount
+ else:
+ assert d_bincount.size(0) == deg_histogram.size(0)
+ deg_histogram += d_bincount
+ return deg_histogram
+
+
+def get_class_weights(loader):
+ """
+ Returns the per-class weights to be used in weighted loss functions.
+ """
+
+ ys = []
+ for data in tqdm(loader):
+ ys.append(data.target.numpy())
+ ys = np.array(ys).ravel()
+ weights = compute_class_weight("balanced", classes=np.unique(ys), y=np.array(ys))
+ return weights
def center_crop_im_batch(batch, dims, batch_order="BCHW"):
diff --git a/pathml/preprocessing/tilestitcher.py b/pathml/preprocessing/tilestitcher.py
new file mode 100644
index 00000000..8eb21e79
--- /dev/null
+++ b/pathml/preprocessing/tilestitcher.py
@@ -0,0 +1,385 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+
+import glob
+import os
+import platform
+import subprocess
+import sys
+import traceback
+import urllib
+import zipfile
+
+import jpype
+import tifffile
+
+
+class TileStitcher:
+ """
+
+ This class serves as a Python implementation of a script originally authored by Pete Bankhead,
+ available at https://gist.github.com/petebankhead/b5a86caa333de1fdcff6bdee72a20abe
+ The original script is designed to stitch spectrally unmixed images into a pyramidal OME-TIFF format.
+
+ Make sure QuPath and JDK are installed before using this class.
+
+ """
+
+ def __init__(
+ self, qupath_jarpath=[], java_path=None, memory="40g", bfconvert_dir="./"
+ ):
+ """
+ Initialize the TileStitcher by setting up the Java Virtual Machine and QuPath environment.
+ """
+ self.bfconvert_path = self.setup_bfconvert(bfconvert_dir)
+
+ if java_path:
+ os.environ["JAVA_HOME"] = java_path
+ else:
+ self.set_environment_paths()
+ print("Setting Environment Paths")
+
+ # print(qupath_jarpath)
+ self.classpath = os.pathsep.join(qupath_jarpath)
+ self.memory = memory
+ self._start_jvm()
+
+ def __del__(self):
+ """Ensure the JVM is shutdown when the object is deleted."""
+
+ if jpype.isJVMStarted():
+ jpype.shutdownJVM()
+
+ def setup_bfconvert(self, bfconvert_dir):
+ setup_dir = bfconvert_dir
+ parent_dir = os.path.dirname(setup_dir)
+ tools_dir = os.path.join(parent_dir, "tools")
+ self.bfconvert_path = os.path.join(tools_dir, "bftools", "bfconvert")
+ self.bf_sh_path = os.path.join(tools_dir, "bftools", "bf.sh")
+ print(self.bfconvert_path, self.bf_sh_path)
+
+ # Ensure the tools directory exists
+ try:
+ if not os.path.exists(tools_dir):
+ os.makedirs(tools_dir)
+ except PermissionError:
+ raise PermissionError(
+ f"Permission denied: Cannot create directory {tools_dir}"
+ )
+
+ # If bftools folder does not exist, check for bftools.zip or download it
+ if not os.path.exists(os.path.join(tools_dir, "bftools")):
+ zip_path = os.path.join(tools_dir, "bftools.zip")
+
+ if not os.path.exists(zip_path):
+ url = "https://downloads.openmicroscopy.org/bio-formats/latest/artifacts/bftools.zip"
+ print(f"Downloading bfconvert from {url}...")
+ urllib.request.urlretrieve(url, zip_path)
+
+ print(f"Unzipping {zip_path}...")
+ try:
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
+ zip_ref.extractall(tools_dir)
+ except zipfile.BadZipFile:
+ raise zipfile.BadZipFile(f"Invalid ZIP file: {zip_path}")
+
+ if os.path.exists(zip_path):
+ os.remove(zip_path)
+
+ print(f"bfconvert set up at {self.bfconvert_path}")
+
+ system = platform.system().lower()
+ if system == "linux":
+ try:
+ os.chmod(self.bf_sh_path, os.stat(self.bf_sh_path).st_mode | 0o111)
+ os.chmod(
+ self.bfconvert_path, os.stat(self.bfconvert_path).st_mode | 0o111
+ )
+ except PermissionError:
+ raise PermissionError("Permission denied: Cannot chmod files")
+
+ # Print bfconvert version
+ try:
+ version_output = subprocess.check_output([self.bfconvert_path, "-version"])
+ print(f"bfconvert version: {version_output.decode('utf-8').strip()}")
+ except subprocess.CalledProcessError:
+ raise subprocess.CalledProcessError(
+ 1,
+ [self.bfconvert_path, "-version"],
+ output="Failed to get bfconvert version.",
+ )
+
+ return self.bfconvert_path
+
+ def set_environment_paths(self):
+ """
+ Set the JAVA_HOME path based on the OS type.
+ If the path is not found in the predefined paths dictionary, the function tries
+ to automatically find the JAVA_HOME path from the system.
+ """
+ print("Setting Environment Paths")
+ if "JAVA_HOME" in os.environ and os.environ["JAVA_HOME"]:
+ # If JAVA_HOME is already set by the user, use that.
+ print("Java Home is already set")
+ return
+
+ # Try to get the JAVA_HOME from the echo command
+ java_home = self.get_system_java_home()
+ if not java_home:
+ raise EnvironmentError(
+ "JAVA_HOME not found. Please set it before proceeding or provide it explicitly."
+ )
+
+ print(f"Setting Java path to {java_home}")
+ os.environ["JAVA_HOME"] = java_home
+
+ def get_system_java_home(self):
+ """
+ Try to automatically find the JAVA_HOME path from the system.
+ Return it if found, otherwise return an empty string.
+ """
+ try:
+ # Execute the echo command to get the JAVA_HOME
+ java_home = subprocess.getoutput("echo $JAVA_HOME").strip()
+ if not java_home:
+ raise EnvironmentError("Unable to retrieve JAVA_HOME from the system.")
+ return java_home
+ except Exception as e:
+ print("Error retrieving JAVA_HOME from the system", e)
+ return ""
+
+ def run_image_stitching(
+ self, infiles, fileout, downsamples=[1, 8], separate_series=False
+ ):
+ """
+ Perform image stitching on the provided TIFF files and output a stitched OME-TIFF image.
+ """
+ try:
+ infiles = self._collect_tif_files(infiles)
+ fileout, file_jpype = self._get_outfile(fileout)
+
+ if not infiles or not file_jpype:
+ return
+
+ server = self.parse_regions(infiles)
+ server = self.ImageServers.pyramidalize(server)
+ self._write_pyramidal_image_server(server, file_jpype, downsamples)
+
+ server.close()
+ print(f"Image stitching completed. Output file: {file_jpype}")
+
+ if separate_series:
+ self.run_bfconvert(fileout)
+
+ except Exception as e:
+ print(f"Error running image stitching: {e}")
+ traceback.print_exc()
+
+ def _start_jvm(self):
+ """Start the Java Virtual Machine and import necessary QuPath classes."""
+ if not jpype.isJVMStarted():
+ try:
+ # Set memory usage and classpath for the JVM
+ memory_usage = f"-Xmx{self.memory}"
+ class_path_option = "-Djava.class.path=%s" % self.classpath
+
+ # Try to start the JVM with the specified options
+ jpype.startJVM(memory_usage, class_path_option)
+
+ print(f"Using JVM version: {jpype.getJVMVersion()}")
+
+ # Import necessary QuPath classes
+ self._import_qupath_classes()
+
+ except Exception as e:
+ # Catch any exception that occurs during JVM startup and print the traceback
+ print(f"Error occurred while starting JVM: {e}")
+ traceback.print_exc()
+ sys.exit(1)
+ else:
+ print("JVM started successfully")
+ else:
+ print("JVM was already started")
+
+ def _import_qupath_classes(self):
+ """Import necessary QuPath classes after starting JVM."""
+
+ try:
+ print("Importing required qupath classes")
+ self.ImageServerProvider = jpype.JPackage(
+ "qupath.lib.images.servers"
+ ).ImageServerProvider
+ self.ImageServers = jpype.JPackage("qupath.lib.images.servers").ImageServers
+ self.SparseImageServer = jpype.JPackage(
+ "qupath.lib.images.servers"
+ ).SparseImageServer
+ self.OMEPyramidWriter = jpype.JPackage(
+ "qupath.lib.images.writers.ome"
+ ).OMEPyramidWriter
+ self.ImageRegion = jpype.JPackage("qupath.lib.regions").ImageRegion
+ self.ImageIO = jpype.JPackage("javax.imageio").ImageIO
+ self.BaselineTIFFTagSet = jpype.JPackage(
+ "javax.imageio.plugins.tiff"
+ ).BaselineTIFFTagSet
+ self.TIFFDirectory = jpype.JPackage(
+ "javax.imageio.plugins.tiff"
+ ).TIFFDirectory
+ self.BufferedImage = jpype.JPackage("java.awt.image").BufferedImage
+
+ except Exception as e:
+ raise RuntimeError(f"Failed to import QuPath classes: {e}")
+
+ def _collect_tif_files(self, input):
+ """Collect .tif files from the input directory or list."""
+ if isinstance(input, str) and os.path.isdir(input):
+ return glob.glob(os.path.join(input, "**/*.tif"), recursive=True)
+ elif isinstance(input, list):
+ return [file for file in input if file.endswith(".tif")]
+ else:
+ print(
+ f"Input must be a directory path or list of .tif files. Received: {input}"
+ )
+ return []
+
+ def _get_outfile(self, fileout):
+ """Get the output file object for the stitched image."""
+ if not fileout.endswith(".ome.tif"):
+ fileout += ".ome.tif"
+ return fileout, jpype.JClass("java.io.File")(fileout)
+
+ def parseRegion(self, file, z=0, t=0):
+ if self.checkTIFF(file):
+ try:
+ # Extract the image region coordinates and dimensions from the TIFF tags
+ with tifffile.TiffFile(file) as tif:
+ tag_xpos = tif.pages[0].tags.get("XPosition")
+ tag_ypos = tif.pages[0].tags.get("YPosition")
+ tag_xres = tif.pages[0].tags.get("XResolution")
+ tag_yres = tif.pages[0].tags.get("YResolution")
+ if (
+ tag_xpos is None
+ or tag_ypos is None
+ or tag_xres is None
+ or tag_yres is None
+ ):
+ print(f"Could not find required tags for {file}")
+ return None
+ xpos = 10000 * tag_xpos.value[0] / tag_xpos.value[1]
+ xres = tag_xres.value[0] / (tag_xres.value[1] * 10000)
+ ypos = 10000 * tag_ypos.value[0] / tag_ypos.value[1]
+ yres = tag_yres.value[0] / (tag_yres.value[1] * 10000)
+ height = tif.pages[0].tags.get("ImageLength").value
+ width = tif.pages[0].tags.get("ImageWidth").value
+ x = int(round(xpos * xres))
+ y = int(round(ypos * yres))
+ # Create an ImageRegion object representing the extracted image region
+ region = self.ImageRegion.createInstance(x, y, width, height, z, t)
+ return region
+ except Exception as e:
+ print(f"Error occurred while parsing {file}: {e}")
+ traceback.print_exc()
+ raise
+ else:
+ print(f"{file} is not a valid TIFF file")
+
+ # Define a function to check if a file is a valid TIFF file
+ def checkTIFF(self, file):
+ try:
+ with open(file, "rb") as f:
+ bytes = f.read(4)
+ byteOrder = self.toShort(bytes[0], bytes[1])
+ if byteOrder == 0x4949: # Little-endian
+ val = self.toShort(bytes[3], bytes[2])
+ elif byteOrder == 0x4D4D: # Big-endian
+ val = self.toShort(bytes[2], bytes[3])
+ else:
+ return False
+ return val == 42 or val == 43
+ except FileNotFoundError:
+ print(f"Error: File not found {file}")
+ raise FileNotFoundError
+ except IOError:
+ print(f"Error: Could not open file {file}")
+ raise IOError
+ except Exception as e:
+ print(f"Error: {e}")
+
+ # Define a helper function to convert two bytes to a short integer
+ def toShort(self, b1, b2):
+ return (b1 << 8) + b2
+
+ # Define a function to parse TIFF file metadata and extract the image region
+ def parse_regions(self, infiles):
+ builder = self.SparseImageServer.Builder()
+ for f in infiles:
+ try:
+ region = self.parseRegion(f)
+ if region is None:
+ print("WARN: Could not parse region for " + str(f))
+ continue
+ serverBuilder = (
+ self.ImageServerProvider.getPreferredUriImageSupport(
+ self.BufferedImage, jpype.JString(f)
+ )
+ .getBuilders()
+ .get(0)
+ )
+ builder.jsonRegion(region, 1.0, serverBuilder)
+ except Exception as e:
+ print(f"Error parsing regions from file {f}: {e}")
+ traceback.print_exc()
+ return builder.build()
+
+ def _write_pyramidal_image_server(self, server, fileout, downsamples):
+ """Convert the parsed image regions into a pyramidal image server and write to file."""
+ # Convert the parsed regions into a pyramidal image server and write to file
+
+ try:
+ newOME = self.OMEPyramidWriter.Builder(server)
+
+ # Control downsamples
+ if downsamples is None:
+ downsamples = server.getPreferredDownsamples()
+ print(downsamples)
+ newOME.downsamples(downsamples).tileSize(
+ 512
+ ).channelsInterleaved().parallelize().losslessCompression().build().writePyramid(
+ fileout.getAbsolutePath()
+ )
+ except Exception as e:
+ print(f"Error writing pyramidal image server to file {fileout}: {e}")
+ traceback.print_exc()
+
+ def run_bfconvert(self, stitched_image_path, bfconverted_path=None):
+ if not self.is_bfconvert_available():
+ print("bfconvert command not available. Skipping bfconvert step.")
+ return
+
+ if not bfconverted_path:
+ base_path, ext = os.path.splitext(stitched_image_path)
+ bfconverted_path = f"{base_path}_sep.tif"
+
+ bfconvert_command = f"./{self.bfconvert_path} -series 0 -separate '{stitched_image_path}' '{bfconverted_path}'"
+
+ try:
+ subprocess.run(bfconvert_command, shell=True, check=True)
+ print(f"bfconvert completed. Output file: {bfconverted_path}")
+ except subprocess.CalledProcessError:
+ print("Error running bfconvert command.")
+
+ def is_bfconvert_available(self):
+ try:
+ result = subprocess.run(
+ [f"./{self.bfconvert_path}", "-version"],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ if result.returncode == 0:
+ return True
+ else:
+ return False
+ except FileNotFoundError:
+ return False
diff --git a/pathml/utils.py b/pathml/utils.py
index 7b7a4bc8..6f1e09f6 100644
--- a/pathml/utils.py
+++ b/pathml/utils.py
@@ -5,7 +5,9 @@
import os
import shutil
+import tarfile
import urllib
+from pathlib import Path
import cv2
import matplotlib.pyplot as plt
@@ -32,6 +34,7 @@ def download_from_url(url, download_dir, name=None):
path = os.path.join(download_dir, name)
if os.path.exists(path):
+ print(f"File {name} already exists, skipping download.")
return
else:
os.makedirs(download_dir, exist_ok=True)
@@ -39,6 +42,7 @@ def download_from_url(url, download_dir, name=None):
# Download the file from `url` and save it locally under `file_name`:
with urllib.request.urlopen(url) as response, open(path, "wb") as out_file:
shutil.copyfileobj(response, out_file)
+ return path # added when including qupath utils
def parse_file_size(fs):
@@ -345,3 +349,38 @@ def _test_log(msg):
# passes thru message to pathml logger
# used for testing logging
logger.info(msg)
+
+
+def find_qupath_home(start_path):
+ for root, dirs, files in os.walk(start_path):
+ if any("qupath" in file.lower() and file.endswith(".jar") for file in files):
+ return str(Path(root).parent.parent)
+ return None
+
+
+def setup_qupath(qupath_home=None):
+ default_path = str(Path.home() / "tools/qupath")
+ qupath_home = qupath_home if qupath_home is not None else default_path
+ Path(qupath_home).mkdir(parents=True, exist_ok=True)
+
+ # Check for existing QuPath installation
+ existing_qupath_home = find_qupath_home(qupath_home)
+ if existing_qupath_home:
+ return existing_qupath_home
+
+ print("Downloading")
+ # URL and name of QuPath tarball
+ # qupath_url = "https://github.com/qupath/qupath/releases/download/v0.3.0/QuPath-0.3.0-Linux.tar.xz"
+ qupath_url = "https://github.com/qupath/qupath/releases/download/v0.4.3/QuPath-0.4.3-Linux.tar.xz"
+ qupath_tar_name = "QuPath-0.4.3-Linux.tar.xz"
+ tar_path = download_from_url(qupath_url, qupath_home, qupath_tar_name)
+
+ # Extract QuPath if the tarball was downloaded
+ if tar_path:
+ print("Extracting QuPath...")
+ with tarfile.open(tar_path) as tar:
+ tar.extractall(path=qupath_home)
+ os.remove(tar_path)
+
+ # Find the QuPath home by searching for jar files
+ return find_qupath_home(qupath_home)
diff --git a/pyproject.toml b/pyproject.toml
index 6bb669c3..2a50d3d2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,6 +8,7 @@ build-backend = "setuptools.build_meta"
[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
+ "exclude: marks tests to exclude (deselect with '-m \"not exclude\"')"
]
[tool.isort]
diff --git a/tests/datasets_tests/test_dataset_utils.py b/tests/datasets_tests/test_dataset_utils.py
index 46ac26b4..a0c6688c 100644
--- a/tests/datasets_tests/test_dataset_utils.py
+++ b/tests/datasets_tests/test_dataset_utils.py
@@ -2,3 +2,139 @@
Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
License: GNU GPL 2.0
"""
+import numpy as np
+import pytest
+import torch.nn as nn
+import torch.nn.functional as F
+from skimage.draw import ellipse
+from skimage.measure import label, regionprops
+
+from pathml.datasets.utils import DeepPatchFeatureExtractor
+
+
+class SimpleCNN(nn.Module):
+ def __init__(self, input_shape):
+ super(SimpleCNN, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_channels=input_shape[0],
+ out_channels=32,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ self.conv2 = nn.Conv2d(
+ in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
+ )
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
+
+ fc_input_size = (input_shape[1] // 4) * (input_shape[2] // 4) * 64
+
+ self.fc1 = nn.Linear(fc_input_size, 128)
+ self.fc2 = nn.Linear(128, 10)
+
+ def forward(self, x):
+ x = self.pool(F.relu(self.conv1(x)))
+ x = self.pool(F.relu(self.conv2(x)))
+
+ # Flatten the output for the fully connected layer
+ x = x.view(x.size(0), -1)
+
+ x = F.relu(self.fc1(x))
+ x = self.fc2(x)
+ return x
+
+
+def make_fake_instance_maps(num, image_size, ellipse_height, ellipse_width):
+ img = np.zeros(image_size)
+
+ # Draw n ellipses
+ for i in range(num):
+ # Random center for each ellipse
+ center_x = np.random.randint(ellipse_width, image_size[1] - ellipse_width)
+ center_y = np.random.randint(ellipse_height, image_size[0] - ellipse_height)
+
+ # Coordinates for the ellipse
+ rr, cc = ellipse(
+ center_y, center_x, ellipse_height, ellipse_width, shape=image_size
+ )
+
+ # Draw the ellipse
+ img[rr, cc] = 1
+
+ label_img = label(img.astype(int))
+
+ return label_img
+
+
+def make_fake_image(instance_map):
+ image = instance_map[:, :, None]
+ image[image > 0] = 1
+ noised_image = (
+ np.random.rand(instance_map.shape[0], instance_map.shape[1], 3) * 0.15 + image
+ ) * 255
+
+ return noised_image.astype("uint8")
+
+
+@pytest.mark.parametrize("patch_size", [1, 64, 128])
+@pytest.mark.parametrize("entity", ["cell", "tissue"])
+@pytest.mark.parametrize("threshold", [0, 0.1, 0.8])
+def test_feature_extractor(entity, patch_size, threshold):
+
+ image_size = (256, 256)
+
+ instance_map = make_fake_instance_maps(
+ num=20, image_size=image_size, ellipse_height=20, ellipse_width=8
+ )
+ image = make_fake_image(instance_map.copy())
+ regions = regionprops(instance_map)
+
+ model = SimpleCNN(input_shape=(3, 224, 224))
+
+ extractor = DeepPatchFeatureExtractor(
+ patch_size=patch_size,
+ batch_size=1,
+ entity=entity,
+ architecture=model,
+ fill_value=255,
+ resize_size=224,
+ threshold=threshold,
+ )
+ features = extractor.process(image, instance_map)
+
+ if threshold == 0:
+ assert features.shape[0] == len(regions)
+ else:
+ assert features.shape[0] <= len(regions)
+
+
+@pytest.mark.parametrize("patch_size", [1, 64, 128])
+@pytest.mark.parametrize("entity", ["cell", "tissue"])
+@pytest.mark.parametrize("threshold", [0, 0.1, 0.8])
+def test_feature_extractor_torchvision(entity, patch_size, threshold):
+ pytest.importorskip("torchvision")
+
+ image_size = (256, 256)
+
+ instance_map = make_fake_instance_maps(
+ num=20, image_size=image_size, ellipse_height=20, ellipse_width=8
+ )
+ image = make_fake_image(instance_map.copy())
+ regions = regionprops(instance_map)
+
+ extractor = DeepPatchFeatureExtractor(
+ patch_size=patch_size,
+ batch_size=1,
+ entity=entity,
+ architecture="resnet34",
+ fill_value=255,
+ resize_size=224,
+ threshold=threshold,
+ )
+ features = extractor.process(image, instance_map)
+
+ if threshold == 0:
+ assert features.shape[0] == len(regions)
+ else:
+ assert features.shape[0] <= len(regions)
diff --git a/tests/graph_tests/test_graph_building.py b/tests/graph_tests/test_graph_building.py
new file mode 100644
index 00000000..7e5b11f9
--- /dev/null
+++ b/tests/graph_tests/test_graph_building.py
@@ -0,0 +1,86 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import numpy as np
+import pytest
+import torch
+from skimage.draw import ellipse
+from skimage.measure import label, regionprops
+
+from pathml.graph import KNNGraphBuilder, RAGGraphBuilder
+
+
+def make_fake_instance_maps(num, image_size, ellipse_height, ellipse_width):
+ img = np.zeros(image_size)
+
+ # Draw n ellipses
+ for i in range(num):
+ # Random center for each ellipse
+ center_x = np.random.randint(ellipse_width, image_size[1] - ellipse_width)
+ center_y = np.random.randint(ellipse_height, image_size[0] - ellipse_height)
+
+ # Coordinates for the ellipse
+ rr, cc = ellipse(
+ center_y, center_x, ellipse_height, ellipse_width, shape=image_size
+ )
+
+ # Draw the ellipse
+ img[rr, cc] = 1
+
+ label_img = label(img.astype(int))
+
+ return label_img
+
+
+@pytest.mark.parametrize("k", [1, 10, 50])
+@pytest.mark.parametrize("thresh", [0, 10, 200])
+@pytest.mark.parametrize("add_loc_feats", [True, False])
+def test_knn_graph_building(k, thresh, add_loc_feats):
+ image_size = (1024, 2048)
+
+ instance_map = make_fake_instance_maps(
+ num=100, image_size=image_size, ellipse_height=10, ellipse_width=8
+ )
+ regions = regionprops(instance_map)
+
+ features = torch.randn(len(regions), 512)
+
+ graph_builder = KNNGraphBuilder(k=k, thresh=thresh, add_loc_feats=add_loc_feats)
+
+ graph = graph_builder.process(instance_map, features, target=1)
+
+ assert graph.node_centroids.shape == (len(regions), 2)
+ assert graph.edge_index.shape[0] == 2
+ if add_loc_feats:
+ assert graph.node_features.shape == (len(regions), 514)
+ else:
+ assert graph.node_features.shape == (len(regions), 512)
+
+
+@pytest.mark.parametrize("kernel_size", [1, 3, 10])
+@pytest.mark.parametrize("hops", [1, 2, 5])
+@pytest.mark.parametrize("add_loc_feats", [True, False])
+def test_rag_graph_building(kernel_size, hops, add_loc_feats):
+ image_size = (1024, 2048)
+
+ instance_map = make_fake_instance_maps(
+ num=100, image_size=image_size, ellipse_height=10, ellipse_width=8
+ )
+ regions = regionprops(instance_map)
+
+ features = torch.randn(len(regions), 512)
+
+ graph_builder = RAGGraphBuilder(
+ kernel_size=kernel_size, hops=hops, add_loc_feats=add_loc_feats
+ )
+
+ graph = graph_builder.process(instance_map, features, target=1)
+
+ assert graph.node_centroids.shape == (len(regions), 2)
+ assert graph.edge_index.shape[0] == 2
+ if add_loc_feats:
+ assert graph.node_features.shape == (len(regions), 514)
+ else:
+ assert graph.node_features.shape == (len(regions), 512)
diff --git a/tests/graph_tests/test_graph_extractor.py b/tests/graph_tests/test_graph_extractor.py
new file mode 100644
index 00000000..6a187443
--- /dev/null
+++ b/tests/graph_tests/test_graph_extractor.py
@@ -0,0 +1,25 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import networkx as nx
+import pytest
+
+from pathml.graph.preprocessing import GraphFeatureExtractor
+
+
+@pytest.mark.parametrize("use_weight", [True, False])
+@pytest.mark.parametrize("alpha", [0, 0.5, 0.95])
+def test_graph_feature_extractor(use_weight, alpha):
+
+ # Creating a simple graph
+ G = nx.DiGraph()
+
+ # Adding nodes
+ G.add_weighted_edges_from([(1, 2, 1), (2, 3, 1), (3, 4, 1), (4, 5, 1), (5, 1, 1)])
+
+ extractor = GraphFeatureExtractor(use_weight=use_weight, alpha=alpha)
+ features = extractor.process(G)
+
+ assert features
diff --git a/tests/graph_tests/test_graph_utils.py b/tests/graph_tests/test_graph_utils.py
new file mode 100644
index 00000000..0d339545
--- /dev/null
+++ b/tests/graph_tests/test_graph_utils.py
@@ -0,0 +1,52 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import numpy as np
+import pytest
+from skimage.draw import ellipse
+from skimage.measure import label
+
+from pathml.graph import build_assignment_matrix
+
+
+def make_fake_instance_maps(num, image_size, ellipse_height, ellipse_width):
+ img = np.zeros(image_size)
+
+ # Draw n ellipses
+ for i in range(num):
+ # Random center for each ellipse
+ center_x = np.random.randint(ellipse_width, image_size[1] - ellipse_width)
+ center_y = np.random.randint(ellipse_height, image_size[0] - ellipse_height)
+
+ # Coordinates for the ellipse
+ rr, cc = ellipse(
+ center_y, center_x, ellipse_height, ellipse_width, shape=image_size
+ )
+
+ # Draw the ellipse
+ img[rr, cc] = 1
+
+ label_img = label(img.astype(int))
+
+ return label_img
+
+
+@pytest.mark.parametrize("matrix", [True, False])
+def test_build_assignment_matrix(matrix):
+ image_size = (1024, 2048)
+
+ tissue_instance_map = make_fake_instance_maps(
+ num=20, image_size=image_size, ellipse_height=20, ellipse_width=8
+ )
+ cell_centroids = np.random.rand(200, 2)
+
+ assignment = build_assignment_matrix(
+ cell_centroids, tissue_instance_map, matrix=matrix
+ )
+
+ if matrix:
+ assert assignment.shape[0] == 200
+ else:
+ assert assignment.shape[1] == 200
diff --git a/tests/graph_tests/test_tissue_extractor.py b/tests/graph_tests/test_tissue_extractor.py
new file mode 100644
index 00000000..c41668fd
--- /dev/null
+++ b/tests/graph_tests/test_tissue_extractor.py
@@ -0,0 +1,83 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import numpy as np
+import pytest
+from skimage.draw import ellipse
+from skimage.measure import label
+
+from pathml.graph import ColorMergedSuperpixelExtractor
+from pathml.graph.preprocessing import SLICSuperpixelExtractor
+
+
+def make_fake_instance_maps(num, image_size, ellipse_height, ellipse_width):
+ img = np.zeros(image_size)
+
+ # Draw n ellipses
+ for i in range(num):
+ # Random center for each ellipse
+ center_x = np.random.randint(ellipse_width, image_size[1] - ellipse_width)
+ center_y = np.random.randint(ellipse_height, image_size[0] - ellipse_height)
+
+ # Coordinates for the ellipse
+ rr, cc = ellipse(
+ center_y, center_x, ellipse_height, ellipse_width, shape=image_size
+ )
+
+ # Draw the ellipse
+ img[rr, cc] = 1
+
+ label_img = label(img.astype(int))
+
+ return label_img
+
+
+def make_fake_image(instance_map):
+ image = instance_map[:, :, None]
+ image[image > 0] = 1
+ noised_image = (
+ np.random.rand(instance_map.shape[0], instance_map.shape[1], 3) * 0.15 + image
+ ) * 255
+
+ return noised_image.astype("uint8")
+
+
+@pytest.mark.parametrize("superpixel_size", [20, 200])
+@pytest.mark.parametrize("compactness", [50, 100])
+@pytest.mark.parametrize("blur_kernel_size", [0.2, 1])
+@pytest.mark.parametrize("threshold", [0.1, 0.9])
+@pytest.mark.parametrize("downsampling_factor", [4, 10])
+@pytest.mark.parametrize(
+ "extractor", [ColorMergedSuperpixelExtractor, SLICSuperpixelExtractor]
+)
+def test_tissue_extractors(
+ superpixel_size,
+ compactness,
+ blur_kernel_size,
+ threshold,
+ downsampling_factor,
+ extractor,
+):
+ image_size = (256, 256)
+
+ instance_map = make_fake_instance_maps(
+ num=30, image_size=image_size, ellipse_height=20, ellipse_width=8
+ )
+ image = make_fake_image(instance_map.copy())
+
+ tissue_detector = extractor(
+ superpixel_size=superpixel_size,
+ compactness=compactness,
+ blur_kernel_size=blur_kernel_size,
+ threshold=threshold,
+ downsampling_factor=downsampling_factor,
+ )
+
+ superpixels = tissue_detector.process(image)
+
+ if isinstance(superpixels, tuple):
+ superpixels = superpixels[0]
+
+ assert superpixels.shape == image_size
diff --git a/tests/ml_tests/test_hactnet.py b/tests/ml_tests/test_hactnet.py
new file mode 100644
index 00000000..cd55cba7
--- /dev/null
+++ b/tests/ml_tests/test_hactnet.py
@@ -0,0 +1,93 @@
+"""
+Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
+License: GNU GPL 2.0
+"""
+
+import pytest
+import torch
+from torch_geometric.utils import erdos_renyi_graph
+
+from pathml.graph.utils import HACTPairData
+from pathml.ml import HACTNet
+
+
+def fake_hactnet_inputs():
+ """fake batch of input for HACTNet"""
+ cell_features = torch.rand(200, 256)
+ cell_edge_index = erdos_renyi_graph(200, 0.2, directed=False)
+ tissue_features = torch.rand(20, 256)
+ tissue_edge_index = erdos_renyi_graph(20, 0.2, directed=False)
+ target = torch.tensor([1, 2])
+ assignment = torch.randint(low=0, high=20, size=(200,)).long()
+ data = HACTPairData(
+ x_cell=cell_features,
+ edge_index_cell=cell_edge_index,
+ x_tissue=tissue_features,
+ edge_index_tissue=tissue_edge_index,
+ assignment=assignment,
+ target=target,
+ )
+ return data
+
+
+@pytest.mark.parametrize("batch_size", [1, 2])
+@pytest.mark.parametrize("readout_op", ["lstm", "concat", None])
+def test_hovernet_forward_pass(batch_size, readout_op):
+ """Make sure that dimensions of outputs are as expected from forward pass"""
+ batch = fake_hactnet_inputs()
+ batch["x_cell_batch"] = torch.zeros(200).long()
+ batch["x_tissue_batch"] = torch.zeros(20).long()
+ if batch_size > 1:
+ batch["x_cell_batch"][-100:] = 1
+ batch["x_tissue_batch"][-10:] = 1
+
+ cell_deg = torch.randint(low=0, high=20000, size=(14,))
+ tissue_deg = torch.randint(low=0, high=2000, size=(14,))
+
+ kwargs_pna_cell = {
+ "aggregators": ["mean", "max", "min", "std"],
+ "scalers": ["identity", "amplification", "attenuation"],
+ "deg": cell_deg,
+ }
+ kwargs_pna_tissue = {
+ "aggregators": ["mean", "max", "min", "std"],
+ "scalers": ["identity", "amplification", "attenuation"],
+ "deg": tissue_deg,
+ }
+
+ cell_params = {
+ "layer": "PNAConv",
+ "in_channels": 256,
+ "hidden_channels": 64,
+ "num_layers": 3,
+ "out_channels": 64,
+ "readout_op": readout_op,
+ "readout_type": "mean",
+ "kwargs": kwargs_pna_cell,
+ }
+
+ tissue_params = {
+ "layer": "PNAConv",
+ "in_channels": 256,
+ "hidden_channels": 64,
+ "num_layers": 3,
+ "out_channels": 64,
+ "readout_op": readout_op,
+ "readout_type": "mean",
+ "kwargs": kwargs_pna_tissue,
+ }
+
+ classifier_params = {
+ "in_channels": 128,
+ "hidden_channels": 128,
+ "out_channels": 7,
+ "num_layers": 2,
+ "norm": "batch_norm" if batch_size > 1 else "layer_norm",
+ }
+
+ model = HACTNet(cell_params, tissue_params, classifier_params)
+
+ with torch.no_grad():
+ outputs = model(batch)
+
+ assert outputs.shape == (batch_size, 7)
diff --git a/tests/preprocessing_tests/test_tilestitcher.py b/tests/preprocessing_tests/test_tilestitcher.py
new file mode 100644
index 00000000..cf2b9a7d
--- /dev/null
+++ b/tests/preprocessing_tests/test_tilestitcher.py
@@ -0,0 +1,624 @@
+import glob
+import os
+import subprocess
+import tempfile
+import zipfile
+from unittest.mock import MagicMock, mock_open, patch
+
+import javabridge
+import jpype
+import pytest
+
+from pathml.preprocessing.tilestitcher import TileStitcher
+from pathml.utils import setup_qupath
+
+
+@pytest.mark.exclude
+@pytest.fixture(scope="module")
+def tile_stitcher(request):
+ try:
+ javabridge.kill_vm()
+ print("Javabridge vm terminated")
+ except Exception as e:
+ print(f"JVM isn't running: {e}")
+ pass # JVM was not running, so nothing to kill
+
+ # Set JAVA_HOME
+ # os.environ["JAVA_HOME"] = "/usr/lib/jvm/jdk-17/"
+
+ # Setup QuPath using the setup_qupath function
+ qupath_home = setup_qupath(
+ "../../tools/qupath"
+ ) # Replace with the appropriate path
+
+ # Ensure QUPATH_HOME is set
+ os.environ["QUPATH_HOME"] = qupath_home
+
+ # Construct path to QuPath jars
+ qupath_jars_dir = os.path.join(qupath_home, "lib", "app")
+ qupath_jars = glob.glob(os.path.join(qupath_jars_dir, "*.jar"))
+ qupath_jars.append(os.path.join(qupath_jars_dir, "libopenslide-jni.so"))
+
+ bfconvert_dir = "./"
+ stitcher = TileStitcher(qupath_jars, bfconvert_dir)
+ stitcher._start_jvm()
+
+ def teardown():
+ try:
+ javabridge.kill_vm()
+ print("Javabridge vm terminated in teardown")
+ except Exception as e:
+ print(f"Error during JVM teardown: {e}")
+ # f"Error during JVM teardown: {error_message}"
+
+ request.addfinalizer(teardown)
+
+ return stitcher
+
+
+@pytest.mark.exclude
+def test_set_environment_paths(tile_stitcher):
+ tile_stitcher.set_environment_paths()
+ assert "JAVA_HOME" in os.environ
+
+
+@pytest.mark.exclude
+def test_get_system_java_home(tile_stitcher):
+ path = tile_stitcher.get_system_java_home()
+ assert isinstance(path, str)
+
+
+@pytest.mark.exclude
+@patch("pathml.preprocessing.tilestitcher.jpype.startJVM")
+def test_start_jvm(mocked_jvm, tile_stitcher):
+ # Check if JVM was already started
+ if jpype.isJVMStarted():
+ pytest.skip("JVM was already started, so we skip this test.")
+ tile_stitcher._start_jvm()
+ mocked_jvm.assert_called()
+
+
+# @pytest.mark.exclude
+# @patch("pathml.preprocessing.tilestitcher.tifffile")
+# def test_parse_region(mocked_tifffile, tile_stitcher):
+# # Mock the return values
+# mocked_tifffile.return_value.__enter__.return_value.pages[
+# 0
+# ].tags.get.side_effect = [
+# MagicMock(value=(0, 1)), # XPosition
+# MagicMock(value=(0, 1)), # YPosition
+# MagicMock(value=(1, 1)), # XResolution
+# MagicMock(value=(1, 1)), # YResolution
+# MagicMock(value=100), # ImageLength
+# MagicMock(value=100), # ImageWidth
+# ]
+# # filename = "tests/testdata/MISI3542i_M3056_3_Panel1_Scan1_[10530,40933]_component_data.tif"
+# filename = "tests/testdata/tilestitching_testdata/MISI3542i_W21-04143_bi016966_M394_OVX_LM_Scan1_[14384,29683]_component_data.tif"
+# region = tile_stitcher.parseRegion(filename)
+# assert region is not None
+# assert isinstance(region, tile_stitcher.ImageRegion)
+
+
+# @pytest.mark.exclude
+# # @patch("pathml.preprocessing.tilestitcher.tifffile")
+# def test_parse_region(tile_stitcher):
+# # Mock the return values
+# # mocked_tifffile.return_value.__enter__.return_value.pages[
+# # 0
+# # ].tags.get.side_effect = [
+# # MagicMock(value=(0, 1)), # XPosition
+# # MagicMock(value=(0, 1)), # YPosition
+# # MagicMock(value=(1, 1)), # XResolution
+# # MagicMock(value=(1, 1)), # YResolution
+# # MagicMock(value=100), # ImageLength
+# # MagicMock(value=100), # ImageWidth
+# # ]
+# # filename = "tests/testdata/MISI3542i_M3056_3_Panel1_Scan1_[10530,40933]_component_data.tif"
+# filename = "tests/testdata/tilestitching_testdata/MISI3542i_W21-04143_bi016966_M394_OVX_LM_Scan1_[14384,29683]_component_data.tif"
+# region = tile_stitcher.parseRegion(filename)
+# assert region is not None
+# assert isinstance(region, tile_stitcher.ImageRegion)
+
+
+# @pytest.mark.exclude
+# @patch("pathml.preprocessing.tilestitcher.tifffile")
+# def test_parse_region_missing_tags(mocked_tifffile, tile_stitcher):
+# # Mock tifffile to return None for required tags
+# mocked_tifffile.return_value.__enter__.return_value.pages[0].tags.get.side_effect = [
+# None, # XPosition missing
+# None, # YPosition missing
+# None, # XResolution missing
+# None, # YResolution missing
+# ]
+
+# # Test filename
+# filename = "tests/testdata/tilestitching_testdata/nonexistent_tags.tif"
+
+# # Call the parseRegion function
+# region = tile_stitcher.parseRegion(filename)
+
+# # Assert that the function returns None due to missing tags
+# assert region is None
+
+
+@pytest.mark.exclude
+@patch("pathml.preprocessing.tilestitcher.tifffile.TiffFile")
+@patch("pathml.preprocessing.tilestitcher.TileStitcher.checkTIFF")
+def test_parse_region_exception(mocked_check_tiff, mocked_tiff_file, tile_stitcher):
+ # Mock the checkTIFF method to always return True
+ mocked_check_tiff.return_value = True
+
+ # Mock the TiffFile to raise a FileNotFoundError when used
+ mocked_tiff_file.side_effect = FileNotFoundError(
+ "Error: File not found dummy_file.tif"
+ )
+ filename = "dummy_file.tif"
+
+ # Expect FileNotFoundError to be raised
+ with pytest.raises(FileNotFoundError) as exc_info:
+ tile_stitcher.parseRegion(filename)
+
+ # Assert that the exception message matches what we expect
+ assert str(exc_info.value) == "Error: File not found dummy_file.tif"
+
+
+@pytest.mark.exclude
+def test_collect_tif_files(tile_stitcher):
+ # Assuming a directory with one tif file for testing
+ dir_path = "some_directory"
+ os.makedirs(dir_path, exist_ok=True)
+ with open(os.path.join(dir_path, "test.tif"), "w") as f:
+ f.write("test content")
+
+ files = tile_stitcher._collect_tif_files(dir_path)
+ assert len(files) == 1
+ assert "test.tif" in files[0]
+
+ os.remove(os.path.join(dir_path, "test.tif"))
+ os.rmdir(dir_path)
+
+
+@pytest.mark.exclude
+def test_checkTIFF_valid(tile_stitcher, tmp_path):
+ # Create a mock TIFF file
+ tiff_path = tmp_path / "mock.tiff"
+ tiff_path.write_bytes(b"II*\x00") # Little-endian TIFF signature
+ # assert tile_stitcher.checkTIFF(tiff_path) == True
+ assert tile_stitcher.checkTIFF(tiff_path)
+
+
+@pytest.mark.exclude
+def test_checkTIFF_invalid(tile_stitcher, tmp_path):
+ # Create a mock non-TIFF file
+ txt_path = tmp_path / "mock.txt"
+ txt_path.write_text("Not a TIFF file")
+ # assert tile_stitcher.checkTIFF(txt_path) == False
+ assert not tile_stitcher.checkTIFF(txt_path)
+
+
+@pytest.mark.exclude
+def test_checkTIFF_nonexistent(tile_stitcher):
+ # Test with a file that doesn't exist
+ with pytest.raises(FileNotFoundError):
+ tile_stitcher.checkTIFF("nonexistent_file.tiff")
+
+
+@pytest.mark.exclude
+def test_check_tiff(tile_stitcher):
+ valid_tif = b"II*"
+ invalid_tif = b"abcd"
+
+ with open("valid_test.tif", "wb") as f:
+ f.write(valid_tif)
+
+ with open("invalid_test.tif", "wb") as f:
+ f.write(invalid_tif)
+
+ assert tile_stitcher.checkTIFF("tests/testdata/smalltif.tif") is True
+ assert tile_stitcher.checkTIFF("invalid_test.tif") is False
+
+ os.remove("valid_test.tif")
+ os.remove("invalid_test.tif")
+
+
+@pytest.mark.exclude
+def test_get_outfile_ending_with_ome_tif(tile_stitcher):
+ result, result_jpype = tile_stitcher._get_outfile("test.ome.tif")
+ assert result == "test.ome.tif"
+ assert str(result_jpype) == "test.ome.tif"
+
+
+@pytest.mark.exclude
+def test_get_outfile_without_ending(tile_stitcher):
+ result, result_jpype = tile_stitcher._get_outfile("test.ome.tif")
+ assert result == "test.ome.tif"
+ assert str(result_jpype) == "test.ome.tif"
+
+
+@pytest.mark.exclude
+# Dummy function to "fake" the file download
+def mocked_urlretrieve(*args, **kwargs):
+ pass
+
+
+@pytest.mark.exclude
+# Mock Zip class as provided
+class MockZip:
+ def __init__(self, zip_path, *args, **kwargs):
+ self.zip_path = zip_path
+
+ def __enter__(self):
+ with zipfile.ZipFile(self.zip_path, "w") as zipf:
+ zipf.writestr("dummy.txt", "This is dummy file content")
+ return self
+
+ def __exit__(self, *args):
+ os.remove(self.zip_path)
+
+ def extractall(self, path, *args, **kwargs):
+ bftools_dir = os.path.join(path, "bftools")
+ if not os.path.exists(bftools_dir):
+ os.makedirs(bftools_dir)
+
+ with open(os.path.join(bftools_dir, "bfconvert"), "w") as f:
+ f.write("#!/bin/sh\necho 'dummy bfconvert'")
+
+ with open(os.path.join(bftools_dir, "bf.sh"), "w") as f:
+ f.write("#!/bin/sh\necho 'dummy bf.sh'")
+
+
+# Assuming the TileStitcher class definition is available in the current context
+# If not, you should import it
+
+
+@pytest.mark.exclude
+def mock_create_zip(zip_path):
+ """
+ Creates a mock zip file at the given path.
+ """
+ # Create mock files to add to the ZIP
+ with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
+ tmpfile.write(b"Mock file content")
+
+ # Create the mock ZIP file
+ with zipfile.ZipFile(zip_path, "w") as zipf:
+ zipf.write(tmpfile.name, "mock_file.txt")
+ os.unlink(tmpfile.name) # Clean up the temporary file
+
+
+@pytest.mark.exclude
+@pytest.fixture
+def bfconvert_dir(tmp_path):
+ return tmp_path / "bfconvert_dir"
+
+
+@pytest.mark.exclude
+def test_bfconvert_version_print(tile_stitcher, bfconvert_dir):
+ tile_stitcher.setup_bfconvert(bfconvert_dir)
+ output = subprocess.check_output([tile_stitcher.bfconvert_path, "-version"])
+ assert output.lower().startswith(b"version:")
+
+ # assert subprocess.check_output([tile_stitcher.bfconvert_path, "-version"]) == b'version 1.0.0'
+
+
+@pytest.mark.exclude
+def test_permission_error_on_directory_creation(tile_stitcher):
+ with patch("os.makedirs", side_effect=PermissionError("Permission denied")):
+ with pytest.raises(PermissionError):
+ tile_stitcher.setup_bfconvert("/fake/path")
+
+
+# @pytest.mark.exclude
+# def test_invalid_zip_file(tile_stitcher):
+# with patch("zipfile.ZipFile", side_effect=zipfile.BadZipFile("Invalid ZIP file")):
+# with pytest.raises(zipfile.BadZipFile):
+# tile_stitcher.setup_bfconvert("/fake/path")
+
+
+@pytest.mark.exclude
+def test_permission_error_on_chmod(tile_stitcher):
+ with patch("os.chmod", side_effect=PermissionError("Permission denied")):
+ with pytest.raises(PermissionError):
+ tile_stitcher.setup_bfconvert("/fake/path")
+
+
+@pytest.mark.exclude
+def throwing_function(*args, **kwargs):
+ raise Exception("Simulated error")
+
+
+@pytest.mark.exclude
+@pytest.fixture
+def mock_tools_dir(tmp_path):
+ return tmp_path / "tools"
+
+
+@pytest.mark.exclude
+@pytest.fixture
+def mock_zip_path(mock_tools_dir):
+ return mock_tools_dir / "bftools.zip"
+
+
+@pytest.mark.exclude
+def mock_urlretrieve(*args, **kwargs):
+ with zipfile.ZipFile(args[1], "w") as zipf:
+ zipf.writestr("bftools/bfconvert", "dummy content")
+ zipf.writestr("bftools/bf.sh", "dummy content")
+
+
+@pytest.mark.exclude
+@patch("urllib.request.urlretrieve", side_effect=mock_urlretrieve)
+@patch("os.makedirs", side_effect=PermissionError)
+def test_invalid_path(mock_makedirs, mock_urlretrieve, tile_stitcher, mock_tools_dir):
+ with pytest.raises(PermissionError):
+ tile_stitcher.setup_bfconvert(str(mock_tools_dir))
+
+
+@pytest.mark.exclude
+@patch("urllib.request.urlretrieve", side_effect=mock_urlretrieve)
+@patch("zipfile.ZipFile", side_effect=zipfile.BadZipFile)
+def test_invalid_zip_file(
+ mock_zipfile, mock_urlretrieve, tile_stitcher, mock_tools_dir
+):
+ with pytest.raises(zipfile.BadZipFile):
+ tile_stitcher.setup_bfconvert(str(mock_tools_dir))
+
+
+@pytest.mark.exclude
+@patch("urllib.request.urlretrieve", side_effect=mock_urlretrieve)
+@patch("subprocess.check_output", side_effect=subprocess.CalledProcessError(1, "cmd"))
+def test_bfconvert_failure(
+ mock_subprocess, mock_urlretrieve, tile_stitcher, mock_tools_dir
+):
+ with pytest.raises(subprocess.CalledProcessError):
+ tile_stitcher.setup_bfconvert(str(mock_tools_dir))
+
+
+@pytest.mark.exclude
+def test_is_bfconvert_available_true(tile_stitcher):
+ with patch(
+ "subprocess.run",
+ return_value=subprocess.CompletedProcess(args=[], returncode=0),
+ ):
+ assert tile_stitcher.is_bfconvert_available() is True
+
+
+@pytest.mark.exclude
+def test_is_bfconvert_available_false(tile_stitcher):
+ with patch(
+ "subprocess.run",
+ return_value=subprocess.CompletedProcess(args=[], returncode=1),
+ ):
+ assert tile_stitcher.is_bfconvert_available() is False
+
+
+@pytest.mark.exclude
+def test_is_bfconvert_available_file_not_found(tile_stitcher):
+ with patch("subprocess.run", side_effect=FileNotFoundError()):
+ assert tile_stitcher.is_bfconvert_available() is False
+
+
+@pytest.mark.exclude
+def test_run_bfconvert_bfconvert_not_available(tile_stitcher, capsys):
+ tile_stitcher.bfconvert_path = "dummy_path"
+ with patch.object(tile_stitcher, "is_bfconvert_available", return_value=False):
+ tile_stitcher.run_bfconvert("dummy_stitched_image_path")
+ captured = capsys.readouterr()
+ assert (
+ "bfconvert command not available. Skipping bfconvert step." in captured.out
+ )
+
+
+@pytest.mark.exclude
+def test_run_bfconvert_custom_bfconverted_path(tile_stitcher, capsys):
+ tile_stitcher.bfconvert_path = "dummy_path"
+ with patch.object(tile_stitcher, "is_bfconvert_available", return_value=True):
+ with patch("subprocess.run") as mock_run:
+ tile_stitcher.run_bfconvert("dummy_stitched_image_path", "custom_path.tif")
+ mock_run.assert_called_once_with(
+ "./dummy_path -series 0 -separate 'dummy_stitched_image_path' 'custom_path.tif'",
+ shell=True,
+ check=True,
+ )
+ captured = capsys.readouterr()
+ assert "bfconvert completed. Output file: custom_path.tif" in captured.out
+
+
+@pytest.mark.exclude
+def test_run_bfconvert_default_bfconverted_path(tile_stitcher, capsys):
+ tile_stitcher.bfconvert_path = "dummy_path"
+ with patch.object(tile_stitcher, "is_bfconvert_available", return_value=True):
+ with patch("subprocess.run") as mock_run:
+ tile_stitcher.run_bfconvert("dummy_stitched_image_path.tif")
+ mock_run.assert_called_once_with(
+ "./dummy_path -series 0 -separate 'dummy_stitched_image_path.tif' 'dummy_stitched_image_path_sep.tif'",
+ shell=True,
+ check=True,
+ )
+ captured = capsys.readouterr()
+ assert (
+ "bfconvert completed. Output file: dummy_stitched_image_path_sep.tif"
+ in captured.out
+ )
+
+
+@pytest.mark.exclude
+def test_run_bfconvert_error(tile_stitcher, capsys):
+ tile_stitcher.bfconvert_path = "dummy_path"
+ with patch.object(tile_stitcher, "is_bfconvert_available", return_value=True):
+ with patch(
+ "subprocess.run", side_effect=subprocess.CalledProcessError(1, cmd=[])
+ ):
+ tile_stitcher.run_bfconvert("dummy_stitched_image_path.tif")
+ captured = capsys.readouterr()
+ assert "Error running bfconvert command." in captured.out
+
+
+@pytest.mark.exclude
+@pytest.fixture
+def sample_files():
+ # Paths to your sample TIF files for testing
+ return [
+ "tests/testdata/tilestitching_testdata/MISI3542i_W21-04143_bi016966_M394_OVX_LM_Scan1_[14384,29683]_component_data.tif"
+ ]
+
+
+@pytest.mark.exclude
+def test_integration_stitching(tile_stitcher, sample_files):
+ # Mocking the Java object returned by parse_regions
+ mocked_java_object = MagicMock()
+ with patch.object(tile_stitcher, "parse_regions", return_value=mocked_java_object):
+ # Test _collect_tif_files
+ collected_files = tile_stitcher._collect_tif_files(sample_files)
+ assert set(collected_files) == set(sample_files)
+
+ # Test parse_regions
+ regions = tile_stitcher.parse_regions(collected_files)
+ assert regions == mocked_java_object
+
+ # Run the actual image stitching on the sample files
+ # Assuming the method is `run_image_stitching`
+ # NOTE: Adjust the method parameters based on your actual method signature
+ tile_stitcher.run_image_stitching(
+ sample_files,
+ "tests/testdata/tilestitching_testdata/temp",
+ separate_series=True,
+ )
+
+ # Add more assertions here if you have additional methods or behaviors to verify
+
+
+# @pytest.mark.exclude
+# def test_write_pyramidal_image_server(tile_stitcher, sample_files):
+# infiles = tile_stitcher._collect_tif_files(sample_files)
+# fileout, file_jpype = tile_stitcher._get_outfile(
+# "tests/testdata/tilestitching_testdata/output_temp"
+# )
+# downsamples = [1]
+# if not infiles or not file_jpype:
+# return
+
+# server = tile_stitcher.parse_regions(infiles)
+# server = tile_stitcher.ImageServers.pyramidalize(server)
+# tile_stitcher._write_pyramidal_image_server(server, file_jpype, downsamples)
+
+# downsamples = None
+# tile_stitcher._write_pyramidal_image_server(server, file_jpype, downsamples)
+
+
+@pytest.mark.exclude
+def test_set_environment_paths_without_java_path(tile_stitcher):
+ with patch.dict(os.environ, {}, clear=True):
+ with patch.object(
+ tile_stitcher, "get_system_java_home", return_value="/dummy/java/home"
+ ):
+ tile_stitcher.__init__(
+ qupath_jarpath=[], java_path=None, memory="40g", bfconvert_dir="./"
+ )
+ assert "JAVA_HOME" in os.environ
+ assert os.environ["JAVA_HOME"] == "/dummy/java/home"
+
+
+@pytest.mark.exclude
+def test_setup_bfconvert_permission_error_on_directory_creation(tile_stitcher):
+ with patch("os.path.exists", return_value=False):
+ with patch("os.makedirs", side_effect=PermissionError("Permission denied")):
+ with pytest.raises(PermissionError) as exc_info:
+ tile_stitcher.setup_bfconvert("/fake/path")
+ assert "Permission denied: Cannot create directory" in str(exc_info.value)
+
+
+@pytest.mark.exclude
+def test_setup_bfconvert_bad_zip_file(tile_stitcher, mock_tools_dir):
+ with patch("os.path.exists", return_value=False):
+ with patch("urllib.request.urlretrieve"):
+ with patch(
+ "zipfile.ZipFile", side_effect=zipfile.BadZipFile("Invalid ZIP file")
+ ):
+ with pytest.raises(zipfile.BadZipFile):
+ tile_stitcher.setup_bfconvert(str(mock_tools_dir))
+
+
+@pytest.mark.exclude
+def test_setup_bfconvert_permission_error_on_chmod(
+ tile_stitcher, mock_tools_dir, tmp_path
+):
+ dummy_bfconvert = tmp_path / "bfconvert"
+ dummy_bfconvert.touch()
+ dummy_bf_sh = tmp_path / "bf.sh"
+ dummy_bf_sh.touch()
+
+ tile_stitcher.bfconvert_path = str(dummy_bfconvert)
+ tile_stitcher.bf_sh_path = str(dummy_bf_sh)
+
+ with patch("os.path.exists", return_value=True):
+ with patch("os.stat", return_value=os.stat(dummy_bf_sh)):
+ with patch("os.chmod", side_effect=PermissionError("Permission denied")):
+ with pytest.raises(PermissionError):
+ tile_stitcher.setup_bfconvert(str(mock_tools_dir))
+
+
+@pytest.mark.exclude
+def test_set_environment_paths_without_java_path_exception(tile_stitcher):
+ with patch.dict(os.environ, {}, clear=True):
+ with patch.object(tile_stitcher, "get_system_java_home", return_value=""):
+ with pytest.raises(EnvironmentError) as exc_info:
+ tile_stitcher.set_environment_paths()
+ assert "JAVA_HOME not found" in str(exc_info.value)
+
+
+@pytest.mark.exclude
+def test_get_system_java_home_failure(tile_stitcher):
+ with patch("subprocess.getoutput", side_effect=Exception("Command failed")):
+ result = tile_stitcher.get_system_java_home()
+ assert result == ""
+
+
+@pytest.mark.exclude
+def test_collect_tif_files_invalid_input(tile_stitcher):
+ invalid_input = 123 # not a string or list
+ result = tile_stitcher._collect_tif_files(invalid_input)
+ assert result == []
+
+
+@pytest.mark.exclude
+def test_check_tiff_io_error(tile_stitcher):
+ with patch("builtins.open", side_effect=IOError("IO error occurred")):
+ with pytest.raises(IOError):
+ tile_stitcher.checkTIFF("invalid_file.tif")
+
+
+@pytest.mark.exclude
+def test_start_jvm_exception(tile_stitcher):
+ with patch("jpype.isJVMStarted", return_value=False):
+ with patch("jpype.startJVM", side_effect=Exception("JVM start error")):
+ with pytest.raises(SystemExit) as exc_info:
+ tile_stitcher._start_jvm()
+ assert exc_info.type == SystemExit
+
+
+@pytest.mark.exclude
+def test_import_qupath_classes_exception(tile_stitcher):
+ with patch("jpype.JPackage", side_effect=Exception("Import error")):
+ with pytest.raises(RuntimeError) as exc_info:
+ tile_stitcher._import_qupath_classes()
+ assert "Failed to import QuPath classes" in str(exc_info.value)
+
+
+@pytest.mark.exclude
+@patch("builtins.open", mock_open(read_data=b"non TIFF data"))
+def test_parse_region_invalid_tiff(tile_stitcher):
+ non_tiff_file = "non_tiff_file.txt"
+ assert tile_stitcher.parseRegion(non_tiff_file) is None
+
+
+@pytest.mark.exclude
+def test_run_image_stitching_with_empty_input(tile_stitcher, sample_files):
+ # Mocking an empty input scenario
+ with patch.object(tile_stitcher, "_collect_tif_files", return_value=[]):
+ # Output file
+ output_file = "output.ome.tif"
+ # Running the stitching method
+ tile_stitcher.run_image_stitching(sample_files, output_file)
+ # Assertions to check if the method returns early as expected
+ # (You can use mocks to assert that certain methods were not called)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 011646eb..2ce5769a 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -3,6 +3,12 @@
License: GNU GPL 2.0
"""
+import os
+import shutil
+import tempfile
+from pathlib import Path
+from unittest.mock import mock_open, patch
+
import cv2
import numpy as np
import pytest
@@ -17,12 +23,14 @@
_pad_or_crop_1d,
contour_centroid,
download_from_url,
+ find_qupath_home,
normalize_matrix_cols,
normalize_matrix_rows,
pad_or_crop,
parse_file_size,
plot_mask,
segmentation_lines,
+ setup_qupath,
sort_points_clockwise,
upsample_array,
)
@@ -43,6 +51,51 @@ def test_download_from_url(tmp_path):
assert file1.readline() == "format: Aperio SVS\n"
+# Test successful download
+def test_successful_download(tmp_path):
+ url = "http://example.com/testfile.txt"
+ download_dir = tmp_path / "downloads"
+ file_content = b"Sample file content"
+ with patch(
+ "urllib.request.urlopen", mock_open(read_data=file_content)
+ ) as mocked_url_open, patch("builtins.open", mock_open()) as mocked_file:
+ download_from_url(url, download_dir, "downloaded_file.txt")
+
+ mocked_url_open.assert_called_with(url)
+ mocked_file.assert_called_with(
+ os.path.join(download_dir, "downloaded_file.txt"), "wb"
+ )
+
+
+# Test skipping download for existing file
+def test_skip_existing_download(tmp_path):
+ url = "http://example.com/testfile.txt"
+ download_dir = tmp_path / "downloads"
+ download_dir.mkdir(parents=True, exist_ok=True) # Ensure directory exists
+ existing_file = download_dir / "existing_file.txt"
+ existing_file.touch() # Create an empty file
+
+ with patch("urllib.request.urlopen", mock_open()) as mocked_url_open:
+ download_from_url(url, download_dir, "existing_file.txt")
+
+ mocked_url_open.assert_not_called()
+
+
+# Test download with default filename
+def test_download_default_filename(tmp_path):
+ url = "http://example.com/testfile.txt"
+ download_dir = tmp_path / "downloads"
+ file_content = b"Sample file content for default"
+
+ with patch(
+ "urllib.request.urlopen", mock_open(read_data=file_content)
+ ) as mocked_url_open, patch("builtins.open", mock_open()) as mocked_file:
+ download_from_url(url, download_dir)
+
+ mocked_url_open.assert_called_with(url)
+ mocked_file.assert_called_with(os.path.join(download_dir, "testfile.txt"), "wb")
+
+
@pytest.fixture(scope="module")
def random_rgb():
im = np.random.randint(low=0, high=255, size=(50, 50, 3), dtype=np.uint8)
@@ -196,3 +249,52 @@ def test_normalize_matrix_rows(random_50_50):
def test_normalize_matrix_cols(random_50_50):
a = normalize_matrix_cols(random_50_50)
assert np.all(np.isclose(np.linalg.norm(a, axis=0), 1.0))
+
+
+def test_find_existing_qupath_home(tmp_path):
+ # Create a mock QuPath directory structure
+ qupath_dir = tmp_path / "qupath"
+ qupath_dir.mkdir(parents=True)
+ qupath_jar = qupath_dir / "qupath.jar"
+ qupath_jar.touch()
+
+ # Test if the function finds the QuPath home correctly
+ qupath_home = find_qupath_home(str(tmp_path))
+ assert qupath_home == str(qupath_dir.parent.parent)
+
+
+def test_no_qupath_home_found(tmp_path):
+ # Test with a directory without QuPath JAR
+ qupath_home = find_qupath_home(str(tmp_path))
+ assert qupath_home is None
+
+
+def test_find_qupath_home():
+ # Create a temporary directory
+ with tempfile.TemporaryDirectory() as temp_dir:
+ # Case 1: QuPath jar is present
+ os.makedirs(Path(temp_dir) / "qupath")
+ open(Path(temp_dir) / "qupath/qupath.jar", "a").close()
+ assert find_qupath_home(temp_dir) is not None
+
+ # Cleanup
+ shutil.rmtree(Path(temp_dir) / "qupath")
+
+ # Case 2: QuPath jar is not present
+ assert find_qupath_home(temp_dir) is None
+
+
+@patch("builtins.print") # To suppress print statements in the test
+def test_setup_qupath(mock_print):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ qupath_home = Path(temp_dir) / "qupath"
+
+ # Simulate the environment before QuPath installation
+ expected_path = (
+ qupath_home / "QuPath"
+ ) # Update according to the actual behavior of setup_qupath
+ assert setup_qupath(str(qupath_home)) == str(expected_path)
+ print(setup_qupath(str(qupath_home)))
+ print(str(expected_path))
+ # Test when QuPath is already installed
+ assert setup_qupath(str(qupath_home)) == str(expected_path)
diff --git a/tests/testdata/tilestitching_testdata/MISI3542i_W21-04143_bi016966_M394_OVX_LM_Scan1_[14384,29683]_component_data.tif b/tests/testdata/tilestitching_testdata/MISI3542i_W21-04143_bi016966_M394_OVX_LM_Scan1_[14384,29683]_component_data.tif
new file mode 100644
index 00000000..5bb80d6a
--- /dev/null
+++ b/tests/testdata/tilestitching_testdata/MISI3542i_W21-04143_bi016966_M394_OVX_LM_Scan1_[14384,29683]_component_data.tif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:663fc47acee22358bfa1ddccf3aac3a89d850650583c2ef08a66b08c4cf9c9b1
+size 35747841