Skip to content

Commit

Permalink
Merge pull request #26 from kaizhangNV/main
Browse files Browse the repository at this point in the history
Support shader-printf and remove write-only texture WAR
  • Loading branch information
kaizhangNV authored Oct 31, 2024
2 parents e83e9ce + 40c0204 commit efbccca
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 54 deletions.
4 changes: 4 additions & 0 deletions build.mk
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/ui.js
website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/styles
website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/compiler.js
website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/language-server.js
website_runtime: $(TRY_SLANG_TARGET_DIRECTORY_PATH)/playgroundShader.js

.PHONY: $(TRY_SLANG_SLANG_SOURCE_DIRECTORY_PATH)/build.em/Release/bin/slang-wasm.js
$(TRY_SLANG_SLANG_SOURCE_DIRECTORY_PATH)/build.em/Release/bin/slang-wasm.js $(TRY_SLANG_SLANG_SOURCE_DIRECTORY_PATH)/build.em/Release/bin/slang-wasm.wasm &:
Expand Down Expand Up @@ -94,3 +95,6 @@ $(TRY_SLANG_TARGET_DIRECTORY_PATH)/compiler.js: $(TRY_SLANG_SOURCE_DIRECTORY_PAT

$(TRY_SLANG_TARGET_DIRECTORY_PATH)/language-server.js: $(TRY_SLANG_SOURCE_DIRECTORY_PATH)/language-server.js
$(COPY) $^ $@

$(TRY_SLANG_TARGET_DIRECTORY_PATH)/playgroundShader.js: $(TRY_SLANG_SOURCE_DIRECTORY_PATH)/playgroundShader.js
$(COPY) $^ $@
46 changes: 24 additions & 22 deletions compiler.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ const imageMainSource = `
import user;
import playground;
RWStructuredBuffer<int> outputBuffer;
[format("r32f")] RWTexture2D<float> texture;
RWStructuredBuffer<int> outputBuffer;
[format("rgba8")]
WTexture2D outputTexture;
inline float encodeColor(float4 color)
{
Expand All @@ -27,42 +29,32 @@ void imageMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
uint width = 0;
uint height = 0;
texture.GetDimensions(width, height);
outputTexture.GetDimensions(width, height);
if (dispatchThreadID.x >= width || dispatchThreadID.y >= height)
return;
float4 color = imageMain(dispatchThreadID.xy, int2(width, height));
float encodedColor = encodeColor(color);
texture[dispatchThreadID.xy] = encodedColor;
outputTexture.Store(dispatchThreadID.xy, color);
}
`;

const playgroundSource = `
internal uniform float time;
// Return the current time in milliseconds
public float getTime()
{
return time;
}
`;

const printMainSource = `
import user;
import playground;
RWStructuredBuffer<int> outputBuffer;
[format("r32f")] RWTexture2D<float> texture;
[format("rgba8")]
WTexture2D outputTexture;
[shader("compute")]
[numthreads(1, 1, 1)]
void printMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
int res = printMain();
outputBuffer[0] = res;
printMain();
}
`;

Expand All @@ -78,9 +70,9 @@ float4 imageMain(uint2 dispatchThreadID, int2 screenSize)
const emptyPrintShader = `
import playground;
int printMain()
void printMain()
{
return 1;
print("%d, %3.2d, 0x%x, %8.3f, %s, %e\\n", 2, 3456, 2134, 40.1234, "hello world", 12.547);
}
`;

Expand Down Expand Up @@ -108,6 +100,9 @@ class SlangCompiler

mainModules = new Map();

// store the string hash if appears in the shader code
hashedString = null;

constructor(module)
{
this.slangWasmModule = module;
Expand Down Expand Up @@ -197,7 +192,6 @@ class SlangCompiler

spirvDisassembly(spirvBinary)
{

const disAsmCode = this.spirvToolsModule.dis(
spirvBinary,
this.spirvToolsModule.SPV_ENV_UNIVERSAL_1_3,
Expand Down Expand Up @@ -393,6 +387,12 @@ class SlangCompiler
compile(shaderSource, entryPointName, compileTargetStr, stage)
{
this.diagnosticsMsg = "";
if (this.hashedString)
{
this.hashedString.delete();
this.hashedString = null;
}

const compileTarget = this.compileTargetMap.findCompileTarget(compileTargetStr);
let isWholeProgram = isWholeProgramTarget(compileTargetStr);

Expand Down Expand Up @@ -423,6 +423,7 @@ class SlangCompiler

var program = slangSession.createCompositeComponentType(components);
var linkedProgram = program.link();
this.hashedString = linkedProgram.loadStrings();

var outCode;
if (compileTargetStr == "SPIRV")
Expand All @@ -441,7 +442,8 @@ class SlangCompiler
0 /* entryPointIndex */, 0 /* targetIndex */);
}

if(outCode == "") {
if (outCode == "")
{
var error = this.slangWasmModule.getLastError();
console.error(error.type + " error: " + error.message);
this.diagnosticsMsg += (error.type + " error: " + error.message);
Expand Down
100 changes: 92 additions & 8 deletions compute.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ class ComputePipeline
uniformBuffer;
uniformBufferHost = new Float32Array(4);


printfBufferElementSize = 12;
printfBufferSize = this.printfBufferElementSize * 100; // 16 bytes per printf struct
printfBuffer;
printfBufferRead;

outputBuffer;
outputBufferRead;
outputTexture;
Expand All @@ -29,7 +35,8 @@ class ComputePipeline
entries: [
{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {type: 'uniform'}},
{binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: {type: 'storage'}},
{binding: 2, visibility: GPUShaderStage.COMPUTE, storageTexture: {access: "read-write", format: this.outputTexture.format}},
{binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: {type: 'storage'}},
{binding: 3, visibility: GPUShaderStage.COMPUTE, storageTexture: {access: "write-only", format: this.outputTexture.format}},
],
};

Expand Down Expand Up @@ -57,8 +64,9 @@ class ComputePipeline
layout: this.pipeline.getBindGroupLayout(0),
entries: [
{ binding: 0, resource: { buffer: this.uniformBuffer }},
{ binding: 1, resource: { buffer: this.outputBuffer }},
{ binding: 2, resource: this.outputTexture.createView() },
{ binding: 1, resource: { buffer: this.printfBuffer }},
{ binding: 2, resource: { buffer: this.outputBuffer }},
{ binding: 3, resource: this.outputTexture.createView() },
],
});

Expand All @@ -79,6 +87,18 @@ class ComputePipeline
this.outputBufferRead = null;
}

if (this.printfBuffer)
{
this.printfBuffer.destroy();
this.printfBuffer = null;
}

if (this.printfBufferRead)
{
this.printfBufferRead.destroy();
this.printfBufferRead = null;
}

if (this.outputTexture)
{
this.outputTexture.destroy();
Expand All @@ -100,13 +120,13 @@ class ComputePipeline
const size = numberElements * 4; // int type
this.outputBuffer = this.device.createBuffer({lable: 'outputBuffer', size, usage});

usage = GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST;
const outputBufferRead = this.device.createBuffer({lable: 'outputBufferRead', size, usage});
this.outputBufferRead = outputBufferRead;
this.printfBuffer = this.device.createBuffer({lable: 'outputBuffer', size: this.printfBufferSize, usage});

const storageTexture = createOutputTexture(device, windowSize[0], windowSize[1], 'r32float');
this.outputTexture = storageTexture;
usage = GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST;
this.outputBufferRead = this.device.createBuffer({lable: 'outputBufferRead', size, usage});
this.printfBufferRead = this.device.createBuffer({lable: 'outputBufferRead', size: this.printfBufferSize, usage});

this.outputTexture = createOutputTexture(device, windowSize[0], windowSize[1], 'rgba8unorm');
}
}

