-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[add] Add stardist fluorescence model
- we have had stardist H&E for a while, which is pretrained for tissue images, this adds the variant of stardist pretrained on flourescence data from here: https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6348084 - TODO: DRY the code, by unifying preprocessing, prediction and loading from the two stardist implementations (majority is identical between the two)
- Loading branch information
Showing
9 changed files
with
2,784 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import { GraphModel, History, LayersModel } from "@tensorflow/tfjs"; | ||
|
||
import { Segmenter } from "../AbstractSegmenter/AbstractSegmenter"; | ||
import { loadStardist } from "./loadStardist"; | ||
import { preprocessStardist } from "./preprocessStardist"; | ||
import { predictStardist } from "./predictStardist"; | ||
import { generateUUID } from "utils/common/helpers"; | ||
import { LoadInferenceDataArgs } from "../types"; | ||
import { ModelTask } from "../enums"; | ||
import { Kind, ImageObject } from "store/data/types"; | ||
|
||
/* | ||
* TODO: make sure this has Fluorescence model info: | ||
* Stardist (Versatile) H&E Nuclei Segmentation | ||
* https://zenodo.org/record/6338615 | ||
* https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6338614&type=model | ||
* https://github.com/stardist/stardist/blob/master/README.md#pretrained-models-for-2d | ||
* Stardist: model for object detection / instance segmentation with star-convex shapes | ||
* This pretrained model: meant to segment individual cell nuclei from brightfield images with H&E staining | ||
*/ | ||
export class StardistFluo extends Segmenter { | ||
protected _fgKind?: Kind; | ||
protected _inferenceDataDims?: Array<{ | ||
width: number; | ||
height: number; | ||
padX: number; | ||
padY: number; | ||
}>; | ||
|
||
constructor() { | ||
super({ | ||
name: "StardistFluo", | ||
task: ModelTask.Segmentation, | ||
graph: true, | ||
pretrained: true, | ||
trainable: false, | ||
requiredChannels: 3, | ||
}); | ||
} | ||
|
||
public async loadModel() { | ||
if (this._model) return; | ||
// inputs: [ {name: 'input', shape: [-1,-1,-1,3], dtype: 'float32'} ] | ||
// outputs: [ {name: 'concatenate_4/concat', shape: [-1, -1, -1, 33], dtype: 'float32'} ] | ||
// where each -1 matches on input and output of corresponding dim/axis | ||
// 33 -> 1 probability score, followed by 32 radial equiangular distances of rays | ||
this._model = await loadStardist(); | ||
} | ||
|
||
public loadTraining(images: ImageObject[], preprocessingArgs: any): void {} | ||
|
||
public loadValidation(images: ImageObject[], preprocessingArgs: any): void {} | ||
|
||
// This Stardist model requires image dimensions to be a multiple of 16 | ||
// (for VHE in particular), see: | ||
// https://github.com/stardist/stardist/blob/468c60552c8c93403969078e51bddc9c2c702035/stardist/models/model2d.py#L543 | ||
// https://github.com/stardist/stardist/blob/master/stardist/models/model2d.py#L201C30-L201C30 | ||
// and config here (under source -> grid): https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6338614 | ||
// basically, in the case of VHE: 2^3 * 2 = 16 | ||
protected _getPaddings(height: number, width: number) { | ||
const padY = height % 16 === 0 ? 0 : 16 - (height % 16); | ||
const padX = width % 16 === 0 ? 0 : 16 - (width % 16); | ||
|
||
return { padY, padX }; | ||
} | ||
|
||
public loadInference( | ||
images: ImageObject[], | ||
preprocessingArgs: LoadInferenceDataArgs | ||
): void { | ||
this._inferenceDataDims = images.map((im) => { | ||
const { height, width } = im.shape; | ||
const { padX, padY } = this._getPaddings(height, width); | ||
return { height, width, padY, padX }; | ||
}); | ||
|
||
this._inferenceDataset = preprocessStardist( | ||
images, | ||
1, | ||
this._inferenceDataDims | ||
); | ||
|
||
if (preprocessingArgs.kinds) { | ||
if (preprocessingArgs.kinds.length !== 1) | ||
throw Error( | ||
`${this.name} Model only takes a single foreground category` | ||
); | ||
this._fgKind = preprocessingArgs.kinds[0]; | ||
} else if (!this._fgKind) { | ||
const unknownCategoryId = generateUUID({ definesUnknown: true }); | ||
this._fgKind = { | ||
id: "Nucleus", | ||
categories: [unknownCategoryId], | ||
containing: [], | ||
unknownCategoryId, | ||
}; | ||
} | ||
} | ||
|
||
public async train(options: any, callbacks: any): Promise<History> { | ||
if (!this.trainable) { | ||
throw new Error(`Training not supported for Model ${this.name}`); | ||
} else { | ||
throw new Error(`Training not yet implemented for Model ${this.name}`); | ||
} | ||
} | ||
|
||
public async predict() { | ||
if (!this._model) { | ||
throw Error(`"${this.name}" Model not loaded`); | ||
} | ||
|
||
if (this._model instanceof LayersModel) { | ||
throw Error(`"${this.name}" Model must a Graph, not Layers`); | ||
} | ||
|
||
if (!this._inferenceDataset) { | ||
throw Error(`"${this.name}" Model's inference data not loaded`); | ||
} | ||
|
||
if (!this._fgKind) { | ||
throw Error(`"${this.name}" Model's foreground kind is not loaded`); | ||
} | ||
|
||
if (!this._inferenceDataDims) { | ||
throw Error( | ||
`"${this.name}" Model's inference data dimensions and padding information not loaded` | ||
); | ||
} | ||
|
||
const graphModel = this._model as GraphModel; | ||
|
||
const infT = await this._inferenceDataset.toArray(); | ||
// imTensor disposed in `predictStardist` | ||
const annotationsPromises = infT.map((imTensor, idx) => { | ||
return predictStardist( | ||
graphModel, | ||
imTensor, | ||
this._fgKind!.id, | ||
this._fgKind!.unknownCategoryId, | ||
this._inferenceDataDims![idx] | ||
); | ||
}); | ||
const annotations = await Promise.all(annotationsPromises); | ||
|
||
return annotations; | ||
} | ||
|
||
public inferenceCategoriesById(catIds: Array<string>) { | ||
return []; | ||
} | ||
public inferenceKindsById(kinds: string[]) { | ||
if (!this._fgKind) { | ||
throw Error(`"${this.name}" Model has no foreground kind loaded`); | ||
} | ||
|
||
return kinds.includes(this._fgKind.id) ? [this._fgKind] : []; | ||
} | ||
|
||
public override dispose() { | ||
this._inferenceDataDims = undefined; | ||
this._fgKind = undefined; | ||
super.dispose(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
export { StardistFluo } from "./StardistFluo"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import { io, loadGraphModel } from "@tensorflow/tfjs"; | ||
|
||
import Stardist2DFluorescenceModel from "data/model-data/stardist-fluo/model.json"; | ||
//@ts-ignore | ||
import Stardist2DFluorescenceWeights1 from "data/model-data/stardist-fluo/group1-shard1of2.bin"; | ||
//@ts-ignore | ||
import Stardist2DFluorescenceWeights2 from "data/model-data/stardist-fluo/group1-shard2of2.bin"; | ||
|
||
/* | ||
* model.json contains 'modelTopology' and 'weightsManifest | ||
* | ||
* 'modelTopology': A JSON object that can be either of: | ||
* 1) a model architecture JSON consistent with the format of the return value of keras.Model.to_json() | ||
* 2) a full model JSON in the format of keras.models.save_model(). | ||
* | ||
* 'weightsManifest': A TensorFlow.js weights manifest. | ||
* See the Python converter function save_model() for more details. | ||
* It is also assumed that model weights (.bin files) can be accessed | ||
* from relative paths described by the paths fields in weights manifest. | ||
*/ | ||
|
||
export async function loadStardist() { | ||
let modelDescription: File; | ||
let modelWeights1: File; | ||
let modelWeights2: File; | ||
|
||
try { | ||
const model_desc_blob = new Blob( | ||
[JSON.stringify(Stardist2DFluorescenceModel)], | ||
{ | ||
type: "application/json", | ||
} | ||
); | ||
modelDescription = new File([model_desc_blob], "model.json", { | ||
type: "application/json", | ||
}); | ||
|
||
const model_weights_fetch1 = await fetch(Stardist2DFluorescenceWeights1); | ||
const model_weights_blob1 = await model_weights_fetch1.blob(); | ||
modelWeights1 = new File([model_weights_blob1], "group1-shard1of2.bin", { | ||
type: "application/octet-stream", | ||
}); | ||
|
||
const model_weights_fetch2 = await fetch(Stardist2DFluorescenceWeights2); | ||
const model_weights_blob2 = await model_weights_fetch2.blob(); | ||
modelWeights2 = new File([model_weights_blob2], "group1-shard2of2.bin", { | ||
type: "application/octet-stream", | ||
}); | ||
} catch (err) { | ||
const error: Error = err as Error; | ||
process.env.NODE_ENV !== "production" && | ||
process.env.REACT_APP_LOG_LEVEL === "1" && | ||
console.error(`error loading stardist: ${error.message}`); | ||
throw err; | ||
} | ||
|
||
try { | ||
const model = await loadGraphModel( | ||
io.browserFiles([modelDescription, modelWeights1, modelWeights2]) | ||
); | ||
|
||
return model; | ||
} catch (err) { | ||
const error: Error = err as Error; | ||
|
||
process.env.NODE_ENV !== "production" && | ||
process.env.REACT_APP_LOG_LEVEL === "1" && | ||
console.error(`error loading stardist: ${error.message}`); | ||
|
||
throw err; | ||
} | ||
} |
Oops, something went wrong.