Skip to content

Commit

Permalink
Allow passing sized array to unsized array parameter. (shader-slang#4744
Browse files Browse the repository at this point in the history
)
  • Loading branch information
csyonghe authored Jul 27, 2024
1 parent c0bff66 commit 7e2bc8e
Show file tree
Hide file tree
Showing 15 changed files with 278 additions and 17 deletions.
42 changes: 41 additions & 1 deletion docs/user-guide/02-conventional-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,54 @@ In some cases the element count is then inferred from the initial value of a var
int a[] = { 1, 2, 3 };
```

In other cases, the result is a _runtime-sized_ array, where the actual element count will be determined later:
In other cases, the result is a _unsized_ array, where the actual element count will be determined later:

```hlsl
// the type of `b` is `int[]`
void f( int b[] )
{ ... }
```

It is allowed to pass a sized array as argument to an unsized array parameter when calling a function.

Array types has a `getCount()` memeber function that returns the length of the array.

```hlsl
int f( int b[] )
{
return b.getCount(); // Note: all arguments to `b` must be resolvable to sized arrays.
}
void test()
{
int arr[3] = { 1, 2, 3 };
int x = f(arr); // OK, passing sized array to unsized array parameter, x will be 3.
}
```

Please note that if a function calls `getCount()` method on an unsized array parameter, then all
calls to that function must provide a sized array argument, otherwise the compiler will not be able
to resolve the size and will report an error. The following code shows an example of valid and
invalid cases.

```hlsl
int f( int b[] )
{
return b.getCount();
}
int g( int b[] )
{
return f(b); // transitive calls are allowed.
}
uniform int unsizedParam[];
void test()
{
g(unsizedParam); // Not OK, `unsizedParam` doesn't have a known size at compile time.
int arr[3];
g(arr); // OK.
}
```

There are more limits on how runtime-sized arrays can be used than on arrays of statically-known element count.

> #### Note ####
Expand Down
5 changes: 2 additions & 3 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -1040,16 +1040,15 @@ __generic<T, let N:int>
__magic_type(ArrayExpressionType)
struct Array : IArray<T>
{
[ForceInline]
int getCount() { return N; }
__intrinsic_op($(kIROp_GetArrayLength))
int getCount();

__subscript(int index) -> T
{
__intrinsic_op($(kIROp_GetElement))
get;
}
}

/// An `N` component vector with elements of type `T`.
__generic<T = float, let N : int = 4>
__magic_type(VectorExpressionType)
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ast-support-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ namespace Slang

kConversionCost_GenericParamUpcast = 1,
kConversionCost_UnconstraintGenericParam = 20,
kConversionCost_SizedArrayToUnsizedArray = 30,

// Convert between matrices of different layout
kConversionCost_MatrixLayout = 5,
Expand Down
25 changes: 25 additions & 0 deletions source/slang/slang-check-conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,31 @@ namespace Slang
return true;
}

// Allow implicit conversion from sized array to unsized array when
// calling a function.
// Note: we implement the logic here instead of an implicit_conversion
// intrinsic in the stdlib because we only want to allow this conversion
// when calling a function.
//
if (site == CoercionSite::Argument)
{
if (auto fromArrayType = as<ArrayExpressionType>(fromType))
{
if (auto toArrayType = as<ArrayExpressionType>(toType))
{
if (fromArrayType->getElementType()->equals(toArrayType->getElementType())
&& toArrayType->isUnsized())
{
if (outToExpr)
*outToExpr = fromExpr;
if (outCost)
*outCost = kConversionCost_SizedArrayToUnsizedArray;
return true;
}
}
}
}

// Another important case is when either the "to" or "from" type
// represents an error. In such a case we must have already
// reporeted the error, so it is better to allow the conversion
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ DIAGNOSTIC(55201, Error, unsupportedRecursion, "recursion detected in call to '$
DIAGNOSTIC(55202, Error, systemValueAttributeNotSupported, "system value semantic '$0' is not supported for the current target.")
DIAGNOSTIC(55203, Error, systemValueTypeIncompatible, "system value semantic '$0' should have type '$1' or be convertible to type '$1'.")
DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatically map '$0' to a host type. Automatic binding generation failed for '$1'")
DIAGNOSTIC(56002, Error, attemptToQuerySizeOfUnsizedArray, "cannot obtain the size of an unsized array.")

DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.")
Expand Down
14 changes: 7 additions & 7 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,13 +947,13 @@ Result linkAndOptimizeIR(
specializeResourceUsage(codeGenContext, irModule);
specializeFuncsForBufferLoadArgs(codeGenContext, irModule);

// For GLSL targets, we also want to specialize calls to functions that
// takes array parameters if possible, to avoid performance issues on
// those platforms.
if (isKhronosTarget(targetRequest))
{
specializeArrayParameters(codeGenContext, irModule);
}
// We also want to specialize calls to functions that
// takes unsized array parameters if possible.
// Moreover, for Khronos targets, we also want to specialize calls to functions
// that takes arrays/structs containing arrays as parameters with the actual
// global array object to avoid loading big arrays into SSA registers, which seems
// to cause performance issues.
specializeArrayParameters(codeGenContext, irModule);

#if 0
dumpIRIfEnabled(codeGenContext, irModule, "AFTER RESOURCE SPECIALIZATION");
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-autodiff-fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1943,6 +1943,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_DebugLine:
case kIROp_DebugVar:
case kIROp_DebugValue:
case kIROp_GetArrayLength:
case kIROp_SizeOf:
case kIROp_AlignOf:
return transcribeNonDiffInst(builder, origInst);

// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
Expand Down
27 changes: 27 additions & 0 deletions source/slang/slang-ir-check-unsupported-inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ namespace Slang
}
}

void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* sink)
{
SLANG_UNUSED(target);
for (auto block : func->getBlocks())
{
for (auto inst : block->getChildren())
{
switch (inst->getOp())
{
case kIROp_GetArrayLength:
sink->diagnose(inst, Diagnostics::attemptToQuerySizeOfUnsizedArray);
break;
}
}
}
}

void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
{
HashSet<IRFunc*> checkedFuncsForRecursionDetection;
Expand All @@ -62,6 +79,16 @@ namespace Slang
case kIROp_Func:
if (!isCPUTarget(target))
checkRecursion(checkedFuncsForRecursionDetection, as<IRFunc>(globalInst), sink);
checkUnsupportedInst(target, as<IRFunc>(globalInst), sink);
break;
case kIROp_Generic:
{
auto generic = as<IRGeneric>(globalInst);
auto innerFunc = as<IRFunc>(findGenericReturnVal(generic));
if (innerFunc)
checkUnsupportedInst(target, innerFunc, sink);
break;
}
default:
break;
}
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ INST(TreatAsDynamicUniform, TreatAsDynamicUniform, 1, 0)

INST(SizeOf, sizeOf, 1, 0)
INST(AlignOf, alignOf, 1, 0)

INST(GetArrayLength, GetArrayLength, 1, 0)
INST(IsType, IsType, 3, 0)
INST(TypeEquals, TypeEquals, 2, 0)
INST(IsInt, IsInt, 1, 0)
Expand Down
8 changes: 8 additions & 0 deletions source/slang/slang-ir-peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@ struct PeepholeContext : InstPassBase
changed = true;
}
break;
case kIROp_GetArrayLength:
if (auto arrayType = as<IRArrayType>(inst->getOperand(0)->getDataType()))
{
inst->replaceUsesWith(arrayType->getElementCount());
maybeRemoveOldInst(inst);
changed = true;
}
break;
case kIROp_GetResultError:
if (inst->getOperand(0)->getOp() == kIROp_MakeResultError)
{
Expand Down
47 changes: 45 additions & 2 deletions source/slang/slang-ir-specialize-arrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
// This pass is intended to specialize functions
// with struct parameters that has array fields
// to avoid performance problems for GLSL targets.

// Returns true if `type` is an `IRStructType` with array-typed fields.
// It will also specialize functions with unsized array parameters into
// sized arrays, if the function is called with an argument that has a
// sized array type.
//
bool isStructTypeWithArray(IRType* type)
{
if (auto structType = as<IRStructType>(type))
Expand All @@ -38,15 +41,55 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
bool doesParamWantSpecialization(IRParam* param, IRInst* arg)
{
SLANG_UNUSED(arg);
return isStructTypeWithArray(param->getDataType());
if (isKhronosTarget(codeGenContext->getTargetReq()))
return isStructTypeWithArray(param->getDataType());
return false;
}

bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg)
{
auto paramType = param->getDataType();
auto argType = arg->getDataType();
if (auto outTypeBase = as<IROutTypeBase>(paramType))
{
paramType = outTypeBase->getValueType();
SLANG_ASSERT(as<IRPtrTypeBase>(argType));
argType = as<IRPtrTypeBase>(argType)->getValueType();
}
else if (auto refType = as<IRRefType>(paramType))
{
paramType = refType->getValueType();
SLANG_ASSERT(as<IRPtrTypeBase>(argType));
argType = as<IRPtrTypeBase>(argType)->getValueType();
}
else if (auto constRefType = as<IRConstRefType>(paramType))
{
paramType = constRefType->getValueType();
SLANG_ASSERT(as<IRPtrTypeBase>(argType));
argType = as<IRPtrTypeBase>(argType)->getValueType();
}
auto arrayType = as<IRUnsizedArrayType>(paramType);
if (!arrayType)
return false;
auto argArrayType = as<IRArrayType>(argType);
if (!argArrayType)
return false;
if (as<IRIntLit>(argArrayType->getElementCount()))
{
return true;
}
return false;
}

CodeGenContext* codeGenContext = nullptr;
};

void specializeArrayParameters(
CodeGenContext* codeGenContext,
IRModule* module)
{
ArrayParameterSpecializationCondition condition;
condition.codeGenContext = codeGenContext;
specializeFunctionCalls(codeGenContext, module, &condition);
}

Expand Down
Loading

0 comments on commit 7e2bc8e

Please sign in to comment.