Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support shader-printf and remove write-only texture WAR #26

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build.mk
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/ui.js
website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/styles
website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/compiler.js
website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/language-server.js
website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/playgroundShader.js

.PHONY: $(TRY_SLANG_SLANG_SOURCE_DIRECTORY_PATH)/build.em/Release/bin/slang-wasm.js
$(TRY_SLANG_SLANG_SOURCE_DIRECTORY_PATH)/build.em/Release/bin/slang-wasm.js $(TRY_SLANG_SLANG_SOURCE_DIRECTORY_PATH)/build.em/Release/bin/slang-wasm.wasm &:
Expand Down Expand Up @@ -94,3 +95,6 @@ $(TRY_SLANG_TARGET_DIRECTORY_PATH)/compiler.js: $(TRY_SLANG_SOURCE_DIRECTORY_PAT

$(TRY_SLANG_TARGET_DIRECTORY_PATH)/language-server.js: $(TRY_SLANG_SOURCE_DIRECTORY_PATH)/language-server.js
$(COPY) $^ $@

$(TRY_SLANG_TARGET_DIRECTORY_PATH)/playgroundShader.js: $(TRY_SLANG_SOURCE_DIRECTORY_PATH)/playgroundShader.js
$(COPY) $^ $@
46 changes: 24 additions & 22 deletions compiler.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ const imageMainSource = `
import user;
import playground;

RWStructuredBuffer<int> outputBuffer;
[format("r32f")] RWTexture2D<float> texture;
RWStructuredBuffer<int> outputBuffer;

[format("rgba8")]
WTexture2D outputTexture;

inline float encodeColor(float4 color)
{
Expand All @@ -27,42 +29,32 @@ void imageMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
uint width = 0;
uint height = 0;
texture.GetDimensions(width, height);
outputTexture.GetDimensions(width, height);

if (dispatchThreadID.x >= width || dispatchThreadID.y >= height)
return;

float4 color = imageMain(dispatchThreadID.xy, int2(width, height));
float encodedColor = encodeColor(color);

texture[dispatchThreadID.xy] = encodedColor;
outputTexture.Store(dispatchThreadID.xy, color);
}
`;

const playgroundSource = `
internal uniform float time;

// Return the current time in milliseconds
public float getTime()
{
return time;
}

`;

const printMainSource = `
import user;
import playground;

RWStructuredBuffer<int> outputBuffer;
[format("r32f")] RWTexture2D<float> texture;

[format("rgba8")]
WTexture2D outputTexture;

[shader("compute")]
[numthreads(1, 1, 1)]
void printMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
int res = printMain();
outputBuffer[0] = res;
printMain();
}
`;

Expand All @@ -78,9 +70,9 @@ float4 imageMain(uint2 dispatchThreadID, int2 screenSize)
const emptyPrintShader = `
import playground;

int printMain()
void printMain()
{
return 1;
print("%d, %3.2d, 0x%x, %8.3f, %s, %e\\n", 2, 3456, 2134, 40.1234, "hello world", 12.547);
}
`;

Expand Down Expand Up @@ -108,6 +100,9 @@ class SlangCompiler

mainModules = new Map();

// store the string hash if appears in the shader code
hashedString = null;

constructor(module)
{
this.slangWasmModule = module;
Expand Down Expand Up @@ -197,7 +192,6 @@ class SlangCompiler

spirvDisassembly(spirvBinary)
{

const disAsmCode = this.spirvToolsModule.dis(
spirvBinary,
this.spirvToolsModule.SPV_ENV_UNIVERSAL_1_3,
Expand Down Expand Up @@ -393,6 +387,12 @@ class SlangCompiler
compile(shaderSource, entryPointName, compileTargetStr, stage)
{
this.diagnosticsMsg = "";
if (this.hashedString)
{
this.hashedString.delete();
this.hashedString = null;
}

const compileTarget = this.compileTargetMap.findCompileTarget(compileTargetStr);
let isWholeProgram = isWholeProgramTarget(compileTargetStr);

Expand Down Expand Up @@ -423,6 +423,7 @@ class SlangCompiler

var program = slangSession.createCompositeComponentType(components);
var linkedProgram = program.link();
this.hashedString = linkedProgram.loadStrings();

var outCode;
if (compileTargetStr == "SPIRV")
Expand All @@ -441,7 +442,8 @@ class SlangCompiler
0 /* entryPointIndex */, 0 /* targetIndex */);
}

if(outCode == "") {
if (outCode == "")
{
var error = this.slangWasmModule.getLastError();
console.error(error.type + " error: " + error.message);
this.diagnosticsMsg += (error.type + " error: " + error.message);
Expand Down
100 changes: 92 additions & 8 deletions compute.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ class ComputePipeline
uniformBuffer;
uniformBufferHost = new Float32Array(4);


printfBufferElementSize = 12;
printfBufferSize = this.printfBufferElementSize * 100; // 16 bytes per printf struct
printfBuffer;
printfBufferRead;

outputBuffer;
outputBufferRead;
outputTexture;
Expand All @@ -29,7 +35,8 @@ class ComputePipeline
entries: [
{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {type: 'uniform'}},
{binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: {type: 'storage'}},
{binding: 2, visibility: GPUShaderStage.COMPUTE, storageTexture: {access: "read-write", format: this.outputTexture.format}},
{binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: {type: 'storage'}},
{binding: 3, visibility: GPUShaderStage.COMPUTE, storageTexture: {access: "write-only", format: this.outputTexture.format}},
],
};

Expand Down Expand Up @@ -57,8 +64,9 @@ class ComputePipeline
layout: this.pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.uniformBuffer }},
{ binding: 1, resource: { buffer: this.outputBuffer }},
{ binding: 2, resource: this.outputTexture.createView() },
{ binding: 1, resource: { buffer: this.printfBuffer }},
{ binding: 2, resource: { buffer: this.outputBuffer }},
{ binding: 3, resource: this.outputTexture.createView() },
],
});

