diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 03dda0fe5a..b9dbd4b420 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2193,6 +2193,19 @@ __generic __intrinsic_op($(kIROp_Reinterpret)) T reinterpret(U value); +// Bitfield extract / insert +__generic +[__readNone] +[__unsafeForceInlineEarly] +__intrinsic_op($(kIROp_BitfieldInsert)) +T bitfieldInsert(T base, T insert, int offset, int bits); + +__generic +[__readNone] +[__unsafeForceInlineEarly] +__intrinsic_op($(kIROp_BitfieldExtract)) +T bitfieldExtract(T value, int offset, int bits); + // Use an otherwise unused value // // This can be used to silence the warning about returning before initializing an out paramter. diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index 0078d39cb2..6f7a12c5a3 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -1150,168 +1150,6 @@ public void imulExtended(highp vector x, highp vector y, out highp } } -[__readNone] -[ForceInline] -[require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] -public int bitfieldExtract(int value, int offset, int bits) -{ - __target_switch - { - case glsl: __intrinsic_asm "bitfieldExtract"; - case spirv: return spirv_asm { - result:$$int = OpBitFieldSExtract $value $offset $bits - }; - default: - return int(uint(value >> offset) & ((1u << bits) - 1)); - } -} - -__generic -[__readNone] -[ForceInline] -[require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] -public vector bitfieldExtract(vector value, int offset, int bits) -{ - __target_switch - { - case glsl: __intrinsic_asm "bitfieldExtract"; - case spirv: return spirv_asm { - result:$$vector = OpBitFieldSExtract $value $offset $bits - }; - default: - vector result; - [ForceUnroll] - for (int i = 0; i < N; ++i) - { - result[i] = bitfieldExtract(value[i], offset, bits); - } - return result; - } -} - -[__readNone] -[ForceInline] -[require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] -public uint bitfieldExtract(uint value, int offset, int bits) -{ - __target_switch - { - case glsl: __intrinsic_asm "bitfieldExtract"; - case spirv: return spirv_asm { - result:$$uint = OpBitFieldUExtract $value $offset $bits - }; - default: - return (value >> offset) & ((1u << bits) - 1); - } -} - -__generic -[__readNone] -[ForceInline] -[require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] -public vector bitfieldExtract(vector value, int offset, int bits) -{ - __target_switch - { - case glsl: __intrinsic_asm "bitfieldExtract"; - case spirv: return spirv_asm { - result:$$vector = OpBitFieldUExtract $value $offset $bits - }; - default: - vector result; - [ForceUnroll] - for (int i = 0; i < N; ++i) - { - result[i] = bitfieldExtract(value[i], offset, bits); - } - return result; - } -} - -[__readNone] -[ForceInline] -[require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] -public uint bitfieldInsert(uint base, uint insert, int offset, int bits) -{ - __target_switch - { - case glsl: __intrinsic_asm "bitfieldInsert"; - case spirv: return spirv_asm { - result:$$uint = OpBitFieldInsert $base $insert $offset $bits - }; - default: - uint clearMask = ~(((1u << bits) - 1u) << offset); - uint clearedBase = base & clearMask; - uint maskedInsert = (insert & ((1u << bits) - 1u)) << offset; - return clearedBase | maskedInsert; - } -} - -__generic -[__readNone] -[ForceInline] -[require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] -public vector bitfieldInsert(vector base, vector insert, int offset, int bits) -{ - __target_switch - { - case glsl: __intrinsic_asm "bitfieldInsert"; - case spirv: return spirv_asm { - result:$$vector = OpBitFieldInsert $base $insert $offset $bits - }; - default: - vector result; - [ForceUnroll] - for (int i = 0; i < N; ++i) - { - result[i] = bitfieldInsert(base[i], insert[i], offset, bits); - } - return result; - } -} - -[__readNone] -[ForceInline] -[require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] -public int bitfieldInsert(int base, int insert, int offset, int bits) -{ - __target_switch - { - case glsl: __intrinsic_asm "bitfieldInsert"; - case spirv: return spirv_asm { - result:$$int = OpBitFieldInsert $base $insert $offset $bits - }; - default: - uint clearMask = ~(((1u << bits) - 1u) << offset); - uint clearedBase = base & clearMask; - uint maskedInsert = (insert & ((1u << bits) - 1u)) << offset; - return clearedBase | maskedInsert; - } -} - -__generic -[__readNone] -[ForceInline] -[require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] -public vector bitfieldInsert(vector base, vector insert, int offset, int bits) -{ - __target_switch - { - case glsl: __intrinsic_asm "bitfieldInsert"; - case spirv: return spirv_asm { - result:$$vector = OpBitFieldInsert $base $insert $offset $bits - }; - default: - vector result; - [ForceUnroll] - for (int i = 0; i < N; ++i) - { - result[i] = bitfieldInsert(base[i], insert[i], offset, bits); - } - return result; - } -} - [__readNone] [ForceInline] [require(cpp_cuda_glsl_hlsl_spirv, GLSL_400)] diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index caf3613a71..f48167ad39 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2788,6 +2788,16 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO m_writer->emit(")"); break; } + case kIROp_BitfieldExtract: + { + emitBitfieldExtractImpl(inst); + break; + } + case kIROp_BitfieldInsert: + { + emitBitfieldInsertImpl(inst); + break; + } case kIROp_PackAnyValue: { m_writer->emit("packAnyValue<"); @@ -3704,6 +3714,234 @@ void CLikeSourceEmitter::emitFuncDecorationsImpl(IRFunc* func) } } +bool CLikeSourceEmitter::tryGetIntInfo(IRType* elementType, bool &isSigned, int &bitWidth) +{ + Slang::IROp type = elementType->getOp(); + if (!(type >= kIROp_Int8Type && type <= kIROp_UInt64Type)) return false; + isSigned = (type >= kIROp_Int8Type && type <= kIROp_Int64Type); + + Slang::IROp stype = (isSigned) ? type : Slang::IROp(type - 4); + bitWidth = 8 << (stype - kIROp_Int8Type); + return true; +} + +void CLikeSourceEmitter::emitVecNOrScalar(IRVectorType* vectorType, std::function emitComponentLogic) +{ + if (vectorType) + { + int N = int(getIntVal(vectorType->getElementCount())); + Slang::IRType *elementType = vectorType->getElementType(); + + // Special handling required for CUDA target + if (isCUDATarget(getTargetReq())) + { + m_writer->emit("make_"); + + switch(elementType->getOp()) + { + case kIROp_Int8Type: m_writer->emit("char"); break; + case kIROp_Int16Type: m_writer->emit("short"); break; + case kIROp_IntType: m_writer->emit("int"); break; + case kIROp_Int64Type: m_writer->emit("longlong"); break; + case kIROp_UInt8Type: m_writer->emit("uchar"); break; + case kIROp_UInt16Type: m_writer->emit("ushort"); break; + case kIROp_UIntType: m_writer->emit("uint"); break; + case kIROp_UInt64Type: m_writer->emit("ulonglong"); break; + default: SLANG_ABORT_COMPILATION("Unhandled type emitting CUDA vector"); + } + + m_writer->emitRawText(std::to_string(N).c_str()); + } + // Special handling required for Metal target + else if (isMetalTarget(getTargetReq())) + { + m_writer->emit("vec<"); + emitType(elementType); + m_writer->emit(", "); + m_writer->emit(N); + m_writer->emit(">"); + } + + // In other languages, we can output the Slang vector type directly + else { + emitType(vectorType); + } + + m_writer->emit("("); + for (int i = 0; i < N; ++i) + { + emitType(elementType); + m_writer->emit("("); + emitComponentLogic(); + m_writer->emit(")"); + if (i != N - 1) + m_writer->emit(", "); + } + m_writer->emit(")"); + } + else + { + m_writer->emit("("); + emitComponentLogic(); + m_writer->emit(")"); + } +} + +void CLikeSourceEmitter::emitBitfieldExtractImpl(IRInst* inst) +{ + // If unsigned, bfue := ((val>>off)&((1u<>off)&((1u<>(nbts-bts)); + Slang::IRType* dataType = inst->getDataType(); + Slang::IRInst* val = inst->getOperand(0); + Slang::IRInst* off = inst->getOperand(1); + Slang::IRInst* bts = inst->getOperand(2); + + Slang::IRType* elementType = dataType; + IRVectorType* vectorType = as(elementType); + if (vectorType) + elementType = vectorType->getElementType(); + + bool isSigned; + int bitWidth; + if (!tryGetIntInfo(elementType, isSigned, bitWidth)) + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "non-integer element type given to bitfieldExtract"); + return; + } + + String one; + switch(bitWidth) + { + case 8: one = "uint8_t(1)"; break; + case 16: one = "uint16_t(1)"; break; + case 32: one = "uint32_t(1)"; break; + case 64: one = "uint64_t(1)"; break; + default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unexpected bit width"); + } + + // Emit open paren and type cast for later sign extension + if (isSigned) + { + m_writer->emit("("); + emitType(inst->getDataType()); + m_writer->emit("("); + } + + // Emit bitfield extraction ((val>>off)&((1u<emit("(("); + emitOperand(val, getInfo(EmitOp::General)); + m_writer->emit(">>"); + emitVecNOrScalar(vectorType, [&]() { + emitOperand(off, getInfo(EmitOp::General)); + }); + m_writer->emit(")&("); + emitVecNOrScalar(vectorType, [&]() { + m_writer->emit("((" + one + "<<"); + emitOperand(bts, getInfo(EmitOp::General)); + m_writer->emit(")-" + one + ")"); + }); + m_writer->emit("))"); + + // Emit sign extension logic + // (type(bitfield<<(numBits-bts))>>(numBits-bts)) + // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + if (isSigned) + { + m_writer->emit("<<"); + emitVecNOrScalar(vectorType, [&]() + { + m_writer->emit("("); + m_writer->emit(bitWidth); + m_writer->emit("-"); + emitOperand(bts, getInfo(EmitOp::General)); + m_writer->emit(")"); + }); + m_writer->emit(")>>"); + emitVecNOrScalar(vectorType, [&]() + { + m_writer->emit("("); + m_writer->emit(bitWidth); + m_writer->emit("-"); + emitOperand(bts, getInfo(EmitOp::General)); + m_writer->emit(")"); + }); + m_writer->emit(")"); + } +} + +void CLikeSourceEmitter::emitBitfieldInsertImpl(IRInst* inst) +{ + // uint clearMask = ~(((1u << bits) - 1u) << offset); + // uint clearedBase = base & clearMask; + // uint maskedInsert = (insert & ((1u << bits) - 1u)) << offset; + // BitfieldInsert := T(uint(clearedBase) | uint(maskedInsert)); + Slang::IRType* dataType = inst->getDataType(); + Slang::IRInst* bse = inst->getOperand(0); + Slang::IRInst* ins = inst->getOperand(1); + Slang::IRInst* off = inst->getOperand(2); + Slang::IRInst* bts = inst->getOperand(3); + + Slang::IRType* elementType = dataType; + IRVectorType* vectorType = as(elementType); + if (vectorType) + elementType = vectorType->getElementType(); + + bool isSigned; + int bitWidth; + if (!tryGetIntInfo(elementType, isSigned, bitWidth)) + { + SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "non-integer element type given to bitfieldInsert"); + return; + } + + String one; + switch(bitWidth) { + case 8: one = "uint8_t(1)"; break; + case 16: one = "uint16_t(1)"; break; + case 32: one = "uint32_t(1)"; break; + case 64: one = "uint64_t(1)"; break; + default: SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unexpected bit width"); + } + + m_writer->emit("(("); + + // emit clearedBase := uint(bse & ~(((1u<emit("&"); + emitVecNOrScalar(vectorType, [&]() + { + m_writer->emit("~(((" + one + "<<"); + emitOperand(bts, getInfo(EmitOp::General)); + m_writer->emit(")-" + one + ")<<"); + emitOperand(off, getInfo(EmitOp::General)); + m_writer->emit(")"); + }); + + + // bitwise or clearedBase with maskedInsert + m_writer->emit(")|("); + + // Emit maskedInsert := ((insert & ((1u << bits) - 1u)) << offset); + + // - first emit mask := (insert & ((1u << bits) - 1u)) + m_writer->emit("("); + emitOperand(ins, getInfo(EmitOp::General)); + m_writer->emit("&"); + emitVecNOrScalar(vectorType, [&](){ + m_writer->emit("(" + one + "<<"); + emitOperand(bts, getInfo(EmitOp::General)); + m_writer->emit(")-" + one); + }); + m_writer->emit(")"); + + // then emit shift := << offset + m_writer->emit("<<"); + emitVecNOrScalar(vectorType, [&](){ + emitOperand(off, getInfo(EmitOp::General)); + }); + m_writer->emit("))"); +} + void CLikeSourceEmitter::emitStruct(IRStructType* structType) { ensureTypePrelude(structType); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index be769f31f9..2e96714dd5 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -534,6 +534,11 @@ class CLikeSourceEmitter: public SourceEmitterBase virtual void emitFuncDecorationsImpl(IRFunc* func); + bool tryGetIntInfo(IRType* elementType, bool &isSigned, int &bitWidth); + void emitVecNOrScalar(IRVectorType* vectorType, std::function func); + virtual void emitBitfieldExtractImpl(IRInst* inst); + virtual void emitBitfieldInsertImpl(IRInst* inst); + // Only needed for glsl output with $ prefix intrinsics - so perhaps removable in the future virtual void emitTextureOrTextureSamplerTypeImpl(IRTextureTypeBase* type, char const* baseName) { SLANG_UNUSED(type); SLANG_UNUSED(baseName); } diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 56113409d3..ce88218655 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -2449,6 +2449,30 @@ void GLSLSourceEmitter::emitFuncDecorationImpl(IRDecoration* decoration) } } +void GLSLSourceEmitter::emitBitfieldExtractImpl(IRInst* inst) +{ + m_writer->emit("bitfieldExtract("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(")"); +} + +void GLSLSourceEmitter::emitBitfieldInsertImpl(IRInst* inst) +{ + m_writer->emit("bitfieldInsert("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(","); + emitOperand(inst->getOperand(3), getInfo(EmitOp::General)); + m_writer->emit(")"); +} + void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) { switch (type->getOp()) diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index 8958c7608e..569d4d783f 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -48,6 +48,9 @@ class GLSLSourceEmitter : public CLikeSourceEmitter virtual void emitFuncDecorationImpl(IRDecoration* decoration) SLANG_OVERRIDE; virtual void emitGlobalParamDefaultVal(IRGlobalParam* decl) SLANG_OVERRIDE; + virtual void emitBitfieldExtractImpl(IRInst* inst) SLANG_OVERRIDE; + virtual void emitBitfieldInsertImpl(IRInst* inst) SLANG_OVERRIDE; + virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE; virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index a0fe220e58..2ba41da1ae 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -538,7 +538,8 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO void MetalSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) { - emitSimpleTypeImpl(elementType); + // NM: Passing count here, as Metal 64-bit vector type names do not match their scalar equivalents. + emitSimpleTypeKnowingCount(elementType, elementCount); switch (elementType->getOp()) { @@ -656,7 +657,7 @@ void MetalSourceEmitter::emitParamTypeImpl(IRType* type, String const& name) emitType(type, name); } -void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) +void MetalSourceEmitter::emitSimpleTypeKnowingCount(IRType* type, IRIntegerValue elementCount) { switch (type->getOp()) { @@ -664,10 +665,8 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_BoolType: case kIROp_Int8Type: case kIROp_IntType: - case kIROp_Int64Type: case kIROp_UInt8Type: case kIROp_UIntType: - case kIROp_UInt64Type: case kIROp_FloatType: case kIROp_DoubleType: case kIROp_HalfType: @@ -681,11 +680,19 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_UInt16Type: m_writer->emit("ushort"); return; + case kIROp_Int64Type: case kIROp_IntPtrType: - m_writer->emit("int64_t"); + // NM: note, "long" is only type that works for i64 vec + m_writer->emit("long"); return; + case kIROp_UInt64Type: case kIROp_UIntPtrType: - m_writer->emit("uint64_t"); + // NM: note, "ulong" is only type that works for i64 vec, but can't be used for scalars. + // (See metal specification pg 26) + if (elementCount > 1) + m_writer->emit("ulong"); + else + m_writer->emit("uint64_t"); return; case kIROp_StructType: m_writer->emit(getName(type)); @@ -887,6 +894,11 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) } } +void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) +{ + emitSimpleTypeKnowingCount(type, 1); +} + void MetalSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) { switch (type->getOp()) diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index 67aa0d506e..66e3523978 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -39,6 +39,8 @@ class MetalSourceEmitter : public CLikeSourceEmitter virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE; virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE; + void emitSimpleTypeKnowingCount(IRType* type, IRIntegerValue elementCount); + virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; virtual void emitParamTypeImpl(IRType* type, String const& name) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index a2da4801e7..dbd8371d1c 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2971,6 +2971,12 @@ struct SPIRVEmitContext inst->getOperand(0) ); break; + case kIROp_BitfieldExtract: + result = emitBitfieldExtract(parent, inst); + break; + case kIROp_BitfieldInsert: + result = emitBitfieldInsert(parent, inst); + break; case kIROp_Add: case kIROp_Sub: case kIROp_Mul: @@ -5594,6 +5600,45 @@ struct SPIRVEmitContext return emitInst(parent, inst, SpvOpConvertUToPtr, inst->getFullType(), kResultID, inst->getOperand(0)); } + SpvInst* emitBitfieldExtract(SpvInstParent* parent, IRInst* inst) + { + auto dataType = inst->getDataType(); + IRVectorType* vectorType = as(dataType); + Slang::IRType* elementType = dataType; + if (vectorType) + elementType = vectorType->getElementType(); + + const IntInfo i = getIntTypeInfo(elementType); + + // NM: technically, using bitfield intrinsics for anything non-32-bit goes against + // VK specification: VUID-StandaloneSpirv-Base-04781. However, it works on at least + // NVIDIA HW. + SpvOp opcode = i.isSigned ? SpvOpBitFieldSExtract : SpvOpBitFieldUExtract; + return emitInst(parent, inst, opcode, inst->getFullType(), kResultID, + inst->getOperand(0), inst->getOperand(1), inst->getOperand(2)); + } + + SpvInst* emitBitfieldInsert(SpvInstParent* parent, IRInst* inst) + { + auto dataType = inst->getDataType(); + IRVectorType* vectorType = as(dataType); + Slang::IRType* elementType = dataType; + if (vectorType) + elementType = vectorType->getElementType(); + + const IntInfo i = getIntTypeInfo(elementType); + + if (i.width == 64) + requireSPIRVCapability(SpvCapabilityInt64); + if (i.width == 16) + requireSPIRVCapability(SpvCapabilityInt16); + if (i.width == 8) + requireSPIRVCapability(SpvCapabilityInt8); + + return emitInst(parent, inst, SpvOpBitFieldInsert, inst->getFullType(), kResultID, + inst->getOperand(0), inst->getOperand(1), inst->getOperand(2), inst->getOperand(3)); + } + template SpvInst* emitCompositeConstruct( SpvInstParent* parent, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index b526df3a92..723eff15e7 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -379,6 +379,9 @@ INST(Alloca, alloca, 1, 0) INST(UpdateElement, updateElement, 2, 0) INST(DetachDerivative, detachDerivative, 1, 0) +INST(BitfieldExtract, bitfieldExtract, 3, 0) +INST(BitfieldInsert, bitfieldInsert, 4, 0) + INST(PackAnyValue, packAnyValue, 1, 0) INST(UnpackAnyValue, unpackAnyValue, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 69f1299862..068d9f341c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3831,6 +3831,10 @@ struct IRBuilder IRInst* emitGlobalValueRef(IRInst* globalInst); + IRInst* emitBitfieldExtract(IRType* type, IRInst* op0, IRInst* op1, IRInst* op2); + + IRInst* emitBitfieldInsert(IRType* type, IRInst* op0, IRInst* op1, IRInst* op2, IRInst* op3); + IRInst* emitPackAnyValue(IRType* type, IRInst* value); IRInst* emitUnpackAnyValue(IRType* type, IRInst* value); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9305d17830..f9f5f696c1 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1980,6 +1980,25 @@ namespace Slang &args[0]); } + template + static T* createInst( + IRBuilder* builder, + IROp op, + IRType* type, + IRInst* arg1, + IRInst* arg2, + IRInst* arg3, + IRInst* arg4) + { + IRInst* args[] = { arg1, arg2, arg3, arg4 }; + return createInstImpl( + builder, + op, + type, + 4, + &args[0]); + } + template static T* createInstWithTrailingArgs( IRBuilder* builder, @@ -3632,6 +3651,20 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitBitfieldExtract(IRType* type, IRInst* value, IRInst* offset, IRInst* bits) + { + auto inst = createInst(this, kIROp_BitfieldExtract, type, value, offset, bits); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBitfieldInsert(IRType* type, IRInst* base, IRInst* insert, IRInst* offset, IRInst* bits) + { + auto inst = createInst(this, kIROp_BitfieldInsert, type, base, insert, offset, bits); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitPackAnyValue(IRType* type, IRInst* value) { auto inst = createInst( @@ -8452,6 +8485,8 @@ namespace Slang case kIROp_PtrCast: case kIROp_CastDynamicResource: case kIROp_AllocObj: + case kIROp_BitfieldExtract: + case kIROp_BitfieldInsert: case kIROp_PackAnyValue: case kIROp_UnpackAnyValue: case kIROp_Reinterpret: diff --git a/tests/language-feature/bitfield/bitfield-extract.slang b/tests/language-feature/bitfield/bitfield-extract.slang new file mode 100644 index 0000000000..06896157f6 --- /dev/null +++ b/tests/language-feature/bitfield/bitfield-extract.slang @@ -0,0 +1,159 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compile-arg -skip-spirv-validation -emit-spirv-directly +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-mtl +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda + +// CHECK: 1 +// CHECK-NEXT: 2 +// CHECK-NEXT: 3 +// CHECK-NEXT: 4 +// CHECK-NEXT: 35 +// CHECK-NEXT: BE +// CHECK-NEXT: FFFFFFFA +// CHECK-NEXT: A +// CHECK-NEXT: 23 +// CHECK-NEXT: FFFFFFAB +// CHECK-NEXT: 76 +// CHECK-NEXT: FFFFFFED + +// CHECK-NEXT: 1 +// CHECK-NEXT: 2 +// CHECK-NEXT: 3 +// CHECK-NEXT: 4 +// CHECK-NEXT: 5 +// CHECK-NEXT: 6 +// CHECK-NEXT: 7 +// CHECK-NEXT: 8 +// CHECK-NEXT: 21 +// CHECK-NEXT: 7A +// CHECK-NEXT: FFFFFFFA +// CHECK-NEXT: A +// CHECK-NEXT: 67 +// CHECK-NEXT: FFFFFFEF +// CHECK-NEXT: 32 +// CHECK-NEXT: FFFFFFA9 + +// CHECK-NEXT: 7654321A +// CHECK-NEXT: FEDCBA98 +// CHECK-NEXT: 76543210 +// CHECK-NEXT: FEDCBA9A +// CHECK-NEXT: 76543210 +// CHECK-NEXT: FEDABA98 +// CHECK-NEXT: 76543210 +// CHECK-NEXT: AEDCBA98 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 +// CHECK-NEXT: 0 +// CHECK-NEXT: F00 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 + +//TEST_INPUT:ubuffer(data=[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1], stride=4):out,name=i16Buffer +RWStructuredBuffer i16Buffer; + +//TEST_INPUT:ubuffer(data=[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1], stride=4):out,name=i32Buffer +RWStructuredBuffer i32Buffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=i64Buffer +RWStructuredBuffer i64Buffer; + +[numthreads(1, 1, 1)] +void computeMain() +{ + // 16-bit tests + { + // Simple hex extraction to test, varying the offset. + uint16_t value = 0x4321; + i16Buffer[0] = bitfieldExtract(value, 4 * 0, 4); + i16Buffer[1] = bitfieldExtract(value, 4 * 1, 4); + i16Buffer[2] = bitfieldExtract(value, 4 * 2, 4); + i16Buffer[3] = bitfieldExtract(value, 4 * 3, 4); + + // Now varying the bit length + value = 0b001110111110110101; + i16Buffer[4] = bitfieldExtract(value, 0, 6); + i16Buffer[5] = bitfieldExtract(value, 6, 8); + + // Sign extension case + // - For unsigned data types, the most significant bits of the result will be set to zero. + // - For signed data types, the most significant bits will be set to the value of bit offset + base - 1 + // (i.e., it is sign extended to the width of the return type). + i16Buffer[6] = bitfieldExtract(int16_t(0b1010111), 3, 4); // 0b1010 -> 0b1111111111111010 + i16Buffer[7] = bitfieldExtract(uint16_t(0b1010111u), 3, 4); // 0b1111 -> 0b0000000000001010 + + // // Component-wise extraction + vector val4 = vector(0x1234, 0x9abc, 0x8765, 0xfedc); + vector ext4 = bitfieldExtract(val4, 4, 8); + i16Buffer[8] = ext4.x; + i16Buffer[9] = ext4.y; + i16Buffer[10] = ext4.z; + i16Buffer[11] = ext4.w; + } + + // 32-bit tests + { + // Simple hex extraction to test, varying the offset. + uint value = 0x87654321; + i32Buffer[0] = bitfieldExtract(value, 4 * 0, 4); + i32Buffer[1] = bitfieldExtract(value, 4 * 1, 4); + i32Buffer[2] = bitfieldExtract(value, 4 * 2, 4); + i32Buffer[3] = bitfieldExtract(value, 4 * 3, 4); + i32Buffer[4] = bitfieldExtract(value, 4 * 4, 4); + i32Buffer[5] = bitfieldExtract(value, 4 * 5, 4); + i32Buffer[6] = bitfieldExtract(value, 4 * 6, 4); + i32Buffer[7] = bitfieldExtract(value, 4 * 7, 4); + + // Now varying the bit length + value = 0b00111011111011110001111010100001; + i32Buffer[8] = bitfieldExtract(value, 0, 6); + i32Buffer[9] = bitfieldExtract(value, 6, 8); + + // Sign extension case + // - For unsigned data types, the most significant bits of the result will be set to zero. + // - For signed data types, the most significant bits will be set to the value of bit offset + base - 1 + // (i.e., it is sign extended to the width of the return type). + i32Buffer[10] = bitfieldExtract(0b1010111, 3, 4); // 0b1010 -> 0b11111111111111111111111111111010 + i32Buffer[11] = bitfieldExtract(0b1010111u, 3, 4); // 0b1111 -> 0b00000000000000000000000000001010 + + // Component-wise extraction + int4 val4 = int4(0x12345678, 0x9abcdef0, 0x87654321, 0xfedcba98); + int4 ext4 = bitfieldExtract(val4, 4, 8); + i32Buffer[12] = ext4.x; + i32Buffer[13] = ext4.y; + i32Buffer[14] = ext4.z; + i32Buffer[15] = ext4.w; + } + + // 64-bit tests + { + // Simple hex insertion to test, varying the offset. + uint64_t base = 0xFEDCBA9876543210ull; + uint64_t insert = 0xAull; + i64Buffer[0] = bitfieldInsert(base, insert, 4 * 0, 4); // 0xFEDCBA987654321Aull -> 2271560495 + i64Buffer[1] = bitfieldInsert(base, insert, 4 * 8, 4); // 0xFEDCBA98A6543210ull -> 2271560689 + i64Buffer[2] = bitfieldInsert(base, insert, 4 * 12, 4); // 0xFEDCAA9876543210ull -> 2271563553 + i64Buffer[3] = bitfieldInsert(base, insert, 4 * 15, 4); // 0xAEDCBA9876543210ull -> 2271605537 + + // Test with varying bit length + base = 0; + insert = 0xFEDCBA987654321Full; + i64Buffer[4] = bitfieldInsert(base, insert, 4, 40); // 0xA987654321 -> 16492674416640 + i64Buffer[5] = bitfieldInsert(base, insert, 40, 4); // 0xF000000000 -> 10477124133360 + + // Test with a vector + vector base4 = vector(base, base, base, base); + vector insert4 = vector(insert, insert, insert, insert); + vector output4 = bitfieldInsert(base4, insert4, 4, 40); + i64Buffer[6] = output4.x; + i64Buffer[7] = output4.y; + i64Buffer[8] = output4.z; + i64Buffer[9] = output4.w; + } +} diff --git a/tests/language-feature/bitfield/bitfield-insert.slang b/tests/language-feature/bitfield/bitfield-insert.slang new file mode 100644 index 0000000000..ee389bcc6c --- /dev/null +++ b/tests/language-feature/bitfield/bitfield-insert.slang @@ -0,0 +1,141 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compile-arg -skip-spirv-validation -emit-spirv-directly +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-mtl +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda + +// CHECK: 432D +// CHECK-NEXT: 43D1 +// CHECK-NEXT: 4D21 +// CHECK-NEXT: D321 +// CHECK-NEXT: A8 +// CHECK-NEXT: 3FC0 +// CHECK-NEXT: 12D4 +// CHECK-NEXT: 9A4C +// CHECK-NEXT: 8755 +// CHECK-NEXT: FE4C + +// CHECK-NEXT: 8765432F +// CHECK-NEXT: 876543F1 +// CHECK-NEXT: 87654F21 +// CHECK-NEXT: 8765F321 +// CHECK-NEXT: A8 +// CHECK-NEXT: 3FC0 +// CHECK-NEXT: 123456F8 +// CHECK-NEXT: 9ABCDE60 +// CHECK-NEXT: 87654331 +// CHECK-NEXT: FEDCBA68 + +// CHECK-NEXT: 7654321A +// CHECK-NEXT: FEDCBA98 +// CHECK-NEXT: 76543210 +// CHECK-NEXT: FEDCBA9A +// CHECK-NEXT: 76543210 +// CHECK-NEXT: FEDABA98 +// CHECK-NEXT: 76543210 +// CHECK-NEXT: AEDCBA98 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 +// CHECK-NEXT: 0 +// CHECK-NEXT: F00 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 +// CHECK-NEXT: 654321F0 +// CHECK-NEXT: 987 + +//TEST_INPUT:ubuffer(data=[0 1 2 3 4 5 6 7 8 9], stride=4):out,name=i16Buffer +RWStructuredBuffer i16Buffer; + +//TEST_INPUT:ubuffer(data=[0 1 2 3 4 5 6 7 8 9], stride=4):out,name=i32Buffer +RWStructuredBuffer i32Buffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=i64Buffer +RWStructuredBuffer i64Buffer; + +[numthreads(1, 1, 1)] +void computeMain() +{ + // 16-bit tests + { + // Simple hex insertion to test, varying the offset. + uint16_t base = 0x4321; + uint16_t value = 0xABCD; + i16Buffer[0] = bitfieldInsert(base, value, 4 * 0, 4); // 0x432D + i16Buffer[1] = bitfieldInsert(base, value, 4 * 1, 4); // 0x43D1 + i16Buffer[2] = bitfieldInsert(base, value, 4 * 2, 4); // 0x4D21 + i16Buffer[3] = bitfieldInsert(base, value, 4 * 3, 4); // 0xD321 + + // Test with varying bit length + base = 0; + value = 0b101010; + i16Buffer[4] = bitfieldInsert(base, value, 2, 6); // 0b101010 00 -> 0xA8 + value = 0b11111111; + i16Buffer[5] = bitfieldInsert(base, value, 6, 8); // 0b11111111 000000 -> 0x3FC0 + + // Test with a vector + vector base4 = vector(0x1234, 0x9abc, 0x8765, 0xfedc); + vector value4 = vector(0xABCD, 0x1234, 0x8765, 0x1234); + vector output4 = bitfieldInsert(base4, value4, 4, 4); + i16Buffer[6] = output4.x; + i16Buffer[7] = output4.y; + i16Buffer[8] = output4.z; + i16Buffer[9] = output4.w; + } + + // 32-bit tests + { + // Simple hex insertion to test, varying the offset. + uint base = 0x87654321; + uint value = 0xABCDEF; + i32Buffer[0] = bitfieldInsert(base, value, 4 * 0, 4); // 0x8765432F + i32Buffer[1] = bitfieldInsert(base, value, 4 * 1, 4); // 0x876543F1 + i32Buffer[2] = bitfieldInsert(base, value, 4 * 2, 4); // 0x8765F321 + i32Buffer[3] = bitfieldInsert(base, value, 4 * 3, 4); // 0x87F54321 + + // Test with varying bit length + base = 0; + value = 0b101010; + i32Buffer[4] = bitfieldInsert(base, value, 2, 6); // 0b10101000 + value = 0b11111111; + i32Buffer[5] = bitfieldInsert(base, value, 6, 8); // 0b11111111000000 + + // Test with a vector + uint4 base4 = uint4(0x12345678, 0x9abcdef0, 0x87654321, 0xfedcba98); + uint4 value4 = uint4(0xABCDEF, 0x123456, 0x876543, 0x123456); + uint4 output4 = bitfieldInsert(base4, value4, 4, 4); + i32Buffer[6] = output4.x; + i32Buffer[7] = output4.y; + i32Buffer[8] = output4.z; + i32Buffer[9] = output4.w; + } + + // 64-bit tests + { + // Simple hex insertion to test, varying the offset. + uint64_t base = 0xFEDCBA9876543210ull; + uint64_t insert = 0xAull; + i64Buffer[0] = bitfieldInsert(base, insert, 4 * 0, 4); // 0xFEDCBA987654321Aull + i64Buffer[1] = bitfieldInsert(base, insert, 4 * 8, 4); // 0xFEDCBA98A6543210ull + i64Buffer[2] = bitfieldInsert(base, insert, 4 * 12, 4); // 0xFEDCAA9876543210ull + i64Buffer[3] = bitfieldInsert(base, insert, 4 * 15, 4); // 0xAEDCBA9876543210ull + + // Test with varying bit length + base = 0; + insert = 0xFEDCBA987654321Full; + i64Buffer[4] = bitfieldInsert(base, insert, 4, 40); // 0xA987654321 + i64Buffer[5] = bitfieldInsert(base, insert, 40, 4); // 0xF000000000 + + // Test with a vector + vector base4 = vector(base, base, base, base); + vector insert4 = vector(insert, insert, insert, insert); + vector output4 = bitfieldInsert(base4, insert4, 4, 40); + i64Buffer[6] = output4.x; + i64Buffer[7] = output4.y; + i64Buffer[8] = output4.z; + i64Buffer[9] = output4.w; + } +}