diff --git a/src/io/dicom.ts b/src/io/dicom.ts index 28a10f129..a972d25f1 100644 --- a/src/io/dicom.ts +++ b/src/io/dicom.ts @@ -190,7 +190,7 @@ type ReadOverlappingSegmentationResultWithRealMeta = metaInfo: ReadOverlappingSegmentationMeta; }; -export async function buildLabelMap(file: File) { +export async function buildSegmentGroups(file: File) { const inputImage = sanitizeFile(file); const result = (await readOverlappingSegmentation(inputImage, { webWorker: getWorker(), diff --git a/src/store/datasets-dicom.ts b/src/store/datasets-dicom.ts index d68478891..2e0a3962c 100644 --- a/src/store/datasets-dicom.ts +++ b/src/store/datasets-dicom.ts @@ -3,7 +3,6 @@ import { defineStore } from 'pinia'; import { Image } from 'itk-wasm'; import { DataSourceWithFile } from '@/src/io/import/dataSource'; import * as DICOM from '@/src/io/dicom'; -import { pullComponent0 } from '@/src/utils/images'; import { identity, pick, removeFromArray } from '../utils'; import { useImageStore } from './datasets-images'; import { useFileStore } from './datasets-files'; @@ -55,16 +54,10 @@ const buildImage = async (seriesFiles: File[], modality: string) => { const messages: string[] = []; if (modality === 'SEG') { const segFile = seriesFiles[0]; - const results = await DICOM.buildLabelMap(segFile); - if (results.outputImage.imageType.components !== 1) { - messages.push( - `${segFile.name} SEG file has overlapping segments. Using first set.` - ); - results.outputImage = pullComponent0(results.segImage); - } + const results = await DICOM.buildSegmentGroups(segFile); if (seriesFiles.length > 1) messages.push( - 'SEG image has multiple components. Using only the first component.' + 'Tried to make one volume from 2 SEG modality files. Using only the first file!' ); return { modality: 'SEG', diff --git a/src/store/segmentGroups.ts b/src/store/segmentGroups.ts index 4fcef5ad6..7363f9d13 100644 --- a/src/store/segmentGroups.ts +++ b/src/store/segmentGroups.ts @@ -17,6 +17,7 @@ import { getImage, isRegularImage, } from '@/src/utils/dataSelection'; +import vtkImageExtractComponents from '@/src/utils/imageExtractComponentsFilter'; import vtkLabelMap from '../vtk/LabelMap'; import { StateFile, @@ -35,6 +36,7 @@ export const DEFAULT_SEGMENT_COLOR: RGBAColor = [255, 0, 0, 255]; export const makeDefaultSegmentName = (value: number) => `Segment ${value}`; export const makeDefaultSegmentGroupName = (baseName: string, index: number) => `Segment Group ${index} for ${baseName}`; +const numberer = (index: number) => (index <= 1 ? '' : `${index}`); // start numbering at 2 export interface SegmentGroupMetadata { name: string; @@ -79,6 +81,20 @@ export function toLabelMap(imageData: vtkImageData) { return labelmap; } +export function extractEachComponent(input: vtkImageData) { + const numComponents = input + .getPointData() + .getScalars() + .getNumberOfComponents(); + const extractComponentsFilter = vtkImageExtractComponents.newInstance(); + extractComponentsFilter.setInputData(input); + return Array.from({ length: numComponents }, (_, i) => { + extractComponentsFilter.setComponents([i]); + extractComponentsFilter.update(); + return extractComponentsFilter.getOutputData() as vtkImageData; + }); +} + export const useSegmentGroupStore = defineStore('segmentGroup', () => { type _This = ReturnType; @@ -156,6 +172,22 @@ export const useSegmentGroupStore = defineStore('segmentGroup', () => { }); }); + function pickUniqueName( + formatName: (index: number) => string, + parentID: string + ) { + const existingNames = new Set( + Object.values(metadataByID).map((meta) => meta.name) + ); + let name = ''; + do { + const nameIndex = nextDefaultIndex[parentID] ?? 1; + nextDefaultIndex[parentID] = nameIndex + 1; + name = formatName(nameIndex); + } while (existingNames.has(name)); + return name; + } + /** * Creates a new labelmap entry from a parent/source image. */ @@ -174,16 +206,10 @@ export const useSegmentGroupStore = defineStore('segmentGroup', () => { 'value' ); - // pick a unique name - let name = ''; - const existingNames = new Set( - Object.values(metadataByID).map((meta) => meta.name) + const name = pickUniqueName( + (index: number) => makeDefaultSegmentGroupName(baseName, index), + parentID ); - do { - const nameIndex = nextDefaultIndex[parentID] ?? 1; - nextDefaultIndex[parentID] = nameIndex + 1; - name = makeDefaultSegmentGroupName(baseName, nameIndex); - } while (existingNames.has(name)); return addLabelmap.call(this, labelmap, { name, @@ -210,7 +236,11 @@ export const useSegmentGroupStore = defineStore('segmentGroup', () => { return [...color]; } - async function decodeSegments(imageId: DataSelection, image: vtkLabelMap) { + async function decodeSegments( + imageId: DataSelection, + image: vtkLabelMap, + component = 0 + ) { if (!isRegularImage(imageId)) { // dicom image const dicomStore = useDICOMStore(); @@ -218,7 +248,9 @@ export const useSegmentGroupStore = defineStore('segmentGroup', () => { const volumeBuildResults = await dicomStore.volumeBuildResults[imageId]; if (volumeBuildResults.modality === 'SEG') { const segments = - volumeBuildResults.builtImageResults.metaInfo.segmentAttributes[0]; + volumeBuildResults.builtImageResults.metaInfo.segmentAttributes[ + component + ]; return segments.map((segment) => ({ value: segment.labelID, name: segment.SegmentLabel, @@ -272,27 +304,42 @@ export const useSegmentGroupStore = defineStore('segmentGroup', () => { ); } - const name = imageStore.metadata[imageID].name; - // Don't remove image if DICOM as user may have selected segment group image as primary selection by now + // Don't remove image if DICOM. User may have selected segment group image as primary selection by now const deleteImage = isRegularImage(imageID); if (deleteImage) { imageStore.deleteData(imageID); } - const matchingParentSpace = await ensureSameSpace( - parentImage, - childImage, - true - ); - const labelmapImage = toLabelMap(matchingParentSpace); + const componentCount = childImage + .getPointData() + .getScalars() + .getNumberOfComponents(); + // for each component, create create new vtkImageData with just one component, pulled from each component of childImage + const images = + componentCount === 1 ? [childImage] : extractEachComponent(childImage); + + const baseName = imageStore.metadata[imageID].name; + images.forEach(async (image, component) => { + const matchingParentSpace = await ensureSameSpace( + parentImage, + image, + true + ); + const labelmapImage = toLabelMap(matchingParentSpace); - const segments = await decodeSegments(imageID, labelmapImage); - const { order, byKey } = normalizeForStore(segments, 'value'); - const segmentGroupStore = useSegmentGroupStore(); - segmentGroupStore.addLabelmap(labelmapImage, { - name, - parentImage: parentID, - segments: { order, byValue: byKey }, + const segments = await decodeSegments(imageID, labelmapImage, component); + const { order, byKey } = normalizeForStore(segments, 'value'); + const segmentGroupStore = useSegmentGroupStore(); + + const name = pickUniqueName( + (index: number) => `${baseName} ${numberer(index)}`, + parentID + ); + segmentGroupStore.addLabelmap(labelmapImage, { + name, + parentImage: parentID, + segments: { order, byValue: byKey }, + }); }); } diff --git a/src/utils/__tests__/imageExtractComponentsFilter.ts b/src/utils/__tests__/imageExtractComponentsFilter.ts new file mode 100644 index 000000000..012fba42e --- /dev/null +++ b/src/utils/__tests__/imageExtractComponentsFilter.ts @@ -0,0 +1,56 @@ +import { describe, it, expect } from 'vitest'; +import vtkImageData from '@kitware/vtk.js/Common/DataModel/ImageData'; +import vtkDataArray from '@kitware/vtk.js/Common/Core/DataArray'; +import vtkImageExtractComponentsFilter from '../imageExtractComponentsFilter'; + +describe('vtkImageExtractComponentsFilter', () => { + it('should extract specified components', () => { + // Create an image data with known scalar components + const imageData = vtkImageData.newInstance(); + imageData.setDimensions([2, 2, 1]); + + // Create scalar data with 3 components per voxel + const scalars = vtkDataArray.newInstance({ + numberOfComponents: 3, + values: new Uint8Array([ + // Voxel 0 + 10, 20, 30, + // Voxel 1 + 40, 50, 60, + // Voxel 2 + 70, 80, 90, + // Voxel 3 + 100, 110, 120, + ]), + }); + + imageData.getPointData().setScalars(scalars); + + // Create the filter and set components to extract + const extractComponentsFilter = + vtkImageExtractComponentsFilter.newInstance(); + extractComponentsFilter.setComponents([0, 2]); // Extract components 0 and 2 + extractComponentsFilter.setInputData(imageData); + extractComponentsFilter.update(); + + const outputData = extractComponentsFilter.getOutputData(); + const outputScalars = outputData.getPointData().getScalars(); + const outputValues = outputScalars.getData(); + + // Expected output + const expectedValues = new Uint8Array([ + // Voxel 0 + 10, 30, + // Voxel 1 + 40, 60, + // Voxel 2 + 70, 90, + // Voxel 3 + 100, 120, + ]); + + // Check if output matches expected values + expect(outputScalars.getNumberOfComponents()).toBe(2); + expect(outputValues).toEqual(expectedValues); + }); +}); diff --git a/src/utils/imageExtractComponentsFilter.js b/src/utils/imageExtractComponentsFilter.js new file mode 100644 index 000000000..46803bb63 --- /dev/null +++ b/src/utils/imageExtractComponentsFilter.js @@ -0,0 +1,86 @@ +/* eslint-disable no-param-reassign */ +import macro from '@kitware/vtk.js/macro'; +import vtkImageData from '@kitware/vtk.js/Common/DataModel/ImageData'; +import vtkDataArray from '@kitware/vtk.js/Common/Core/DataArray'; + +function vtkImageExtractComponentsFilter(publicAPI, model) { + model.classHierarchy.push('vtkImageExtractComponentsFilter'); + + publicAPI.requestData = (inData, outData) => { + const inputData = inData[0]; + const outputData = vtkImageData.newInstance(); + + const components = model.components; + if (!components || !Array.isArray(components) || components.length === 0) { + throw Error('No components specified for extraction.'); + } + + const inputScalars = inputData.getPointData().getScalars(); + const numInputComponents = inputScalars.getNumberOfComponents(); + components.forEach((c) => { + if (c < 0) { + throw Error('Component index must be greater than or equal to 0.'); + } + if (c >= numInputComponents) { + throw Error( + 'Component index must be less than the number of components in the input data.' + ); + } + }); + + outputData.shallowCopy(inputData); + + const inputArray = inputScalars.getData(); + const numPixels = inputArray.length / numInputComponents; + + const outputNumComponents = components.length; + const outputArray = new inputArray.constructor( + numPixels * outputNumComponents + ); + + for (let pixel = 0; pixel < numPixels; pixel++) { + for (let c = 0; c < components.length; c++) { + outputArray[pixel * outputNumComponents + c] = + inputArray[pixel * numInputComponents + components[c]]; + } + } + + outputData.getPointData().setScalars( + vtkDataArray.newInstance({ + numberOfComponents: outputNumComponents, + values: outputArray, + }) + ); + + outData[0] = outputData; + }; +} + +const DEFAULT_VALUES = { + components: [], +}; + +// ---------------------------------------------------------------------------- + +export function extend(publicAPI, model, initialValues = {}) { + Object.assign(model, DEFAULT_VALUES, initialValues); + + macro.obj(publicAPI, model); + + macro.algo(publicAPI, model, 1, 1); + + macro.setGet(publicAPI, model, ['components']); + + vtkImageExtractComponentsFilter(publicAPI, model); +} + +// ---------------------------------------------------------------------------- + +export const newInstance = macro.newInstance( + extend, + 'vtkImageExtractComponentsFilter' +); + +// ---------------------------------------------------------------------------- + +export default { newInstance, extend }; diff --git a/src/utils/images.ts b/src/utils/images.ts deleted file mode 100644 index 2482b75df..000000000 --- a/src/utils/images.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { Image } from 'itk-wasm'; - -export const pullComponent0 = (image: Image) => { - const srcComponentCount = image.imageType.components; - const srcPixelArray = image.data!; - const oneComponentArrayLength = srcPixelArray.length / srcComponentCount; - const pixelArray = new (srcPixelArray.constructor as { - new (length: number): typeof srcPixelArray; - })(oneComponentArrayLength); - for (let pixel = 0; pixel < oneComponentArrayLength; pixel++) { - pixelArray[pixel] = srcPixelArray[pixel * srcComponentCount]; - } - return { - ...image, - data: pixelArray, - imageType: { - ...image.imageType, - components: 1, - }, - }; -};