Skip to content

Commit

Permalink
Make flake8 compliant
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunUllanat committed Dec 6, 2023
1 parent b275a04 commit 10fdde7
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 992 deletions.
53 changes: 52 additions & 1 deletion examples/construct_graphs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"from pathml.preprocessing import StainNormalizationHE\n",
"from pathml.graph import RAGGraphBuilder, KNNGraphBuilder\n",
"from pathml.graph import ColorMergedSuperpixelExtractor\n",
"from pathml.graph.utils import _valid_image, _exists, plot_graph_on_image, get_full_instance_map, build_assignment_matrix"
"from pathml.graph.utils import get_full_instance_map, build_assignment_matrix"
]
},
{
Expand Down Expand Up @@ -256,6 +256,57 @@
"metadata": {},
"outputs": [],
"source": [
"def plot_graph_on_image(graph, image):\n",
" \"\"\"\n",
" Plots a given graph on the original WSI image\n",
"\n",
" Args:\n",
" graph (torch.tensor): Graph as an sparse edge index\n",
" image (numpy.array): Input image\n",
" \"\"\"\n",
"\n",
" from torch_geometric.utils.convert import to_networkx\n",
"\n",
" pos = graph.node_centroids.numpy()\n",
" G = to_networkx(graph, to_undirected=True)\n",
" plt.imshow(image)\n",
" nx.draw(G, pos, node_size=25)\n",
" plt.show()\n",
"\n",
"def _valid_image(nr_pixels):\n",
" \"\"\"\n",
" Checks if image does not exceed maximum number of pixels or exceeds minimum number of pixels.\n",
"\n",
" Args:\n",
" nr_pixels (int): Number of pixels in given image\n",
" \"\"\"\n",
"\n",
" if nr_pixels > MIN_NR_PIXELS and nr_pixels < MAX_NR_PIXELS:\n",
" return True\n",
" return False\n",
"\n",
"def _exists(cg_out, tg_out, assign_out, overwrite):\n",
" \"\"\"\n",
" Checks if given input files exist or not\n",
"\n",
" Args:\n",
" cg_out (str): Cell graph file\n",
" tg_out (str): Tissue graph file\n",
" assign_out (str): Assignment matrix file\n",
" overwrite (bool): Whether to overwrite files or not. If true, this function return false and files are\n",
" overwritten.\n",
" \"\"\"\n",
" if overwrite:\n",
" return False\n",
" else:\n",
" if (\n",
" os.path.isfile(cg_out)\n",
" and os.path.isfile(tg_out)\n",
" and os.path.isfile(assign_out)\n",
" ):\n",
" return True\n",
" return False\n",
"\n",
"def process(image_path, save_path, split, plot=True, overwrite=False):\n",
" # 1. get image path\n",
" subdirs = os.listdir(image_path)\n",
Expand Down
4 changes: 3 additions & 1 deletion pathml/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(
)

self.threshold = int(self.patch_size * self.patch_size * threshold)
self.warning_threshold = 0.75
self.warning_threshold = 0.50

try:
from torchvision import transforms
Expand Down Expand Up @@ -256,6 +256,8 @@ def __init__(
elif self.entity == "tissue":
self._precompute_tissue()

self._warning()

def _add_patch(self, center_x, center_y, instance_index, region_count):
"""Extract and include patch information."""

Expand Down
12 changes: 10 additions & 2 deletions pathml/graph/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def process( # type: ignore[override]
node_features = self._compute_node_features(features, image_size)

if annotation is not None:
node_labels = self._compute_node_labels(instance_map, annotation)
node_labels = self._set_node_labels(instance_map, annotation)
else:
node_labels = None

Expand Down Expand Up @@ -230,7 +230,7 @@ def _set_node_labels(self, instance_map, annotation):
assert annotation.shape[0] == len(
regions
), "Number of annotations do not match number of nodes"
return torch.FloatTensor(annotation.astype(float))
return annotation

def _build_topology(self, instance_map):
"""Build topology using (thresholded) kNN"""
Expand Down Expand Up @@ -274,6 +274,14 @@ def __init__(self, kernel_size=3, hops=1, **kwargs):
self.hops = hops
super().__init__(**kwargs)

def _set_node_labels(self, instance_map, annotation):
"""Set the node labels of the graphs using annotation"""
regions = regionprops(instance_map)
assert annotation.shape[0] == len(
regions
), "Number of annotations do not match number of nodes"
return annotation

def _build_topology(self, instance_map):
"""Create the graph topology from the instance connectivty in the instance_map"""

Expand Down
88 changes: 0 additions & 88 deletions pathml/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,8 @@
Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
License: GNU GPL 2.0
"""

import importlib
import math
import os

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch
from skimage.measure import label, regionprops
Expand Down Expand Up @@ -87,74 +82,6 @@ def __inc__(self, key, value, *args, **kwargs):
return super().__inc__(key, value, *args, **kwargs)


def dynamic_import_from(source_file: str, class_name: str):
"""Do a from source_file import class_name dynamically
Args:
source_file (str): Where to import from
class_name (str): What to import
Returns:
Any: The class to be imported
"""
module = importlib.import_module(source_file)
return getattr(module, class_name)


def _valid_image(nr_pixels):
"""
Checks if image does not exceed maximum number of pixels or exceeds minimum number of pixels.
Args:
nr_pixels (int): Number of pixels in given image
"""

if nr_pixels > MIN_NR_PIXELS and nr_pixels < MAX_NR_PIXELS:
return True
return False


def plot_graph_on_image(graph, image):
"""
Plots a given graph on the original WSI image
Args:
graph (torch.tensor): Graph as an sparse edge index
image (numpy.array): Input image
"""

from torch_geometric.utils.convert import to_networkx

pos = graph.node_centroids.numpy()
G = to_networkx(graph, to_undirected=True)
plt.imshow(image)
nx.draw(G, pos, node_size=25)
plt.show()


def _exists(cg_out, tg_out, assign_out, overwrite):
"""
Checks if given input files exist or not
Args:
cg_out (str): Cell graph file
tg_out (str): Tissue graph file
assign_out (str): Assignment matrix file
overwrite (bool): Whether to overwrite files or not. If true, this function return false and files are
overwritten.
"""

if overwrite:
return False
else:
if (
os.path.isfile(cg_out)
and os.path.isfile(tg_out)
and os.path.isfile(assign_out)
):
return True
return False


def get_full_instance_map(wsi, patch_size, mask_name="cell"):
"""
Generates and returns the normalized image, cell instance map and cell centroids from pathml SlideData object
Expand Down Expand Up @@ -223,21 +150,6 @@ def build_assignment_matrix(low_level_centroids, high_level_map, matrix=False):
return assignment_matrix


def compute_histogram(input_array: np.ndarray, nr_values: int) -> np.ndarray:
"""Calculates a histogram of a matrix of the values from 0 up to (excluding) nr_values
Args:
x (np.array): Input tensor.
nr_values (int): Possible values. From 0 up to (exclusing) nr_values.
Returns:
np.array: Output tensor.
"""
output_array = np.empty(nr_values, dtype=int)
for i in range(nr_values):
output_array[i] = (input_array == i).sum()
return output_array


def two_hop(edge_index, num_nodes):
"""Calculates the two-hop graph.
Args:
Expand Down
Loading

0 comments on commit 10fdde7

Please sign in to comment.