Expand All @@ -126,4 +146,68 @@ class ComputePipeline
this.createOutput(true, windowSize);
this.createComputePipelineLayout();
}


// This is the definition of the printf buffer.
// struct FormatedStruct
// {
// uint32_t type = 0xFFFFFFFF;
// uint32_t low = 0;
// uint32_t high = 0;
// };
parsePrintfBuffer(hashedString)
{
const printfBufferArray = new Uint32Array(computePipeline.printfBufferRead.getMappedRange())
var elementIndex = 0;
var numberElements = printfBufferArray.byteLength / this.printfBufferElementSize;

var formatString;
if (printfBufferArray[0] == 1) // type field
{
formatString = hashedString.getString(printfBufferArray[1]); // low field
}
else
{
// If the first element is not a string, we will just return an empty string, it indicates
// that the printf buffer is empty.
return "";
}

// TODO: We currently doesn't support 64-bit data type (e.g. uint64_t, int64_t, double, etc.)
// so 32-bit array should be able to contain everything we need.
var dataArray = [];
const elementSizeInWords = this.printfBufferElementSize / 4;
for (elementIndex = 1; elementIndex < numberElements; elementIndex++)
{
var offset = elementIndex * elementSizeInWords;
const type = printfBufferArray[offset];

if (type == 1) // type field, this is a string
{
dataArray.push(hashedString.getString(printfBufferArray[offset + 1])); // low field
}
else if (type == 2) // type field
{
dataArray.push(printfBufferArray[offset + 1]); // low field
}
else if (type == 3) // type field
{
const floatData = reinterpretUint32AsFloat(printfBufferArray[offset + 1]);
dataArray.push(floatData); // low field
}
else if (type == 4) // type field
{
// TODO: We can't handle 64-bit data type yet.
dataArray.push(0); // low field
}
else if (type == 0xFFFFFFFF)
{
break;
}
}

const parsedTokens = parsePrintfFormat(formatString);
const output = formatPrintfString(parsedTokens, dataArray);
return output;
}
}
1 change: 1 addition & 0 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
);
</script>
<script src="util.js"></script>
<script src="playgroundShader.js"></script>
<script src="water_demo.js"></script>
<script src="pass_through.js"></script>
<script src="compute.js"></script>
Expand Down
9 changes: 2 additions & 7 deletions pass_through.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ var passThroughshaderCode = `
@fragment fn fs(fsInput: VertexShaderOutput) -> @location(0) vec4f {
let color = textureSample(ourTexture, ourSampler, fsInput.texcoord);
let value = u32(color.x);
let r = ((value & 0xFF000000) >> 24);
let g = ((value & 0x00FF0000) >> 16);
let b = ((value & 0x0000FF00) >> 8);
return vec4f(f32(r)/255.0f, f32(g)/255.0f, f32(b)/255.0f, 1.0f);
return color;
}
`;

