From 9913cfbf68dab8c3c8c418dd28b71c2a65a55ae0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 22 Nov 2024 18:55:47 -0500 Subject: [PATCH] [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred (#5630) * [AD] Add support for resolving custom derivatives where generic parameters can't be automatically inferred * Fix failing tests * Update custom-derivative-generic.slang --- source/slang/slang-check-decl.cpp | 76 ++++++++++++++++++- .../custom-derivative-enum-param.slang | 57 ++++++++++++++ ...trinsic.slang => custom-intrinsic-1.slang} | 0 ... => custom-intrinsic-1.slang.expected.txt} | 0 .../custom-derivative-generic.slang | 2 +- 5 files changed, 133 insertions(+), 2 deletions(-) create mode 100644 tests/autodiff/custom-derivative-enum-param.slang rename tests/autodiff/{custom-intrinsic.slang => custom-intrinsic-1.slang} (100%) rename tests/autodiff/{custom-intrinsic.slang.expected.txt => custom-intrinsic-1.slang.expected.txt} (100%) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 251ce6a696..e4206827f8 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl( SemanticsContext::ExprLocalScope scope; auto ctx = visitor->withExprLocalScope(&scope); auto subVisitor = SemanticsVisitor(ctx); - auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx); + + auto exprToCheck = attr->funcExpr; + + // If this is a generic, we want to wrap the call to the derivative method + // with the generic parameters of the source. + // + if (as(funcDecl->parentDecl) && !as(attr->funcExpr)) + { + auto genericDecl = as(funcDecl->parentDecl); + auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl); + auto appExpr = ctx.getASTBuilder()->create(); + + Index count = 0; + for (auto member : genericDecl->members) + { + if (as(member) || as(member) || + as(member)) + count++; + } + + appExpr->functionExpr = attr->funcExpr; + + for (auto arg : substArgs) + { + if (count == 0) + break; + + if (auto declRefType = as(arg)) + { + auto baseTypeExpr = ctx.getASTBuilder()->create(); + baseTypeExpr->base.type = declRefType; + auto baseTypeType = ctx.getASTBuilder()->getOrCreate(declRefType); + baseTypeExpr->type.type = baseTypeType; + + appExpr->arguments.add(baseTypeExpr); + } + else if (auto genericValParam = as(arg)) + { + auto declRef = genericValParam->getDeclRef(); + appExpr->arguments.add( + subVisitor + .ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr)); + } + else + { + SLANG_UNEXPECTED("Unhandled substitution arg type"); + } + + count--; + } + + exprToCheck = appExpr; + } + + auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx); attr->funcExpr = checkedFuncExpr; if (attr->args.getCount()) attr->args[0] = attr->funcExpr; @@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl( calleeDeclRef = calleeDeclRefExpr->declRef; auto calleeFunc = as(calleeDeclRef.getDecl()); + + if (!calleeFunc) + { + // If we couldn't find a direct function, it might be a generic. + if (auto genericDecl = as(calleeDeclRef.getDecl())) + { + calleeFunc = as(genericDecl->inner); + + if (as(resolved->type.type)) + { + // If we can't resolve a type, something went wrong. If we're working with a generic + // decl, the most likely cause is a failure of generic argument inference. + // + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::cannotResolveGenericArgumentForDerivativeFunction); + } + } + } + if (!calleeFunc) { visitor->getSink()->diagnose( diff --git a/tests/autodiff/custom-derivative-enum-param.slang b/tests/autodiff/custom-derivative-enum-param.slang new file mode 100644 index 0000000000..aa67338732 --- /dev/null +++ b/tests/autodiff/custom-derivative-enum-param.slang @@ -0,0 +1,57 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +enum MyEnum { A, B, C }; + +[BackwardDerivative(mDiff)] +float m(float x) +{ + switch (M) + { + case MyEnum.A: + return x * x; + case MyEnum.B: + return x; + case MyEnum.C: + return 3 * x; + default: + return 0; + } +} + +void mDiff(inout DifferentialPair x, float dResult) +{ + switch (M) + { + case MyEnum.A: + updateDiff(x, 2 * dResult * x.p); + break; + case MyEnum.B: + updateDiff(x, dResult); + break; + case MyEnum.C: + updateDiff(x, 3 * dResult); + break; + default: + updateDiff(x, 0); + break; + } +} + +[Differentiable] +float test(float x) +{ + return m(x); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var a = diffPair(3.0); + __bwd_diff(test)(a, 1.0); + outputBuffer[dispatchThreadID.x] = a.d; + // CHECK: 6.0 +} diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic-1.slang similarity index 100% rename from tests/autodiff/custom-intrinsic.slang rename to tests/autodiff/custom-intrinsic-1.slang diff --git a/tests/autodiff/custom-intrinsic.slang.expected.txt b/tests/autodiff/custom-intrinsic-1.slang.expected.txt similarity index 100% rename from tests/autodiff/custom-intrinsic.slang.expected.txt rename to tests/autodiff/custom-intrinsic-1.slang.expected.txt diff --git a/tests/diagnostics/custom-derivative-generic.slang b/tests/diagnostics/custom-derivative-generic.slang index 5f2cd9951f..fb65dd2cca 100644 --- a/tests/diagnostics/custom-derivative-generic.slang +++ b/tests/diagnostics/custom-derivative-generic.slang @@ -34,7 +34,7 @@ DifferentialPair dd1(DifferentialPair x) } // CHECK-DAG: {{.*}}(37): error 31151 -[BackwardDerivative(f)] +[BackwardDerivativeOf(f)] DifferentialPair df(inout DifferentialPair x, float dOut) { var primal = x.p * x.p;