diff --git a/src/Deeptable.ts b/src/Deeptable.ts index 2ced0e0b..1f55f172 100644 --- a/src/Deeptable.ts +++ b/src/Deeptable.ts @@ -38,6 +38,7 @@ import { DataSelection } from './selection'; import { Some, TupleMap } from './utilityFunctions'; import { getNestedVector } from './regl_rendering'; import { tileKey_to_tix } from './tixrixqid'; +import { DeepGPU } from './deepscatter'; type TransformationStatus = 'queued' | 'in progress' | 'complete' | 'failed'; @@ -102,6 +103,7 @@ export class Deeptable { public root_tile: Tile; public fetcherId?: number; private _downloaderId?: number; + private _deepGPU?: Promise; // Whether the tileset is structured as a pure quadtree. public readonly tileStucture: DS.TileStructure; @@ -255,6 +257,13 @@ export class Deeptable { await this.root_tile.download_to_depth(max_ix, suffix); } + get deepGPU() : Promise { + if (this._deepGPU !== undefined) { + return this._deepGPU + } + this._deepGPU = DeepGPU.create(this) + return this._deepGPU + } /** * The highest known point that deepscatter has seen so far. This is used * to adjust opacity size. @@ -755,7 +764,7 @@ export class Deeptable { spawnDownloads( bbox: Rectangle | undefined, max_ix: number, - queue_length = 32, + queue_length = 64, fields: string[] = ['x', 'y', 'ix'], priority: 'high' | 'low' = 'high', ): boolean { diff --git a/src/deepscatter.ts b/src/deepscatter.ts index 278ca308..a52e6ada 100644 --- a/src/deepscatter.ts +++ b/src/deepscatter.ts @@ -4,6 +4,8 @@ export { Deeptable } from './Deeptable'; export { LabelMaker } from './label_rendering'; export { dictionaryFromArrays } from './utilityFunctions'; export { Tile } from './tile'; +export { DeepGPU, create_hamming_transform, HammingPipeline, ReusableWebGPUPipeline } from './webGPU/lib' + export type { APICall, CompletePrefs, diff --git a/src/regl_rendering.ts b/src/regl_rendering.ts index 3c8bf6c9..1484a9fe 100644 --- a/src/regl_rendering.ts +++ b/src/regl_rendering.ts @@ -285,7 +285,7 @@ export class ReglRenderer extends Renderer { deeptable.spawnDownloads( this.zoom.current_corners(), this.props.max_ix, - 5, + 64, this.aes.neededFields.map((x) => x[0]), 'high', ); @@ -294,7 +294,7 @@ export class ReglRenderer extends Renderer { deeptable.spawnDownloads( undefined, prefs.max_points, - 5, + 64, this.aes.neededFields.map((x) => x[0]), 'high', ); @@ -1336,7 +1336,7 @@ export abstract class BufferSet { buffer: T; offset: number; stride?: number; - byte_size?: number; // in bytes; + byte_size: number; // in bytes; }; export type WebGPUBufferLocation = BufferLocation & { diff --git a/src/webGPU/buffertools.ts b/src/webGPU/buffertools.ts index 06ecefd6..69242ded 100644 --- a/src/webGPU/buffertools.ts +++ b/src/webGPU/buffertools.ts @@ -1,4 +1,4 @@ -import { isTypedArray, type TypedArray } from 'webgpu-utils'; +import { type TypedArray } from 'webgpu-utils'; import { BufferSet } from '../regl_rendering'; import { WebGPUBufferLocation } from '../types'; import { Some, TupleMap } from '../utilityFunctions'; @@ -76,8 +76,8 @@ export class WebGPUBufferSet extends BufferSet< if (this.store.has(key)) { throw new Error(`Key ${key.join(', ')} already exists in buffer set.`); } - const size = value.byteLength; - const paddedSize = Math.ceil(size / 256) * 256; + const byte_size = value.byteLength; + const paddedSize = Math.ceil(byte_size / 256) * 256; const { buffer, offset } = this.allocate_block(paddedSize); @@ -85,7 +85,7 @@ export class WebGPUBufferSet extends BufferSet< // cast it to uint32array const v2 = value; const data = new Uint32Array(v2.buffer, v2.byteOffset, v2.byteLength / 4); - const description = { buffer, offset, size, paddedSize }; + const description = { buffer, offset, byte_size, paddedSize }; await this.passThroughStagingBuffer(data, description); this.register(key, description); } @@ -111,7 +111,7 @@ export class WebGPUBufferSet extends BufferSet< export function createSingletonBuffer( device: GPUDevice, - data: Uint32Array | Int32Array | Float32Array | ArrayBuffer, + data: Uint32Array | Int32Array | Float32Array | Uint8Array | ArrayBuffer, usage: number, ): GPUBuffer { // Creates a disposable singleton buffer. @@ -123,11 +123,24 @@ export function createSingletonBuffer( mappedAtCreation: true, }); const mappedRange = buffer.getMappedRange(); - if (isTypedArray(data)) { - new Uint32Array(mappedRange).set(data as TypedArray); - } else { - new Uint32Array(mappedRange).set(new Uint32Array(data as ArrayBuffer)); - } + + // Write the data into the buffer + if (data instanceof Uint32Array) { + new Uint32Array(mappedRange).set(data); + } else if (data instanceof Int32Array) { + new Int32Array(mappedRange).set(data); + } else if (data instanceof Float32Array) { + new Float32Array(mappedRange).set(data); + } else if (data instanceof Uint8Array) { + new Uint8Array(mappedRange).set(data); + } else if (data instanceof ArrayBuffer) { + // Treat ArrayBuffer as raw data, copy it into the mapped range + const view = new Uint8Array(mappedRange); + view.set(new Uint8Array(data)); + } else { + throw new Error("Unsupported data type for buffer creation"); + } + buffer.unmap(); return buffer; } diff --git a/src/webGPU/lib.ts b/src/webGPU/lib.ts index e1cf03b7..85ce5379 100644 --- a/src/webGPU/lib.ts +++ b/src/webGPU/lib.ts @@ -1,7 +1,7 @@ import { makeShaderDataDefinitions, makeStructuredView } from 'webgpu-utils'; import { WebGPUBufferSet, createSingletonBuffer } from './buffertools'; -import { Deeptable, Scatterplot, Tile } from '../deepscatter'; -import { Bool, Vector, vectorFromArray } from 'apache-arrow'; +import { Deeptable, Tile, Transformation } from '../deepscatter'; +import { Bool, Type, Vector, vectorFromArray } from 'apache-arrow'; export class DeepGPU { // This is a stateful class for bundling together GPU buffers and resources. @@ -45,6 +45,7 @@ export class DeepGPU { if (this.bufferSet.store.has([field, tile.key])) { return this.bufferSet.store.get([field, tile.key]) } else { + const values = (await tile.get_column(field)).data[0].children[0] .values as Uint8Array; await this.bufferSet.set([field, tile.key], values); @@ -74,11 +75,13 @@ export class HammingPipeline extends ReusableWebGPUPipeline { public gpuState: DeepGPU; public dimensionality? : number; public comparisonBuffer: GPUBuffer; - private fieldName = '_hamming_embeddings'; + private fieldName : string; constructor( gpuState: DeepGPU, + fieldName: string ) { super(gpuState) + this.fieldName = fieldName } bindGroupLayout(device: GPUDevice) { @@ -138,10 +141,17 @@ export class HammingPipeline extends ReusableWebGPUPipeline { setComparisonArray( arr: Vector ) { - const underlying = arr.data[0].values; + const underlying = arr.data[0] + if (underlying.type.typeId !== Type.Bool) { + throw new Error("uhuh") + } + const bytes = underlying.values.slice(underlying.offset / 8, underlying.offset / 8 + underlying.length / 8) + if (bytes.length !== 768 / 8) { + throw new Error("WTF") + } this.comparisonBuffer = createSingletonBuffer( this.gpuState.device, - underlying, + bytes, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, ); this.dimensionality = underlying.length; @@ -244,27 +254,18 @@ export class HammingPipeline extends ReusableWebGPUPipeline { } } -// hide the state in a global variable. -const dumb: DeepGPU[] = []; export async function create_hamming_transform( - scatterplot: Scatterplot, - id: string, + deeptable: Deeptable, + field: string, view: Vector, -) { - if (dumb.length === 0) { - dumb.push(await DeepGPU.create(scatterplot.deeptable)); - } - if (scatterplot.dataset.transformations[id] !== undefined) { - return; - } +) : Promise { - const [gpuState] = dumb; - const pipeline = new HammingPipeline(gpuState); + const gpuState = await deeptable.deepGPU + const pipeline = new HammingPipeline(gpuState, field); pipeline.setComparisonArray(view) pipeline.prep(); - - scatterplot.dataset.transformations[id] = (tile) => pipeline.runOnTile(tile) + return (tile: Tile) => pipeline.runOnTile(tile) }