diff --git a/src/utils/models/StardistFluo/StardistFluo.ts b/src/utils/models/Stardist/AbstractStardist.ts similarity index 77% rename from src/utils/models/StardistFluo/StardistFluo.ts rename to src/utils/models/Stardist/AbstractStardist.ts index 44bb3c51..8b947de5 100644 --- a/src/utils/models/StardistFluo/StardistFluo.ts +++ b/src/utils/models/Stardist/AbstractStardist.ts @@ -1,23 +1,16 @@ import { GraphModel, History, LayersModel } from "@tensorflow/tfjs"; import { Segmenter } from "../AbstractSegmenter/AbstractSegmenter"; -import { loadStardist } from "./loadStardist"; import { preprocessStardist } from "./preprocessStardist"; import { predictStardist } from "./predictStardist"; import { generateUUID } from "utils/common/helpers"; import { LoadInferenceDataArgs } from "../types"; -import { ModelTask } from "../enums"; import { Kind, ImageObject } from "store/data/types"; /* - * Stardist (Versatile) Fluorescence Nuclei Segmentation - * https://zenodo.org/records/6348085 - * https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6348084 - * https://github.com/stardist/stardist/blob/master/README.md#pretrained-models-for-2d - * Stardist: model for object detection / instance segmentation with star-convex shapes - * This pretrained model: meant to segment individual cell nuclei from single channel fluorescence data (2018 DSB) + * Abstract model for Stardist variants */ -export class StardistFluo extends Segmenter { +export abstract class Stardist extends Segmenter { protected _fgKind?: Kind; protected _inferenceDataDims?: Array<{ width: number; @@ -26,25 +19,7 @@ export class StardistFluo extends Segmenter { padY: number; }>; - constructor() { - super({ - name: "StardistFluo", - task: ModelTask.Segmentation, - graph: true, - pretrained: true, - trainable: false, - requiredChannels: 3, - }); - } - - public async loadModel() { - if (this._model) return; - // inputs: [ {name: 'input', shape: [-1,-1,-1,1], dtype: 'float32'} ] - // outputs: [ {name: 'concatenate_4/concat', shape: [-1, -1, -1, 33], dtype: 'float32'} ] - // where each -1 matches on input and output of corresponding dim/axis - // 33 -> 1 probability score, followed by 32 radial equiangular distances of rays - this._model = await loadStardist(); - } + public abstract loadModel(): Promise; public loadTraining(images: ImageObject[], preprocessingArgs: any): void {} diff --git a/src/utils/models/Stardist/StardistFluo/StardistFluo.ts b/src/utils/models/Stardist/StardistFluo/StardistFluo.ts new file mode 100644 index 00000000..e96b7461 --- /dev/null +++ b/src/utils/models/Stardist/StardistFluo/StardistFluo.ts @@ -0,0 +1,33 @@ +import { Stardist } from "../AbstractStardist"; +import { loadStardistFluo } from "./loadStardistFluo"; +import { ModelTask } from "../../enums"; + +/* + * Stardist (Versatile) Fluorescence Nuclei Segmentation + * https://zenodo.org/records/6348085 + * https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6348084 + * https://github.com/stardist/stardist/blob/master/README.md#pretrained-models-for-2d + * Stardist: model for object detection / instance segmentation with star-convex shapes + * This pretrained model: meant to segment individual cell nuclei from single channel fluorescence data (2018 DSB) + */ +export class StardistFluo extends Stardist { + constructor() { + super({ + name: "StardistFluo", + task: ModelTask.Segmentation, + graph: true, + pretrained: true, + trainable: false, + requiredChannels: 3, + }); + } + + public async loadModel() { + if (this._model) return; + // inputs: [ {name: 'input', shape: [-1,-1,-1,1], dtype: 'float32'} ] + // outputs: [ {name: 'concatenate_4/concat', shape: [-1, -1, -1, 33], dtype: 'float32'} ] + // where each -1 matches on input and output of corresponding dim/axis + // 33 -> 1 probability score, followed by 32 radial equiangular distances of rays + this._model = await loadStardistFluo(); + } +} diff --git a/src/utils/models/StardistFluo/index.ts b/src/utils/models/Stardist/StardistFluo/index.ts similarity index 100% rename from src/utils/models/StardistFluo/index.ts rename to src/utils/models/Stardist/StardistFluo/index.ts diff --git a/src/utils/models/StardistFluo/loadStardist.ts b/src/utils/models/Stardist/StardistFluo/loadStardistFluo.ts similarity index 92% rename from src/utils/models/StardistFluo/loadStardist.ts rename to src/utils/models/Stardist/StardistFluo/loadStardistFluo.ts index 02fb4bb6..2548fff5 100644 --- a/src/utils/models/StardistFluo/loadStardist.ts +++ b/src/utils/models/Stardist/StardistFluo/loadStardistFluo.ts @@ -19,7 +19,7 @@ import Stardist2DFluorescenceWeights2 from "data/model-data/stardist-fluo/group1 * from relative paths described by the paths fields in weights manifest. */ -export async function loadStardist() { +export async function loadStardistFluo() { let modelDescription: File; let modelWeights1: File; let modelWeights2: File; @@ -50,7 +50,7 @@ export async function loadStardist() { const error: Error = err as Error; process.env.NODE_ENV !== "production" && process.env.REACT_APP_LOG_LEVEL === "1" && - console.error(`error loading stardist: ${error.message}`); + console.error(`error loading stardist fluorescence: ${error.message}`); throw err; } @@ -65,7 +65,7 @@ export async function loadStardist() { process.env.NODE_ENV !== "production" && process.env.REACT_APP_LOG_LEVEL === "1" && - console.error(`error loading stardist: ${error.message}`); + console.error(`error loading stardist fluorescence: ${error.message}`); throw err; } diff --git a/src/utils/models/Stardist/StardistVHE/StardistVHE.ts b/src/utils/models/Stardist/StardistVHE/StardistVHE.ts new file mode 100644 index 00000000..867ce28a --- /dev/null +++ b/src/utils/models/Stardist/StardistVHE/StardistVHE.ts @@ -0,0 +1,33 @@ +import { Stardist } from "../AbstractStardist"; +import { loadStardistVHE } from "./loadStardistVHE"; +import { ModelTask } from "../../enums"; + +/* + * Stardist (Versatile) H&E Nuclei Segmentation + * https://zenodo.org/record/6338615 + * https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6338614&type=model + * https://github.com/stardist/stardist/blob/master/README.md#pretrained-models-for-2d + * Stardist: model for object detection / instance segmentation with star-convex shapes + * This pretrained model: meant to segment individual cell nuclei from brightfield images with H&E staining + */ +export class StardistVHE extends Stardist { + constructor() { + super({ + name: "StardistVHE", + task: ModelTask.Segmentation, + graph: true, + pretrained: true, + trainable: false, + requiredChannels: 3, + }); + } + + public async loadModel() { + if (this._model) return; + // inputs: [ {name: 'input', shape: [-1,-1,-1,3], dtype: 'float32'} ] + // outputs: [ {name: 'concatenate_4/concat', shape: [-1, -1, -1, 33], dtype: 'float32'} ] + // where each -1 matches on input and output of corresponding dim/axis + // 33 -> 1 probability score, followed by 32 radial equiangular distances of rays + this._model = await loadStardistVHE(); + } +} diff --git a/src/utils/models/StardistVHE/index.ts b/src/utils/models/Stardist/StardistVHE/index.ts similarity index 100% rename from src/utils/models/StardistVHE/index.ts rename to src/utils/models/Stardist/StardistVHE/index.ts diff --git a/src/utils/models/StardistVHE/loadStardist.ts b/src/utils/models/Stardist/StardistVHE/loadStardistVHE.ts similarity index 92% rename from src/utils/models/StardistVHE/loadStardist.ts rename to src/utils/models/Stardist/StardistVHE/loadStardistVHE.ts index 4ec38373..bc36e4cd 100644 --- a/src/utils/models/StardistVHE/loadStardist.ts +++ b/src/utils/models/Stardist/StardistVHE/loadStardistVHE.ts @@ -19,7 +19,7 @@ import Stardist2DBrightfieldWeights2 from "data/model-data//stardist-vhe/group1- * from relative paths described by the paths fields in weights manifest. */ -export async function loadStardist() { +export async function loadStardistVHE() { let modelDescription: File; let modelWeights1: File; let modelWeights2: File; @@ -50,7 +50,7 @@ export async function loadStardist() { const error: Error = err as Error; process.env.NODE_ENV !== "production" && process.env.REACT_APP_LOG_LEVEL === "1" && - console.error(`error loading stardist: ${error.message}`); + console.error(`error loading stardist H&E: ${error.message}`); throw err; } @@ -65,7 +65,7 @@ export async function loadStardist() { process.env.NODE_ENV !== "production" && process.env.REACT_APP_LOG_LEVEL === "1" && - console.error(`error loading stardist: ${error.message}`); + console.error(`error loading stardist H&E: ${error.message}`); throw err; } diff --git a/src/utils/models/Stardist/index.ts b/src/utils/models/Stardist/index.ts new file mode 100644 index 00000000..c9406c99 --- /dev/null +++ b/src/utils/models/Stardist/index.ts @@ -0,0 +1,2 @@ +export { StardistFluo } from "./StardistFluo"; +export { StardistVHE } from "./StardistVHE"; diff --git a/src/utils/models/StardistFluo/predictStardist.ts b/src/utils/models/Stardist/predictStardist.ts similarity index 100% rename from src/utils/models/StardistFluo/predictStardist.ts rename to src/utils/models/Stardist/predictStardist.ts diff --git a/src/utils/models/StardistFluo/preprocessStardist.ts b/src/utils/models/Stardist/preprocessStardist.ts similarity index 100% rename from src/utils/models/StardistFluo/preprocessStardist.ts rename to src/utils/models/Stardist/preprocessStardist.ts diff --git a/src/utils/models/StardistVHE/StardistVHE.ts b/src/utils/models/StardistVHE/StardistVHE.ts deleted file mode 100644 index d00a4a4a..00000000 --- a/src/utils/models/StardistVHE/StardistVHE.ts +++ /dev/null @@ -1,164 +0,0 @@ -import { GraphModel, History, LayersModel } from "@tensorflow/tfjs"; - -import { Segmenter } from "../AbstractSegmenter/AbstractSegmenter"; -import { loadStardist } from "./loadStardist"; -import { preprocessStardist } from "./preprocessStardist"; -import { predictStardist } from "./predictStardist"; -import { generateUUID } from "utils/common/helpers"; -import { LoadInferenceDataArgs } from "../types"; -import { ModelTask } from "../enums"; -import { Kind, ImageObject } from "store/data/types"; - -/* - * Stardist (Versatile) H&E Nuclei Segmentation - * https://zenodo.org/record/6338615 - * https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6338614&type=model - * https://github.com/stardist/stardist/blob/master/README.md#pretrained-models-for-2d - * Stardist: model for object detection / instance segmentation with star-convex shapes - * This pretrained model: meant to segment individual cell nuclei from brightfield images with H&E staining - */ -export class StardistVHE extends Segmenter { - protected _fgKind?: Kind; - protected _inferenceDataDims?: Array<{ - width: number; - height: number; - padX: number; - padY: number; - }>; - - constructor() { - super({ - name: "StardistVHE", - task: ModelTask.Segmentation, - graph: true, - pretrained: true, - trainable: false, - requiredChannels: 3, - }); - } - - public async loadModel() { - if (this._model) return; - // inputs: [ {name: 'input', shape: [-1,-1,-1,3], dtype: 'float32'} ] - // outputs: [ {name: 'concatenate_4/concat', shape: [-1, -1, -1, 33], dtype: 'float32'} ] - // where each -1 matches on input and output of corresponding dim/axis - // 33 -> 1 probability score, followed by 32 radial equiangular distances of rays - this._model = await loadStardist(); - } - - public loadTraining(images: ImageObject[], preprocessingArgs: any): void {} - - public loadValidation(images: ImageObject[], preprocessingArgs: any): void {} - - // This Stardist model requires image dimensions to be a multiple of 16 - // (for VHE in particular), see: - // https://github.com/stardist/stardist/blob/468c60552c8c93403969078e51bddc9c2c702035/stardist/models/model2d.py#L543 - // https://github.com/stardist/stardist/blob/master/stardist/models/model2d.py#L201C30-L201C30 - // and config here (under source -> grid): https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6338614 - // basically, in the case of VHE: 2^3 * 2 = 16 - protected _getPaddings(height: number, width: number) { - const padY = height % 16 === 0 ? 0 : 16 - (height % 16); - const padX = width % 16 === 0 ? 0 : 16 - (width % 16); - - return { padY, padX }; - } - - public loadInference( - images: ImageObject[], - preprocessingArgs: LoadInferenceDataArgs - ): void { - this._inferenceDataDims = images.map((im) => { - const { height, width } = im.shape; - const { padX, padY } = this._getPaddings(height, width); - return { height, width, padY, padX }; - }); - - this._inferenceDataset = preprocessStardist( - images, - 1, - this._inferenceDataDims - ); - - if (preprocessingArgs.kinds) { - if (preprocessingArgs.kinds.length !== 1) - throw Error( - `${this.name} Model only takes a single foreground category` - ); - this._fgKind = preprocessingArgs.kinds[0]; - } else if (!this._fgKind) { - const unknownCategoryId = generateUUID({ definesUnknown: true }); - this._fgKind = { - id: "Nucleus", - categories: [unknownCategoryId], - containing: [], - unknownCategoryId, - }; - } - } - - public async train(options: any, callbacks: any): Promise { - if (!this.trainable) { - throw new Error(`Training not supported for Model ${this.name}`); - } else { - throw new Error(`Training not yet implemented for Model ${this.name}`); - } - } - - public async predict() { - if (!this._model) { - throw Error(`"${this.name}" Model not loaded`); - } - - if (this._model instanceof LayersModel) { - throw Error(`"${this.name}" Model must a Graph, not Layers`); - } - - if (!this._inferenceDataset) { - throw Error(`"${this.name}" Model's inference data not loaded`); - } - - if (!this._fgKind) { - throw Error(`"${this.name}" Model's foreground kind is not loaded`); - } - - if (!this._inferenceDataDims) { - throw Error( - `"${this.name}" Model's inference data dimensions and padding information not loaded` - ); - } - - const graphModel = this._model as GraphModel; - - const infT = await this._inferenceDataset.toArray(); - // imTensor disposed in `predictStardist` - const annotationsPromises = infT.map((imTensor, idx) => { - return predictStardist( - graphModel, - imTensor, - this._fgKind!.id, - this._fgKind!.unknownCategoryId, - this._inferenceDataDims![idx] - ); - }); - const annotations = await Promise.all(annotationsPromises); - - return annotations; - } - - public inferenceCategoriesById(catIds: Array) { - return []; - } - public inferenceKindsById(kinds: string[]) { - if (!this._fgKind) { - throw Error(`"${this.name}" Model has no foreground kind loaded`); - } - - return kinds.includes(this._fgKind.id) ? [this._fgKind] : []; - } - - public override dispose() { - this._inferenceDataDims = undefined; - this._fgKind = undefined; - super.dispose(); - } -} diff --git a/src/utils/models/StardistVHE/predictStardist.ts b/src/utils/models/StardistVHE/predictStardist.ts deleted file mode 100644 index a696e1aa..00000000 --- a/src/utils/models/StardistVHE/predictStardist.ts +++ /dev/null @@ -1,230 +0,0 @@ -import { - Tensor4D, - GraphModel, - dispose, - tidy, - image as TFImage, - tensor2d, - tensor1d, - setBackend, - getBackend, -} from "@tensorflow/tfjs"; - -import { encode, scanline, simplifyPolygon } from "utils/annotator"; -import { connectPoints } from "utils/annotator"; -import { OrphanedAnnotationObject } from "../AbstractSegmenter/AbstractSegmenter"; -import { generateUUID } from "utils/common/helpers"; -import { Partition } from "../enums"; -import { Point } from "utils/annotator/types"; - -const computeAnnotationMaskFromPoints = ( - cropDims: { x: number; y: number; width: number; height: number }, - coordinates: Array, - imH: number, - imW: number -) => { - // get coordinates of connected points and draw boundaries of mask - const connectedPoints = connectPoints(coordinates); - const simplifiedPoints = simplifyPolygon(connectedPoints); - const maskImage = scanline(simplifiedPoints, imW, imH); - // @ts-ignore: getChannel API is not exposed - const greyScaleMask = maskImage.getChannel(0); // as ImageJS.Image; - const cropped = greyScaleMask.crop(cropDims); - - return cropped; -}; - -function buildPolygon( - distances: Array, - row: number, - col: number, - imH: number, - imW: number, - inputImDims: { - width: number; - height: number; - padX: number; - padY: number; - } -) { - const THETA = (2 / distances.length) * Math.PI; // 0.19635, for 32 distances - const points: Array = []; - var xMin = Infinity; - var yMin = Infinity; - var xMax = 0; - var yMax = 0; - - distances.forEach((length, idx) => { - const y = Math.max( - 0, - Math.min( - imH - 1, - Math.round( - row + - length * Math.sin(idx * THETA) - - Math.floor(inputImDims.padY / 2) - ) - ) - ); - - const x = Math.max( - 0, - Math.min( - imW - 1, - Math.round( - col + - length * Math.cos(idx * THETA) - - Math.floor(inputImDims.padX / 2) - ) - ) - ); - - yMin = y < yMin ? y : yMin; - yMax = y > yMax ? y : yMax; - - xMin = x < xMin ? x : xMin; - xMax = x > xMax ? x : xMax; - - points.push({ x, y }); - }); - - let bbox: [number, number, number, number] = [ - Math.max(0, xMin), - Math.max(0, yMin), - Math.min(inputImDims.width, xMax), - Math.min(inputImDims.height, yMax), - ]; - - const boxH = bbox[3] - bbox[1]; - const boxW = bbox[2] - bbox[0]; - - if (boxW <= 0 || boxH <= 0) return; - - let cropDims = { x: bbox[0], y: bbox[1], width: boxW, height: boxH }; - - const poly = computeAnnotationMaskFromPoints(cropDims, points, imH, imW); - - return { decodedMask: poly.data, bbox: bbox }; -} - -function generateAnnotations( - preds: number[][][], - kindId: string, - unknownCategoryId: string, - height: number, - width: number, - inputImDims: { - width: number; - height: number; - padX: number; - padY: number; - }, - scoreThresh: number -) { - const generatedAnnotations: Array = []; - const scores: Array = []; - const generatedBboxes: Array<[number, number, number, number]> = []; - - preds.forEach((row: Array>, i) => { - row.forEach((output: Array, j) => { - if (output[0] >= scoreThresh) { - const polygon = buildPolygon( - output.slice(1), // radial distances - i, - j, - height, - width, - inputImDims - ); - - if (!polygon) return; - - scores.push(output[0]); - - const { decodedMask, bbox } = polygon; - - const annotation = { - encodedMask: encode(decodedMask), - boundingBox: bbox, - kind: kindId, - categoryId: unknownCategoryId, - partition: Partition.Unassigned, - id: generateUUID(), - activePlane: 0, - }; - - generatedAnnotations.push(annotation); - generatedBboxes.push(bbox); - } - }); - }); - return { generatedAnnotations, generatedBboxes, scores }; -} -export const predictStardist = async ( - model: GraphModel, - imTensor: Tensor4D, // expects 1 for batchSize dim (axis 0) - kindId: string, // foreground (nucleus kind) - unknownCategoryId: string, - inputImDims: { - width: number; - height: number; - padX: number; - padY: number; - }, - NMS_IoUThresh: number = 0.1, - NMS_scoreThresh: number = 0.3, - NMS_maxOutputSize: number = 500, - NMS_softNmsSigma: number = 0.0 -) => { - // [batchSize, H, W, 33] - const res = model.execute(imTensor) as Tensor4D; - const preds = (await res.array())[0]; - - dispose(imTensor); - dispose(res); - - const prevBackend = getBackend(); - - if (prevBackend === "webgl") { - setBackend("cpu"); - } - - const { generatedAnnotations, generatedBboxes, scores } = generateAnnotations( - preds, - kindId, - unknownCategoryId, - res.shape[1], // H - res.shape[2], // W - inputImDims, - NMS_scoreThresh - ); - - const indexTensor = tidy(() => { - const bboxTensor = tensor2d(generatedBboxes); - const scoresTensor = tensor1d(scores); - - return TFImage.nonMaxSuppressionWithScore( - bboxTensor, - scoresTensor, - NMS_maxOutputSize, - NMS_IoUThresh, - NMS_scoreThresh, - NMS_softNmsSigma - ).selectedIndices; - }); - - const indices = (await indexTensor.data()) as Float32Array; - dispose(indexTensor); - - const selectedAnnotations: Array = []; - - indices.forEach((index) => { - selectedAnnotations.push(generatedAnnotations[index]); - }); - - if (prevBackend !== getBackend()) { - setBackend(prevBackend); - } - - return selectedAnnotations; -}; diff --git a/src/utils/models/StardistVHE/preprocessStardist.ts b/src/utils/models/StardistVHE/preprocessStardist.ts deleted file mode 100644 index 7707181a..00000000 --- a/src/utils/models/StardistVHE/preprocessStardist.ts +++ /dev/null @@ -1,68 +0,0 @@ -import { Tensor3D, Tensor4D, data as tfdata, tidy } from "@tensorflow/tfjs"; -import { padToMatch } from "../helpers"; -import { getImageSlice } from "utils/common/tensorHelpers"; -import { OldImageType, ImageObject } from "store/data/types"; - -const sampleGenerator = ( - images: Array, - padVals: Array<{ padX: number; padY: number }> -) => { - const count = images.length; - - return function* () { - let index = 0; - - while (index < count) { - const image = images[index]; - const dataPlane = getImageSlice(image.data, image.activePlane); - - yield { - data: dataPlane, - bitDepth: image.bitDepth, - padX: padVals[index].padX, - padY: padVals[index].padY, - }; - - index++; - } - }; -}; - -const padImage = (image: { - data: Tensor3D; - bitDepth: OldImageType["bitDepth"]; - padX: number; - padY: number; -}) => { - const imageTensor = tidy(() => { - if (image.padX !== 0 || image.padY !== 0) { - const padded = padToMatch( - image.data, - { - height: image.data.shape[0] + image.padY, - width: image.data.shape[1] + image.padX, - }, - "reflect" - ); - - // image.data disposed by padToMatch, and would be disposed by tf anyway - return padded; - } else { - return image.data; - } - }); - - // no casting, stardistVHE input should be float32 - return imageTensor as Tensor3D; -}; - -export const preprocessStardist = ( - images: Array, - batchSize: number, - dataDims: Array<{ padX: number; padY: number }> -) => { - return tfdata - .generator(sampleGenerator(images, dataDims)) - .map((im) => padImage(im)) - .batch(batchSize) as tfdata.Dataset; -}; diff --git a/src/utils/models/availableSegmentationModels.ts b/src/utils/models/availableSegmentationModels.ts index 3722e5a4..50df5a00 100644 --- a/src/utils/models/availableSegmentationModels.ts +++ b/src/utils/models/availableSegmentationModels.ts @@ -3,8 +3,7 @@ import { Cellpose } from "./Cellpose"; import { CocoSSD } from "./CocoSSD"; import { FullyConvolutionalSegmenter } from "./FullyConvolutionalSegmenter"; import { Glas } from "./Glas/Glas"; -import { StardistVHE } from "./StardistVHE"; -import { StardistFluo } from "./StardistFluo"; +import { StardistVHE, StardistFluo } from "./Stardist"; export const availableSegmenterModels: Array = [ new FullyConvolutionalSegmenter(),