diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 300bc01859def..832f6e132901e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -3,6 +3,7 @@ import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; import { @@ -257,7 +258,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte }; }; -const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: TensorView, n: number, d: number) => { +const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => { const components = getMaxComponents(d); let WG = 64; const dComp = d / components; @@ -358,7 +359,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor }; const createAttentionProbsProgramInfo = ( - context: ComputeContext, + outputCount: number, q: TensorView, key: TensorView, pastKey: TensorView | undefined, @@ -369,7 +370,7 @@ const createAttentionProbsProgramInfo = ( ) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; - const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1; + const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey; const presentKeyShape = presentKey ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] : undefined; @@ -394,9 +395,10 @@ const createAttentionProbsProgramInfo = ( { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: parameters.kvSequenceLength }, ]; - + // Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced + const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; - if (pastKey) { + if (feedPastKey) { inputDependencies.push('type'); } if (attentionBias) { @@ -410,7 +412,7 @@ const createAttentionProbsProgramInfo = ( const qInput = inputVariable('q', q.dataType, q.dims, components); const kInput = inputVariable('key', key.dataType, key.dims, components); const inputVars = [qInput, kInput]; - if (pastKey) { + if (feedPastKey) { const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components); inputVars.push(pastKeyInput); } @@ -446,7 +448,7 @@ const createAttentionProbsProgramInfo = ( let n = workgroup_id.x * TILE_SIZE; let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; ${(() => { - if (pastKey && presentKey) { + if (feedPastKey && presentKey) { return ` let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx; let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`; @@ -464,7 +466,7 @@ const createAttentionProbsProgramInfo = ( if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) { var idx = TILE_SIZE * local_id.y + local_id.x; ${(() => { - if (pastKey && presentKey) { + if (feedPastKey && presentKey) { return ` if (n + local_id.y < uniforms.past_sequence_length) { tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; @@ -513,7 +515,7 @@ const createAttentionProbsProgramInfo = ( return { name: 'AttentionProbs', shaderCache: { - hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${context.outputCount}`, + hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${outputCount}`, inputDependencies, }, getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }), @@ -522,7 +524,7 @@ const createAttentionProbsProgramInfo = ( }; const createVxAttentionScoreProgramInfo = ( - context: ComputeContext, + outputCount: number, probs: TensorView, v: TensorView, pastValue: TensorView | undefined, @@ -532,7 +534,7 @@ const createVxAttentionScoreProgramInfo = ( const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const nReps = params.nReps ? params.nReps : 1; const repeatedVHiddenSize = params.vHiddenSize * nReps; - const presentValue = params.kvNumHeads == null && context.outputCount > 1; + const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue; const presentValueShape = presentValue ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] : undefined; @@ -553,7 +555,12 @@ const createVxAttentionScoreProgramInfo = ( { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: params.kvSequenceLength }, ]; - const inputDependencies: ProgramInputTensorInfoDependency[] = pastValue ? ['type', 'type', 'type'] : ['type', 'type']; + // Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced + const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + if (feedPastValue) { + inputDependencies.push('type'); + } const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }]; if (presentValue) { outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default }); @@ -562,7 +569,7 @@ const createVxAttentionScoreProgramInfo = ( const probsHelper = inputVariable('probs', probs.dataType, probs.dims); const vHelper = inputVariable('v', v.dataType, v.dims); const inputVars = [probsHelper, vHelper]; - if (pastValue) { + if (feedPastValue) { inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims)); } const output = outputVariable('output', probs.dataType, outputShape); @@ -591,7 +598,7 @@ const createVxAttentionScoreProgramInfo = ( let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; ${(() => { - if (pastValue && presentValue) { + if (feedPastValue && presentValue) { return ` let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n; let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n; @@ -611,7 +618,7 @@ const createVxAttentionScoreProgramInfo = ( if (n < uniforms.N && w + local_id.y < uniforms.K) { var idx = TILE_SIZE * local_id.y + local_id.x; ${(() => { - if (pastValue && presentValue) { + if (feedPastValue && presentValue) { return ` if (w + local_id.y < uniforms.past_sequence_length) { tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; @@ -647,7 +654,7 @@ const createVxAttentionScoreProgramInfo = ( return { name: 'AttentionScore', - shaderCache: { hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies }, + shaderCache: { hint: `${pastValue !== undefined};${outputCount}`, inputDependencies }, getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }), getShaderSource, }; @@ -662,15 +669,21 @@ export const applyAttention = ( _past: TensorView | undefined, pastKey: TensorView | undefined, pastValue: TensorView | undefined, - attentionBias: TensorView | undefined, + attentionBiasInput: TensorView | undefined, parameters: AttentionParameters, attributes: AttentionAttrs, ) => { - const outputCount = context.outputCount; + // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. + const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0)); const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; + const attentionBias = + attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined; - const inputsK = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey ? [q, k, pastKey] : [q, k]; + const inputsK = [q, k]; + if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { + inputsK.push(pastKey); + } if (attentionBias) { inputsK.push(attentionBias); } @@ -678,10 +691,10 @@ export const applyAttention = ( // Run AttentionProbs const probs = context.compute( createAttentionProbsProgramInfo( - context, + outputCount, q, k, - outputCount > 1 ? pastKey : undefined, + pastKey, attentionBias, parameters, attributes, @@ -693,7 +706,6 @@ export const applyAttention = ( // Run Softmax context.compute( createInPlaceSoftmaxProgramInfo( - context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, totalSequenceLength, @@ -702,19 +714,14 @@ export const applyAttention = ( ); // Run AttrionScore - const inputsV = - parameters.kvNumHeads === undefined && outputCount > 1 && pastValue ? [probs, v, pastValue] : [probs, v]; - context.compute( - createVxAttentionScoreProgramInfo( - context, - probs, - v, - outputCount > 1 && pastValue ? pastValue : undefined, - parameters, - pastSequenceLength, - ), - { inputs: inputsV, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0] }, - ); + const inputsV = [probs, v]; + if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { + inputsV.push(pastValue); + } + context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), { + inputs: inputsV, + outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0], + }); }; const prepare = (context: ComputeContext, parameters: AttentionParameters) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 72e09303ba76f..485ebec9847fd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -18,7 +18,7 @@ import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from ' import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; const getInput = (inputs: readonly TensorView[], i: number) => - inputs.length > i && inputs[i].dims.length > 0 && ShapeUtil.size(inputs[i].dims) > 0 ? inputs[i] : undefined; + inputs.length > i && inputs[i].dims.length > 0 ? inputs[i] : undefined; const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { const query = inputs[0]; diff --git a/js/web/test/data/ops/multihead-attention.jsonc b/js/web/test/data/ops/multihead-attention.jsonc index ed937a22c0b84..9ae866327b3f2 100644 --- a/js/web/test/data/ops/multihead-attention.jsonc +++ b/js/web/test/data/ops/multihead-attention.jsonc @@ -1073,5 +1073,80 @@ ] } ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=1 with empty pastKey, pastValue inputs and optional presentKey, presentValue outputs", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1], + "dims": [1, 1, 1], + "type": "float32" + }, + // K + { + "data": [2], + "dims": [1, 1, 1], + "type": "float32" + }, + // V + { + "data": [3], + "dims": [1, 1, 1], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // AttentionBias + { + "data": null, + "type": "float32" + }, + // PastKey + { + "data": [], + "dims": [1, 1, 0, 1], + "type": "float32" + }, + // PastValue + { + "data": [], + "dims": [1, 1, 0, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [3], + "dims": [1, 1, 1], + "type": "float32" + }, + { + "data": [2], + "dims": [1, 1, 1, 1], + "type": "float32" + }, + { + "data": [3], + "dims": [1, 1, 1, 1], + "type": "float32" + } + ] + } + ] } ]