Expand All @@ -79,6 +87,18 @@ class ComputePipeline
this.outputBufferRead = null;
}

if (this.printfBuffer)
{
this.printfBuffer.destroy();
this.printfBuffer = null;
}

if (this.printfBufferRead)
{
this.printfBufferRead.destroy();
this.printfBufferRead = null;
}

if (this.outputTexture)
{
this.outputTexture.destroy();
Expand All @@ -100,13 +120,13 @@ class ComputePipeline
const size = numberElements * 4; // int type
this.outputBuffer = this.device.createBuffer({lable: 'outputBuffer', size, usage});

usage = GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST;
const outputBufferRead = this.device.createBuffer({lable: 'outputBufferRead', size, usage});
this.outputBufferRead = outputBufferRead;
this.printfBuffer = this.device.createBuffer({lable: 'outputBuffer', size: this.printfBufferSize, usage});

const storageTexture = createOutputTexture(device, windowSize[0], windowSize[1], 'r32float');
this.outputTexture = storageTexture;
usage = GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST;
this.outputBufferRead = this.device.createBuffer({lable: 'outputBufferRead', size, usage});
this.printfBufferRead = this.device.createBuffer({lable: 'outputBufferRead', size: this.printfBufferSize, usage});

this.outputTexture = createOutputTexture(device, windowSize[0], windowSize[1], 'rgba8unorm');
}
}

Expand All @@ -126,4 +146,68 @@ class ComputePipeline
this.createOutput(true, windowSize);
this.createComputePipelineLayout();
}