Expand Down Expand Up @@ -98,7 +93,7 @@ class GraphicsPipeline
fragment:
{
module: shaderModule,
targets: [{format: navigator.gpu.getPreferredCanvasFormat()}]
targets: [{format: navigator.gpu.getPreferredCanvasFormat(),}]
},
});
this.pipeline = pipeline;
Expand Down
84 changes: 84 additions & 0 deletions playgroundShader.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

const playgroundSource = `
internal uniform float time;
// Return the current time in milliseconds
public float getTime()
{
return time;
}
// type field: 1 for string, 2 for integer, 3 for float, 4 for double
struct FormatedStruct
{
uint32_t type = 0xFFFFFFFF;
uint32_t low = 0;
uint32_t high = 0;
};
internal RWStructuredBuffer<FormatedStruct> g_printedBuffer;
interface IPrintf
{
uint32_t typeFlag();
uint32_t writePrintfWords();
};
extension uint : IPrintf
{
uint32_t typeFlag() { return 2;}
uint32_t writePrintfWords() { return (uint32_t)this; }
}
extension int : IPrintf
{
uint32_t typeFlag() { return 2;}
uint32_t writePrintfWords() { return (uint32_t)this; }
}
// extension int64_t : IPrintf
// {
// uint64_t writePrintfWords() { return (uint64_t)this; }
// }
// extension uint64_t : IPrintf
// {
// uint64_t writePrintfWords() { return (uint64_t)this; }
// }
extension float : IPrintf
{
uint32_t typeFlag() { return 3;}
uint32_t writePrintfWords() { return bit_cast<uint32_t>(this); }
}
// extension double : IPrintf
// {
// uint64_t writePrintfWords() { return bit_cast<uint64_t>(this); }
// }
extension String : IPrintf
{
uint32_t typeFlag() { return 1;}
uint32_t writePrintfWords() { return getStringHash(this); }
}
void handleEach<T>(T value, int index) where T : IPrintf
{
g_printedBuffer[index].type = value.typeFlag();
g_printedBuffer[index].low = value.writePrintfWords();
}
public void print<each T>(String format, expand each T values) where T : IPrintf
{
//if (format.length != 0)
{
g_printedBuffer[0].type = 1;
g_printedBuffer[0].low = getStringHash(format);
int index = 1;
expand(handleEach(each values, index++));
g_printedBuffer[index] = {};
}
}
`;
Loading

0 comments on commit efbccca

Please sign in to comment.