Skip to content

Commit

Permalink
[add] Add stardist fluorescence model
Browse files Browse the repository at this point in the history
- we have had stardist H&E for a while, which is pretrained for tissue
  images, this adds the variant of stardist pretrained on flourescence
  data from here: https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6348084
- TODO: DRY the code, by unifying preprocessing, prediction and loading
  from the two stardist implementations (majority is identical between
  the two)
  • Loading branch information
gnodar01 committed Jun 12, 2024
1 parent a8fa812 commit f50e0ca
Show file tree
Hide file tree
Showing 9 changed files with 2,784 additions and 0 deletions.
Binary file not shown.
Binary file not shown.
2,246 changes: 2,246 additions & 0 deletions src/data/model-data/stardist-fluo/model.json

Large diffs are not rendered by default.

165 changes: 165 additions & 0 deletions src/utils/models/StardistFluo/StardistFluo.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import { GraphModel, History, LayersModel } from "@tensorflow/tfjs";

import { Segmenter } from "../AbstractSegmenter/AbstractSegmenter";
import { loadStardist } from "./loadStardist";
import { preprocessStardist } from "./preprocessStardist";
import { predictStardist } from "./predictStardist";
import { generateUUID } from "utils/common/helpers";
import { LoadInferenceDataArgs } from "../types";
import { ModelTask } from "../enums";
import { Kind, ImageObject } from "store/data/types";

/*
* TODO: make sure this has Fluorescence model info:
* Stardist (Versatile) H&E Nuclei Segmentation
* https://zenodo.org/record/6338615
* https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6338614&type=model
* https://github.com/stardist/stardist/blob/master/README.md#pretrained-models-for-2d
* Stardist: model for object detection / instance segmentation with star-convex shapes
* This pretrained model: meant to segment individual cell nuclei from brightfield images with H&E staining
*/
export class StardistFluo extends Segmenter {
protected _fgKind?: Kind;
protected _inferenceDataDims?: Array<{
width: number;
height: number;
padX: number;
padY: number;
}>;

constructor() {
super({
name: "StardistFluo",
task: ModelTask.Segmentation,
graph: true,
pretrained: true,
trainable: false,
requiredChannels: 3,
});
}

public async loadModel() {
if (this._model) return;
// inputs: [ {name: 'input', shape: [-1,-1,-1,3], dtype: 'float32'} ]
// outputs: [ {name: 'concatenate_4/concat', shape: [-1, -1, -1, 33], dtype: 'float32'} ]
// where each -1 matches on input and output of corresponding dim/axis
// 33 -> 1 probability score, followed by 32 radial equiangular distances of rays
this._model = await loadStardist();
}

public loadTraining(images: ImageObject[], preprocessingArgs: any): void {}

public loadValidation(images: ImageObject[], preprocessingArgs: any): void {}

// This Stardist model requires image dimensions to be a multiple of 16
// (for VHE in particular), see:
// https://github.com/stardist/stardist/blob/468c60552c8c93403969078e51bddc9c2c702035/stardist/models/model2d.py#L543
// https://github.com/stardist/stardist/blob/master/stardist/models/model2d.py#L201C30-L201C30
// and config here (under source -> grid): https://bioimage.io/#/?tags=stardist&id=10.5281%2Fzenodo.6338614
// basically, in the case of VHE: 2^3 * 2 = 16
protected _getPaddings(height: number, width: number) {
const padY = height % 16 === 0 ? 0 : 16 - (height % 16);
const padX = width % 16 === 0 ? 0 : 16 - (width % 16);

return { padY, padX };
}

public loadInference(
images: ImageObject[],
preprocessingArgs: LoadInferenceDataArgs
): void {
this._inferenceDataDims = images.map((im) => {
const { height, width } = im.shape;
const { padX, padY } = this._getPaddings(height, width);
return { height, width, padY, padX };
});

this._inferenceDataset = preprocessStardist(
images,
1,
this._inferenceDataDims
);

if (preprocessingArgs.kinds) {
if (preprocessingArgs.kinds.length !== 1)
throw Error(
`${this.name} Model only takes a single foreground category`
);
this._fgKind = preprocessingArgs.kinds[0];
} else if (!this._fgKind) {
const unknownCategoryId = generateUUID({ definesUnknown: true });
this._fgKind = {
id: "Nucleus",
categories: [unknownCategoryId],
containing: [],
unknownCategoryId,
};
}
}

public async train(options: any, callbacks: any): Promise<History> {
if (!this.trainable) {
throw new Error(`Training not supported for Model ${this.name}`);
} else {
throw new Error(`Training not yet implemented for Model ${this.name}`);
}
}

public async predict() {
if (!this._model) {
throw Error(`"${this.name}" Model not loaded`);
}

if (this._model instanceof LayersModel) {
throw Error(`"${this.name}" Model must a Graph, not Layers`);
}

if (!this._inferenceDataset) {
throw Error(`"${this.name}" Model's inference data not loaded`);
}

if (!this._fgKind) {
throw Error(`"${this.name}" Model's foreground kind is not loaded`);
}

if (!this._inferenceDataDims) {
throw Error(
`"${this.name}" Model's inference data dimensions and padding information not loaded`
);
}

const graphModel = this._model as GraphModel;

const infT = await this._inferenceDataset.toArray();
// imTensor disposed in `predictStardist`
const annotationsPromises = infT.map((imTensor, idx) => {
return predictStardist(
graphModel,
imTensor,
this._fgKind!.id,
this._fgKind!.unknownCategoryId,
this._inferenceDataDims![idx]
);
});
const annotations = await Promise.all(annotationsPromises);

return annotations;
}

public inferenceCategoriesById(catIds: Array<string>) {
return [];
}
public inferenceKindsById(kinds: string[]) {
if (!this._fgKind) {
throw Error(`"${this.name}" Model has no foreground kind loaded`);
}

return kinds.includes(this._fgKind.id) ? [this._fgKind] : [];
}

public override dispose() {
this._inferenceDataDims = undefined;
this._fgKind = undefined;
super.dispose();
}
}
1 change: 1 addition & 0 deletions src/utils/models/StardistFluo/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export { StardistFluo } from "./StardistFluo";
72 changes: 72 additions & 0 deletions src/utils/models/StardistFluo/loadStardist.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { io, loadGraphModel } from "@tensorflow/tfjs";