// This is the definition of the printf buffer.
// struct FormatedStruct
// {
// uint32_t type = 0xFFFFFFFF;
// uint32_t low = 0;
// uint32_t high = 0;
// };
parsePrintfBuffer(hashedString)
{
const printfBufferArray = new Uint32Array(computePipeline.printfBufferRead.getMappedRange())
var elementIndex = 0;
var numberElements = printfBufferArray.byteLength / this.printfBufferElementSize;

var formatString;
if (printfBufferArray[0] == 1) // type field
{
formatString = hashedString.getString(printfBufferArray[1]); // low field
}
else
{
// If the first element is not a string, we will just return an empty string, it indicates
// that the printf buffer is empty.
return "";
}

// TODO: We currently doesn't support 64-bit data type (e.g. uint64_t, int64_t, double, etc.)
// so 32-bit array should be able to contain everything we need.
var dataArray = [];
const elementSizeInWords = this.printfBufferElementSize / 4;
for (elementIndex = 1; elementIndex < numberElements; elementIndex++)
{
var offset = elementIndex * elementSizeInWords;
const type = printfBufferArray[offset];

if (type == 1) // type field, this is a string
{
dataArray.push(hashedString.getString(printfBufferArray[offset + 1])); // low field
}
else if (type == 2) // type field
{
dataArray.push(printfBufferArray[offset + 1]); // low field
}
else if (type == 3) // type field
{
const floatData = reinterpretUint32AsFloat(printfBufferArray[offset + 1]);
dataArray.push(floatData); // low field
}
else if (type == 4) // type field
{
// TODO: We can't handle 64-bit data type yet.
dataArray.push(0); // low field
}
else if (type == 0xFFFFFFFF)
{
break;
}
}

const parsedTokens = parsePrintfFormat(formatString);
const output = formatPrintfString(parsedTokens, dataArray);
return output;
}
}
1 change: 1 addition & 0 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
);
</script>
<script src="util.js"></script>
<script src="playgroundShader.js"></script>
<script src="water_demo.js"></script>
<script src="pass_through.js"></script>
<script src="compute.js"></script>
Expand Down
9 changes: 2 additions & 7 deletions pass_through.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ var passThroughshaderCode = `

@fragment fn fs(fsInput: VertexShaderOutput) -> @location(0) vec4f {
let color = textureSample(ourTexture, ourSampler, fsInput.texcoord);
let value = u32(color.x);
let r = ((value & 0xFF000000) >> 24);
let g = ((value & 0x00FF0000) >> 16);
let b = ((value & 0x0000FF00) >> 8);

return vec4f(f32(r)/255.0f, f32(g)/255.0f, f32(b)/255.0f, 1.0f);
return color;
}
`;

Expand Down Expand Up @@ -98,7 +93,7 @@ class GraphicsPipeline
fragment:
{
module: shaderModule,
targets: [{format: navigator.gpu.getPreferredCanvasFormat()}]
targets: [{format: navigator.gpu.getPreferredCanvasFormat(),}]
},
});
this.pipeline = pipeline;
Expand Down
84 changes: 84 additions & 0 deletions playgroundShader.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

const playgroundSource = `
internal uniform float time;

// Return the current time in milliseconds
public float getTime()
{
return time;
}

// type field: 1 for string, 2 for integer, 3 for float, 4 for double
struct FormatedStruct
{
uint32_t type = 0xFFFFFFFF;
uint32_t low = 0;
uint32_t high = 0;
};

internal RWStructuredBuffer<FormatedStruct> g_printedBuffer;

interface IPrintf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public interface IPrintable.

{
uint32_t typeFlag();
uint32_t writePrintfWords();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pass the index counter + g_printBuffer directly into this method, we can make vectors and matrices or even arrays implement IPrintf.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, good point, but now we can not add g_printBuffer into the function parameter. Looks like wgsl doesn't support it and we need some legalization path like glsl to make this happen.

};

extension uint : IPrintf
{
uint32_t typeFlag() { return 2;}
uint32_t writePrintfWords() { return (uint32_t)this; }
}

extension int : IPrintf
{
uint32_t typeFlag() { return 2;}
uint32_t writePrintfWords() { return (uint32_t)this; }
}

// extension int64_t : IPrintf
// {
// uint64_t writePrintfWords() { return (uint64_t)this; }
// }

// extension uint64_t : IPrintf
// {
// uint64_t writePrintfWords() { return (uint64_t)this; }
// }

extension float : IPrintf
{
uint32_t typeFlag() { return 3;}
uint32_t writePrintfWords() { return bit_cast<uint32_t>(this); }
}

// extension double : IPrintf
// {
// uint64_t writePrintfWords() { return bit_cast<uint64_t>(this); }
// }

extension String : IPrintf
{
uint32_t typeFlag() { return 1;}
uint32_t writePrintfWords() { return getStringHash(this); }
}

void handleEach<T>(T value, int index) where T : IPrintf
{
g_printedBuffer[index].type = value.typeFlag();
g_printedBuffer[index].low = value.writePrintfWords();
}

public void print<each T>(String format, expand each T values) where T : IPrintf
{
//if (format.length != 0)
{
g_printedBuffer[0].type = 1;
g_printedBuffer[0].low = getStringHash(format);
int index = 1;
expand(handleEach(each values, index++));

g_printedBuffer[index] = {};
}
}
`;
Loading