import Stardist2DFluorescenceModel from "data/model-data/stardist-fluo/model.json";
//@ts-ignore
import Stardist2DFluorescenceWeights1 from "data/model-data/stardist-fluo/group1-shard1of2.bin";
//@ts-ignore
import Stardist2DFluorescenceWeights2 from "data/model-data/stardist-fluo/group1-shard2of2.bin";

/*
* model.json contains 'modelTopology' and 'weightsManifest
*
* 'modelTopology': A JSON object that can be either of:
* 1) a model architecture JSON consistent with the format of the return value of keras.Model.to_json()
* 2) a full model JSON in the format of keras.models.save_model().
*
* 'weightsManifest': A TensorFlow.js weights manifest.
* See the Python converter function save_model() for more details.
* It is also assumed that model weights (.bin files) can be accessed
* from relative paths described by the paths fields in weights manifest.
*/

export async function loadStardist() {
let modelDescription: File;
let modelWeights1: File;
let modelWeights2: File;

try {
const model_desc_blob = new Blob(
[JSON.stringify(Stardist2DFluorescenceModel)],
{
type: "application/json",
}
);
modelDescription = new File([model_desc_blob], "model.json", {
type: "application/json",
});

const model_weights_fetch1 = await fetch(Stardist2DFluorescenceWeights1);
const model_weights_blob1 = await model_weights_fetch1.blob();
modelWeights1 = new File([model_weights_blob1], "group1-shard1of2.bin", {
type: "application/octet-stream",
});

const model_weights_fetch2 = await fetch(Stardist2DFluorescenceWeights2);
const model_weights_blob2 = await model_weights_fetch2.blob();
modelWeights2 = new File([model_weights_blob2], "group1-shard2of2.bin", {
type: "application/octet-stream",
});
} catch (err) {
const error: Error = err as Error;
process.env.NODE_ENV !== "production" &&
process.env.REACT_APP_LOG_LEVEL === "1" &&
console.error(`error loading stardist: ${error.message}`);
throw err;
}

try {
const model = await loadGraphModel(
io.browserFiles([modelDescription, modelWeights1, modelWeights2])
);

return model;
} catch (err) {
const error: Error = err as Error;

process.env.NODE_ENV !== "production" &&
process.env.REACT_APP_LOG_LEVEL === "1" &&
console.error(`error loading stardist: ${error.message}`);

throw err;
}
}
Loading

0 comments on commit f50e0ca

Please sign in to comment.