From 62c84cef31a06a81c642d3d0b8e96452a1aae0d4 Mon Sep 17 00:00:00 2001 From: briannwu Date: Tue, 10 Dec 2024 16:43:48 +0000 Subject: [PATCH] [OPT] Tail Loop Optimization details: 1. Separate tailLoopOpt for A / B: tailLoopOptA / tailLoopOptB. 2. Not supported: DTV, SparseGemm. 3. Reorder load instructions with more vgprs. --- tensilelite/Tensile/KernelWriter.py | 46 +- tensilelite/Tensile/KernelWriterAssembly.py | 840 ++++++++++++++++---- tensilelite/Tensile/SolutionStructs.py | 18 +- 3 files changed, 721 insertions(+), 183 deletions(-) diff --git a/tensilelite/Tensile/KernelWriter.py b/tensilelite/Tensile/KernelWriter.py index d71f1e2419..e56764a22a 100644 --- a/tensilelite/Tensile/KernelWriter.py +++ b/tensilelite/Tensile/KernelWriter.py @@ -2627,33 +2627,59 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): # if swapGlobalRoad is true, swap the order of global read (B->A) tensorParameters1st = tensorParametersA tensorParameters2nd = tensorParametersB + tailLoopOpt1st = kernel["tailLoopOptA"] + tailLoopOpt2nd = kernel["tailLoopOptB"] + tc1 = 'A' tc2 = 'B' if self.isSwapGlobalReadOrderForDtvOrDtl(kernel): tensorParameters1st, tensorParameters2nd = tensorParameters2nd, tensorParameters1st + tailLoopOpt1st, tailLoopOpt2nd = tailLoopOpt2nd, tailLoopOpt1st tc1, tc2 = tc2, tc1 globalReadMode1st = 2 if (((tensorParameters1st["glvw"] * tensorParameters1st["bpeGR"]) < 4) or \ - kernel["tailLoopOpt"] == False) else 0 + tailLoopOpt1st == False) else 3 globalReadMode2nd = 2 if (((tensorParameters2nd["glvw"] * tensorParameters2nd["bpeGR"]) < 4) or \ - kernel["tailLoopOpt"] == False) else 0 - globalReadMode1st = 0 if tensorParameters1st["isSwizzled"] else globalReadMode1st - globalReadMode2nd = 0 if tensorParameters2nd["isSwizzled"] else globalReadMode2nd + tailLoopOpt2nd == False) else 3 + + globalReadMode1st = 3 if tensorParameters1st["isSwizzled"] else globalReadMode1st + globalReadMode2nd = 3 if tensorParameters2nd["isSwizzled"] else globalReadMode2nd module.addComment1("Update M0 for DTLDS") moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters1st) module.add(replaceHolder(moduleTmp, 0)) module.addComment1("Tail global read %s"%tc1) - module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st)) + if tailLoopOpt1st and (globalReadMode1st == 2): + module.add(self.doTailLoopOpt(kernel, tensorParameters1st)) + else: + module.add(self.globalReadDo(kernel, globalReadMode1st, tensorParameters1st)) module.addComment1("Update M0 for DTLDS") moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters2nd) module.add(replaceHolder(moduleTmp, 0)) module.addComment1("Tail global read %s"%tc2) - module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd)) - if kernel["tailLoopOpt"] and \ - (((tensorParameters1st["glvw"] * tensorParameters1st["bpeGR"]) >= 4) or \ - ((tensorParameters2nd["glvw"] * tensorParameters2nd["bpeGR"]) >= 4)): - module.add(self.tailLoopGlobalRead(kernel, tensorParameters1st, tensorParameters2nd)) + if tailLoopOpt2nd and (globalReadMode2nd == 2): + module.add(self.doTailLoopOpt(kernel, tensorParameters2nd)) + else: + module.add(self.globalReadDo(kernel, globalReadMode2nd, tensorParameters2nd)) + + doA = False + doB = False + if globalReadMode1st == 3: + if tc1 == 'A': + doA = True if (tensorParameters1st["bpeGR"] % 4 != 0) and (not kernel["ProblemType"]["TLU%s"%(tc1)]) else False + else: + doB = True if (tensorParameters1st["bpeGR"] % 4 != 0) and (not kernel["ProblemType"]["TLU%s"%(tc1)]) else False + if globalReadMode2nd == 3: + if tc2 == 'A': + doA = True if (tensorParameters2nd["bpeGR"] % 4 != 0) and (not kernel["ProblemType"]["TLU%s"%(tc2)]) else False + else: + doB = True if (tensorParameters2nd["bpeGR"] % 4 != 0) and (not kernel["ProblemType"]["TLU%s"%(tc2)]) else False + + if doA or doB: + if tc1 == 'A': + module.add(self.tailLoopGlobalRead(kernel, tensorParameters1st, tensorParameters2nd, doA, doB)) + else: + module.add(self.tailLoopGlobalRead(kernel, tensorParameters2nd, tensorParameters1st, doA, doB)) module.add(self._wait(kernel, tensorParameters1st, tensorParameters2nd, 0, -1, -1, "2wait for global read")) module.add(self._syncThreads(kernel)) diff --git a/tensilelite/Tensile/KernelWriterAssembly.py b/tensilelite/Tensile/KernelWriterAssembly.py index 53081a09c2..cf2e2bdb06 100644 --- a/tensilelite/Tensile/KernelWriterAssembly.py +++ b/tensilelite/Tensile/KernelWriterAssembly.py @@ -49,6 +49,7 @@ from .Activation import ActivationType from .Utils import DataDirection from .CustomKernels import isCustomKernelConfig +from dataclasses import dataclass from math import ceil, log, floor from copy import deepcopy @@ -58,10 +59,24 @@ import os import subprocess +@dataclass +class TailOptParams: + idx: int = 0 + behavior: str = "" + jumpLabel: Label = None + kLabelsList: List[Label] = None + tmpVgpr: int = None + kSgpr: int = None + vgprSlot: List[int] = None + loadNum: int = 0 + periodParam: List[int] = None + preDirectToLdsLoads: int = 0 + firstLoop: int = 0 + finalLoop: int = 0 + ################################################################################ # Assembly Kernel ################################################################################ - class KernelWriterAssembly(KernelWriter): ############################################################################## @@ -4117,19 +4132,45 @@ def removeStagger(self, kernel, tP): return imod - def tailLoopGlobalRead(self, kernel, tPA, tPB): + ############################################################################## + # Using wider load instructions to improve the GR efficiency in tail loop. + # If loading size is smaller than a dword(32bit), it will return 0 instead. + # Need to call buffer_load_d16 to load the data which is out of boundary. + ############################################################################## + def tailLoopGlobalRead(self, kernel, tPA, tPB, doA, doB): imod = Module("tailLoopGlobalRead") - doA = True if ((tPA["glvw"] * tPA["bpeGR"] >= 4) and (tPA["bpeGR"] % 4 != 0)) else False - doB = True if ((tPB["glvw"] * tPB["bpeGR"] >= 4) and (tPB["bpeGR"] % 4 != 0)) else False - loadALabel = Label(label="LOAD_A", comment="") - loadBLabel = Label(label="LOAD_B", comment="") - mergeALabel = Label(label="MERGE_A", comment="") - mergeBLabel = Label(label="MERGE_B", comment="") - skipLabel = Label(label="SKIP_LOAD_SINGLE_ELEMENT", comment="") + + tagList = ["AddressA", "AddressB", "WrapUA", "WrapUB", "StaggerU", "WGM", \ + "StaggerUIter", "GlobalReadIncsA", "GlobalReadIncsB"] + lastRegTag = None + for i in range(0, self.sgprPool.size()): + regTag = self.sgprPool.pool[i].tag + if regTag != lastRegTag: + lastRegTag = regTag + if (lastRegTag not in self.states.nonPostLoopSgpr) and \ + (self.sgprPool.pool[i].status == RegisterPool.Status.InUse) and \ + (lastRegTag in tagList): + imod.add(self.undefineSgpr(regTag)) + + loadALabel = Label(label="LoadA", comment="") + loadBLabel = Label(label="LoadB", comment="") + mergeALabel = Label(label="MergeA", comment="") + mergeBLabel = Label(label="MergeB", comment="") + skipLabel = Label(label="TailGlobalLoadEnd", comment="") + checkALabel = Label(label="CheckA", comment="") + checkBLabel = Label(label="CheckB", comment="") + checkAOOBLabel = Label(label="CheckA_OOB", comment="") + checkALoopBeginLabel = Label(label="CheckLoopBeginA", comment="") + checkBOOBLabel = Label(label="CheckB_OOB", comment="") + checkBLoopBeginLabel = Label(label="CheckLoopBeginB", comment="") + reloadALabel = Label(label="Reload_A", comment="") + reloadBLabel = Label(label="Reload_B", comment="") lspA = kernel[tPA["lsp"]] lscA = kernel[tPA["lsc"]] lspB = kernel[tPB["lsp"]] lscB = kernel[tPB["lsc"]] + nlcA = kernel["NumLoadsCoalescedA"] + nlcB = kernel["NumLoadsCoalescedB"] nlpA = kernel["NumLoadsPerpendicularA"] nlpB = kernel["NumLoadsPerpendicularB"] numElementsPer4BytesA = int(4 / tPA["bpeGR"]) @@ -4137,6 +4178,24 @@ def tailLoopGlobalRead(self, kernel, tPA, tPB): maxNumOOBElementsA = numElementsPer4BytesA - 1 maxNumOOBElementsB = numElementsPer4BytesB - 1 + glvwWorkaround = 8 * kernel["ProblemType"]["DataType"].numRegisters() + dataTypeA = kernel["ProblemType"]["DataType"] if tPA["glvw"] < glvwWorkaround else \ + kernel["ProblemType"]["DataTypeA"] + dataTypeB = kernel["ProblemType"]["DataType"] if tPB["glvw"] < glvwWorkaround else \ + kernel["ProblemType"]["DataTypeB"] + numElementsPerLoadA = -1 + numElementsPerLoadB = -1 + if dataTypeA.isHalf() or dataTypeA.isBFloat16(): + if tPA["glvw"] > 1 and kernel["AssertSummationElementMultiple"] % 2 == 0: + numElementsPerLoadA = 2 + if dataTypeB.isHalf() or dataTypeB.isBFloat16(): + if tPB["glvw"] > 1 and kernel["AssertSummationElementMultiple"] % 2 == 0: + numElementsPerLoadB = 2 + if numElementsPerLoadA == 2: + doA = False + if numElementsPerLoadB == 2: + doB = False + self.param = TailOptParams() def LOAD_FUNC(tP, tmpVgpr, behavior, jumpLabel, tileSgpr, kSgpr): tc = tP["tensorChar"] bpe = tP["bpeGR"] @@ -4156,6 +4215,9 @@ def func(idx, bevavior, jumpLabel, tileSgpr, kSgpr): if (idx != 0): imod.add(SCmpEQU32(src0=sgpr(tileSgpr), src1=idx, comment="")) imod.add(SCBranchSCC1(labelName=labelTmp.getLabelName(), comment="")) + else: + imod.add(SCmpEQU32(src0=sgpr(tileSgpr), src1=idx, comment="")) + imod.add(SCBranchSCC0(labelName=jumpLabel.getLabelName(), comment="")) func(idx-1, bevavior, jumpLabel, tileSgpr, kSgpr) imod.add(labelTmp) for i in range(tP["glvw"], 0, -1): @@ -4163,74 +4225,97 @@ def func(idx, bevavior, jumpLabel, tileSgpr, kSgpr): labelStr2 = labelStr+"_K"+str(i) labelTmp = Label(label=labelStr2, comment="") kLabelsList.append(labelTmp) - imod.add(self.globalReadGuardK(kernel, tP, True, idx, jumpLabel, tmpVgpr, kLabelsList, behavior, kSgpr)) + if idx == (numTiles - 1): + finalLoop = 1 + else: + finalLoop = 0 + + #self.param = TailOptParams(idx, jumpLabel, tmpVgpr, kLabelsList, behavior, kSgpr, \ + # [], [], 0, 0, 0, finalLoop) + self.param.idx = idx + self.param.jumpLabel = jumpLabel + self.param.tmpVgpr = tmpVgpr + self.param.kLabelsList = kLabelsList + self.param.behavior = behavior + self.param.kSgpr = kSgpr + self.param.finalLoop = finalLoop + imod.add(self.globalReadGuardK(kernel, tP, 1, self.param)) +# imod.add(self.globalReadGuardK(kernel, tP, 1, idx, jumpLabel, tmpVgpr, kLabelsList, behavior, kSgpr, \ +# [], [], 0, 0, finalLoop)) func(numTiles - 1, behavior, jumpLabel, tileSgpr, kSgpr) tmpSgprA1 = self.sgprPool.checkOut(1, preventOverflow=False) tmpSgprB1 = self.sgprPool.checkOut(1, preventOverflow=False) -# tmpSgprA2 = self.sgprPool.checkOut(1, preventOverflow=False) -# tmpSgprB2 = self.sgprPool.checkOut(1, preventOverflow=False) + tmpSgprA2 = self.sgprPool.checkOut(1, preventOverflow=False) + tmpSgprB2 = self.sgprPool.checkOut(1, preventOverflow=False) tmpSgpr = self.sgprPool.checkOutAligned(2, 2, preventOverflow=False) tmpSgprQregA = self.sgprPool.checkOut(1, preventOverflow=False) tmpSgprQregB = self.sgprPool.checkOut(1, preventOverflow=False) tmpSgprKA = self.sgprPool.checkOut(1, preventOverflow=False) tmpSgprKB = self.sgprPool.checkOut(1, preventOverflow=False) + tmpSgpr1 = self.sgprPool.checkOut(1, preventOverflow=False) + tmpSgpr2 = self.sgprPool.checkOut(1, preventOverflow=False) + tmpSgpr3 = self.sgprPool.checkOut(1, preventOverflow=False) + tmpSgpr4 = self.sgprPool.checkOut(1, preventOverflow=False) + loopIdx = self.states.unrollIdx # for A if doA: if (kernel["WaveSeparateGlobalReadA"] == 0): tmpSgprA = tmpSgprQregA -# else: -# tmpSgprA = tmpSgprA2 + else: + tmpSgprA = tmpSgprA2 imod.add(SSubU32(dst=sgpr(tmpSgprA1), src0=sgpr("SizeI"), src1=1)) imod.add(scalarStaticDivideAndRemainder(tmpSgprA, tmpSgprA, tmpSgprA1, \ kernel["MacroTile0"], \ RegisterPoolResource(tmpSgpr, 2), 1)) -# if (kernel["WaveSeparateGlobalReadA"] == 1): -# imod.add(scalarStaticDivideAndRemainder(tmpSgprQregA, tmpSgprQregA, tmpSgprA, \ -# (nlpA * lspA), \ -# RegisterPoolResource(tmpSgpr, 2), 1)) + if (kernel["WaveSeparateGlobalReadA"] > 0): + imod.add(scalarStaticDivideAndRemainder(tmpSgprQregA, tmpSgprQregA, tmpSgprA, \ + (nlpA * lspA), \ + RegisterPoolResource(tmpSgpr, 2), 1)) imod.add(SLShiftRightB32(dst=sgpr(tmpSgprQregA), shiftHex=hex(log2(lspA)), \ src=sgpr(tmpSgprQregA), comment="divide lsp")) + imod.add(SMulI32(dst=sgpr(tmpSgprQregA), src0=sgpr(tmpSgprQregA), src1=nlcA, comment="")) + imod.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(lscA)), \ + src=sgpr("LoopCounterL"), comment="")) + imod.add(SAddI32(dst=sgpr(tmpSgprQregA), src0=sgpr(tmpSgprQregA), src1=sgpr(tmpSgpr), comment="")) + imod.add(scalarStaticDivideAndRemainder(tmpSgpr, tmpSgprA1, "SizesSum+%u"%loopIdx, \ + kernel["DepthU"], RegisterPoolResource(tmpSgpr, 2), 2)) # for B if doB: if (kernel["WaveSeparateGlobalReadB"] == 0): tmpSgprB = tmpSgprQregB -# else: -# tmpSgprB = tmpSgprA2 + else: + tmpSgprB = tmpSgprB2 imod.add(SSubU32(dst=sgpr(tmpSgprB1), src0=sgpr("SizeJ"), src1=1)) imod.add(scalarStaticDivideAndRemainder(tmpSgprB, tmpSgprB, tmpSgprB1, \ kernel["MacroTile1"], \ RegisterPoolResource(tmpSgpr, 2), 1)) -# if (kernel["WaveSeparateGlobalReadB"] == 1): -# imod.add(scalarStaticDivideAndRemainder(tmpSgprQregB, tmpSgprQregB, tmpSgprB, \ -# (nlpB * lspB), \ -# RegisterPoolResource(tmpSgpr, 2), 1)) + if (kernel["WaveSeparateGlobalReadB"] > 0): + imod.add(scalarStaticDivideAndRemainder(tmpSgprQregB, tmpSgprQregB, tmpSgprB, \ + (nlpB * lspB), \ + RegisterPoolResource(tmpSgpr, 2), 1)) imod.add(SLShiftRightB32(dst=sgpr(tmpSgprQregB), shiftHex=hex(log2(lspB)), \ src=sgpr(tmpSgprQregB), comment="divide lsp")) - + imod.add(SMulI32(dst=sgpr(tmpSgprQregB), src0=sgpr(tmpSgprQregB), src1=nlcB, comment="")) + imod.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(lscB)), \ + src=sgpr("LoopCounterL"), comment="")) + imod.add(SAddI32(dst=sgpr(tmpSgprQregB), src0=sgpr(tmpSgprQregB), src1=sgpr(tmpSgpr), comment="")) + imod.add(scalarStaticDivideAndRemainder(tmpSgpr, tmpSgprB1, "SizesSum+%u"%loopIdx, \ + kernel["DepthU"], RegisterPoolResource(tmpSgpr, 2), 2)) # A if doA: - imod.add(SAndB32(dst=sgpr(tmpSgprA1), src0=sgpr("LoopCounterL"), src1=(tPA["glvw"] - 1), \ + imod.add(SAndB32(dst=sgpr(tmpSgprA1), src0=sgpr(tmpSgprA1), src1=(tPA["glvw"] - 1), \ comment="s[sgprLoopCounterL] % glvw")) imod.add(SAndB32(dst=sgpr(tmpSgprKA), src0=sgpr(tmpSgprA1), src1=hex(maxNumOOBElementsA), \ comment=" % numElementsPer4Bytes")) - imod.add(SAndB32(dst=sgpr(tmpSgpr), src0=sgpr(tmpSgprA1), src1=maxNumOOBElementsA, \ - comment="LoopCounterL + maxNumOOBElementsA")) - imod.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(numElementsPer4BytesA)), \ - src=sgpr(tmpSgpr), comment="divide numElementsPer4BytesA")) # B if doB: - imod.add(SAndB32(dst=sgpr(tmpSgprB1), src0=sgpr("LoopCounterL"), src1=(tPB["glvw"] - 1), \ + imod.add(SAndB32(dst=sgpr(tmpSgprB1), src0=sgpr(tmpSgprB1), src1=(tPB["glvw"] - 1), \ comment="s[sgprLoopCounterL] % glvw")) imod.add(SAndB32(dst=sgpr(tmpSgprKB), src0=sgpr(tmpSgprB1), src1=hex(maxNumOOBElementsB), \ comment=" % numElementsPer4Bytes")) - imod.add(SAndB32(dst=sgpr(tmpSgpr), src0=sgpr(tmpSgprB1), src1=maxNumOOBElementsB, \ - comment="LoopCounterL + maxNumOOBElementsB")) - imod.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(numElementsPer4BytesB)), \ - src=sgpr(tmpSgpr), \ - comment="divide numElementsPer4BytesB")) ######################################################################################################### numDwordA = (tPA["glvw"] * tPA["bpeGR"]) >> 2 numDwordB = (tPB["glvw"] * tPB["bpeGR"]) >> 2 @@ -4240,10 +4325,11 @@ def func(idx, bevavior, jumpLabel, tileSgpr, kSgpr): if doA or doB: tmpVgpr = self.vgprPool.checkOut(numTmpVgpr) + imod.add(SMovB32(sgpr(tmpSgpr4), 0)) # A - imod.add(loadALabel) if doA: + imod.add(loadALabel) imod.add(SCmpEQU32(src0=sgpr(tmpSgprKA), src1=0, \ comment="Valid loading size per thread is multiples of 4 bytes")) if doB: @@ -4254,8 +4340,8 @@ def func(idx, bevavior, jumpLabel, tileSgpr, kSgpr): LOAD_FUNC(tPA, tmpVgpr, "LOAD", mergeALabel, tmpSgprQregA, tmpSgprA1) # B - imod.add(loadBLabel) if doB: + imod.add(loadBLabel) imod.add(SCmpEQU32(src0=sgpr(tmpSgprKB), src1=0, \ comment="Valid loading size per thread is multiples of 4 bytes")) if doA: @@ -4266,39 +4352,225 @@ def func(idx, bevavior, jumpLabel, tileSgpr, kSgpr): imod.add(SCBranchSCC1(labelName=mergeBLabel.getLabelName(), comment="Skip loading B")) LOAD_FUNC(tPB, tmpVgpr + (maxNumOOBElementsA * numDwordA), "LOAD", mergeBLabel, \ tmpSgprQregB, tmpSgprB1) - # A - imod.add(mergeALabel) if doA: - imod.add(SCmpEQU32(src0=sgpr(tmpSgprKA), src1=0, \ - comment="Valid loading size per thread is multiples of 4 bytes")) + imod.add(mergeALabel) + if numElementsPerLoadA != 2: + imod.add(SCmpEQU32(src0=sgpr(tmpSgprKA), src1=0, \ + comment="Valid loading size per thread is multiples of 4 bytes")) + if doB: + imod.add(SCBranchSCC1(labelName=mergeBLabel.getLabelName(), comment="Skip mergeing A")) + LOAD_FUNC(tPA, tmpVgpr, "MERGE", mergeBLabel, tmpSgprQregA, tmpSgprA1) + else: + imod.add(SCBranchSCC1(labelName=checkAOOBLabel.getLabelName(), comment="Skip mergeing A")) + LOAD_FUNC(tPA, tmpVgpr, "MERGE", checkAOOBLabel, tmpSgprQregA, tmpSgprA1) + + if doB and kernel["DirectToLds%s"%tPA["tensorChar"]]: + imod.add(SMovB32(dst=mgpr(0), src=hex(kernel["LdsNumBytes"]), \ + comment="Restore LDS clamp at %u bytes HERE"%(kernel["LdsNumBytes"]))) + # B + if doB: + imod.add(mergeBLabel) + if numElementsPerLoadB != 2: + if doA and kernel["DirectToLds%s"%tPA["tensorChar"]]: + imod.add(SMovB32(dst=mgpr(0), src=hex(kernel["LdsNumBytes"]), \ + comment="Restore LDS clamp at %u bytes HERE"%(kernel["LdsNumBytes"]))) + + imod.add(SCmpEQU32(src0=sgpr(tmpSgprKB), src1=0, \ + comment="Valid loading size per thread is multiples of 4 bytes")) + + if doA: + imod.add(SCBranchSCC1(labelName=checkAOOBLabel.getLabelName(), comment="Skip mergeing B")) + LOAD_FUNC(tPB, tmpVgpr + (maxNumOOBElementsA * numDwordA), "MERGE", checkAOOBLabel, \ + tmpSgprQregB, tmpSgprB1) + else: + imod.add(SCBranchSCC1(labelName=checkBOOBLabel.getLabelName(), comment="Skip mergeing B")) + LOAD_FUNC(tPB, tmpVgpr + (maxNumOOBElementsA * numDwordA), "MERGE", checkBOOBLabel, \ + tmpSgprQregB, tmpSgprB1) + + if doA: + imod.add(checkAOOBLabel) + imod.add(SCmpEQU32(src0=sgpr(tmpSgpr4), src1=0, comment="")) + imod.add(SCMovB32(sgpr(tmpSgprQregA), nlpA * nlcA)) + if not doB: + imod.add(SAddU32(sgpr(tmpSgpr4), sgpr(tmpSgpr4), 1)) + imod.add(checkALoopBeginLabel) + imod.add(SSubI32(dst=sgpr(tmpSgprQregA), src0=sgpr(tmpSgprQregA), src1=1)) + imod.add(scalarStaticDivideAndRemainder(tmpSgprA2, tmpSgprA2, "SizesSum+%u"%loopIdx, \ + kernel["DepthU"], \ + RegisterPoolResource(tmpSgpr, 2), 1)) + imod.add(SSubU32(dst=sgpr(tmpSgprB2), src0=sgpr("SizeI"), src1=1)) + imod.add(scalarStaticDivideAndRemainder(tmpSgpr1, tmpSgpr1, tmpSgprB2, \ + kernel["MacroTile0"], \ + RegisterPoolResource(tmpSgpr, 2), 1)) + imod.add(SMulI32(dst=sgpr(tmpSgpr2), src0=sgpr(tmpSgprA2), src1=sgpr(tmpSgpr1))) + numThreadsCoalA = lscA // tPA["glvw"] + numThreadsPerpA = kernel["NumThreads"] // numThreadsCoalA + if kernel["WaveSeparateGlobalReadA"] == 2 and nlcA == 1: + ofst = kernel["NumLoadsPerpendicularA"]*kernel["NumThreads"]//kernel["WavefrontSize"] + else: + ofst = 1 + + if kernel["WaveSeparateGlobalReadA"] == 2: + numVecA = (kernel["NumThreads"] // kernel["LocalSplitU"]) // (lscA // tPA["glvw"]) + baseValue = (lspA * nlpA) * (numVecA - 1) + (kernel["LocalSplitU"] - 1) + elif kernel["WaveSeparateGlobalReadA"] == 1: + baseValue = (lspA * nlpA) * (kernel["LocalSplitU"] - 1) + (lspA - 1) + else: + baseValue = (numThreadsPerpA - 1) + if nlcA > 1: + baseValue = 0 + imod.add(SMovB32(dst=sgpr(tmpSgpr3), src=baseValue)) + startIdx = nlpA * nlcA - 1 + labelCurr = Label(label="A"+str(startIdx), comment="") + for n in range(startIdx, -1, -1): + labelNext = Label(label="A"+str(n - 1), comment="") + imod.add(labelCurr) + imod.add(SCmpEQU32(src0=sgpr(tmpSgprQregA), src1=n, comment="")) + if n == 0: + if doB: + imod.add(SCBranchSCC0(labelName=checkBOOBLabel.getLabelName(), comment="")) + else: + imod.add(SCBranchSCC0(labelName=skipLabel.getLabelName(), comment="")) + imod.add(SAddU32(dst=sgpr(tmpSgpr3), src0=sgpr(tmpSgpr3), src1=(tPA["glvw"] - 1))) + else: + imod.add(SCBranchSCC0(labelName=labelNext.getLabelName(), comment="")) + stride = "StrideA%s"%(self.states.indexChars[tPA['tileIdx']]) + q = n // nlcA + r = n % nlcA + imod.add(SMulI32(dst=sgpr(tmpSgprA2), src0=q, src1=lspA)) + imod.add(SAddU32(dst=sgpr(tmpSgpr3), src0=sgpr(tmpSgpr3), src1=sgpr(tmpSgprA2))) + imod.add(SMulI32(dst=sgpr(tmpSgpr3), src0=sgpr(tmpSgpr3), src1=sgpr(stride))) + imod.add(SAddU32(dst=sgpr(tmpSgpr3), src0=sgpr(tmpSgpr3), src1=(tPA["glvw"] - 1))) + if (nlcA > 1): + imod.add(SMulI32(dst=sgpr(tmpSgprA2), src0=r, src1=lscA)) + imod.add(SAddU32(dst=sgpr(tmpSgpr3), src0=sgpr(tmpSgpr3), src1=sgpr(tmpSgprA2))) + imod.add(SBranch(checkALabel.getLabelName(), comment="")) + labelCurr = labelNext + imod.add(checkALabel) + imod.add(SCmpGeI32(src0=sgpr(tmpSgprQregA), src1=0, comment="")) + if doB: - imod.add(SCBranchSCC1(labelName=mergeBLabel.getLabelName(), comment="Skip mergeing A")) - LOAD_FUNC(tPA, tmpVgpr, "MERGE", mergeBLabel, tmpSgprQregA, tmpSgprA1) + imod.add(SCBranchSCC0(checkBOOBLabel.getLabelName(), comment="")) else: - imod.add(SCBranchSCC1(labelName=skipLabel.getLabelName(), comment="Skip mergeing A")) - LOAD_FUNC(tPA, tmpVgpr, "MERGE", skipLabel, tmpSgprQregA, tmpSgprA1) - # B - imod.add(mergeBLabel) + imod.add(SCBranchSCC0(skipLabel.getLabelName(), comment="")) + imod.add(SCmpGtU32(src0=sgpr(tmpSgpr3), src1=sgpr(tmpSgpr2), comment="lastIdxLoaded > last available index ?")) + + if doB: + imod.add(SCBranchSCC1(labelName=checkBOOBLabel.getLabelName(), comment="Reload")) + else: + imod.add(SCBranchSCC1(labelName=loadALabel.getLabelName(), comment="Reload")) + imod.add(SBranch(labelName=checkALoopBeginLabel.getLabelName(), comment="Re check")) + if doB: - imod.add(SCmpEQU32(src0=sgpr(tmpSgprKB), src1=0, \ - comment="Valid loading size per thread is multiples of 4 bytes")) - imod.add(SCBranchSCC1(labelName=skipLabel.getLabelName(), comment="Skip mergeing B")) - LOAD_FUNC(tPB, tmpVgpr + (maxNumOOBElementsA * numDwordA), "MERGE", skipLabel, \ - tmpSgprQregB, tmpSgprB1) + imod.add(checkBOOBLabel) + imod.add(SCmpEQU32(src0=sgpr(tmpSgpr4), src1=0, comment="")) + imod.add(SCMovB32(sgpr(tmpSgprQregB), nlpB * nlcB)) + imod.add(SAddU32(sgpr(tmpSgpr4), sgpr(tmpSgpr4), 1)) + imod.add(checkBLoopBeginLabel) + imod.add(SSubI32(dst=sgpr(tmpSgprQregB), src0=sgpr(tmpSgprQregB), src1=1)) + imod.add(scalarStaticDivideAndRemainder(tmpSgprA2, tmpSgprA2, "SizesSum+%u"%loopIdx, \ + kernel["DepthU"], \ + RegisterPoolResource(tmpSgpr, 2), 1)) + imod.add(SSubU32(dst=sgpr(tmpSgprB2), src0=sgpr("SizeJ"), src1=1)) + imod.add(scalarStaticDivideAndRemainder(tmpSgpr1, tmpSgpr1, tmpSgprB2, \ + kernel["MacroTile1"], \ + RegisterPoolResource(tmpSgpr, 2), 1)) + imod.add(SMulI32(dst=sgpr(tmpSgprB2), src0=sgpr(tmpSgprA2), src1=sgpr(tmpSgpr1))) + numThreadsCoalB = lscB // tPB["glvw"] + numThreadsPerpB = kernel["NumThreads"] // numThreadsCoalB + if kernel["WaveSeparateGlobalReadB"] == 2 and nlcB == 1: + ofst = kernel["NumLoadsPerpendicularB"]*kernel["NumThreads"]//kernel["WavefrontSize"] + else: + ofst = 1 + if kernel["WaveSeparateGlobalReadB"] == 2: + numVecB = (kernel["NumThreads"] // kernel["LocalSplitU"]) // (lscA // tPB["glvw"]) + baseValue = (lspB * nlpB) * (numVecB - 1) + (kernel["LocalSplitU"] - 1) + elif kernel["WaveSeparateGlobalReadB"] == 1: + baseValue = (lspB * nlpB) * (kernel["LocalSplitU"] - 1) + (lspB - 1) + else: + baseValue = (numThreadsPerpB - 1) + if nlcB > 1: + baseValue = 0 + + imod.add(SMovB32(dst=sgpr(tmpSgpr1), src=baseValue)) + startIdx = nlpB * nlcB - 1 + labelCurr = Label(label="B"+str(startIdx), comment="") + for n in range(startIdx, -1, -1): + labelNext = Label(label="B"+str(n - 1), comment="") + imod.add(labelCurr) + imod.add(SCmpEQU32(src0=sgpr(tmpSgprQregB), src1=n, comment="")) + if n == 0: + if doA: + imod.add(SCBranchSCC0(labelName=reloadALabel.getLabelName(), comment="")) + else: + imod.add(SCBranchSCC0(labelName=skipLabel.getLabelName(), comment="")) + imod.add(SAddU32(dst=sgpr(tmpSgpr1), src0=sgpr(tmpSgpr1), src1=(tPB["glvw"] - 1))) + else: + imod.add(SCBranchSCC0(labelName=labelNext.getLabelName(), comment="")) + stride = "StrideB%s"%(self.states.indexChars[tPB['tileIdx']]) + q = n // nlcB + r = n % nlcB + imod.add(SMulI32(dst=sgpr(tmpSgprA2), src0=q, src1=lspB)) + imod.add(SAddU32(dst=sgpr(tmpSgpr1), src0=sgpr(tmpSgpr1), src1=sgpr(tmpSgprA2))) + imod.add(SMulI32(dst=sgpr(tmpSgpr1), src0=sgpr(tmpSgpr1), src1=sgpr(stride))) + imod.add(SAddU32(dst=sgpr(tmpSgpr1), src0=sgpr(tmpSgpr1), src1=(tPB["glvw"] - 1))) + if (nlcA > 1): + imod.add(SMulI32(dst=sgpr(tmpSgprA2), src0=r, src1=lscB)) + imod.add(SAddU32(dst=sgpr(tmpSgpr1), src0=sgpr(tmpSgpr1), src1=sgpr(tmpSgprA2))) + imod.add(SBranch(checkBLabel.getLabelName(), comment="")) + labelCurr = labelNext + imod.add(checkBLabel) + imod.add(SCmpGeI32(src0=sgpr(tmpSgprQregB), src1=0, comment="")) + + if doA: + imod.add(SCBranchSCC0(reloadALabel.getLabelName(), comment="")) + else: + imod.add(SCBranchSCC0(skipLabel.getLabelName(), comment="")) + imod.add(SCmpGtU32(src0=sgpr(tmpSgpr1), src1=sgpr(tmpSgprB2), comment="lastIdxLoaded > last available index ?")) + + if doA: + imod.add(SCBranchSCC1(labelName=reloadALabel.getLabelName(), comment="Reload")) + else: + imod.add(SCBranchSCC1(labelName=loadBLabel.getLabelName(), comment="Reload")) + imod.add(SBranch(labelName=checkBLoopBeginLabel.getLabelName(), comment="Re check")) + + if doA and doB: + imod.add(reloadALabel) + imod.add(SCmpGeI32(src0=sgpr(tmpSgprQregA), src1=0, comment="")) + imod.add(SCBranchSCC0(reloadBLabel.getLabelName(), comment="")) + imod.add(SCmpGtU32(src0=sgpr(tmpSgpr3), src1=sgpr(tmpSgpr2), comment="lastIdxLoaded > last available index ?")) + imod.add(SCBranchSCC1(loadALabel.getLabelName(), comment="")) + imod.add(reloadBLabel) + imod.add(SCmpGeI32(src0=sgpr(tmpSgprQregB), src1=0, comment="")) + imod.add(SCBranchSCC0(skipLabel.getLabelName(), comment="")) + imod.add(SCmpGtU32(src0=sgpr(tmpSgpr1), src1=sgpr(tmpSgprB2), comment="lastIdxLoaded > last available index ?")) + imod.add(SCBranchSCC1(loadBLabel.getLabelName(), comment="")) imod.add(skipLabel) + + if doA and not doB and kernel["DirectToLds%s"%tPA["tensorChar"]]: + imod.add(SMovB32(dst=mgpr(0), src=hex(kernel["LdsNumBytes"]), \ + comment="Restore LDS clamp at %u bytes HERE"%(kernel["LdsNumBytes"]))) + elif doB and kernel["DirectToLds%s"%tPB["tensorChar"]]: + imod.add(SMovB32(dst=mgpr(0), src=hex(kernel["LdsNumBytes"]), \ + comment="Restore LDS clamp at %u bytes HERE"%(kernel["LdsNumBytes"]))) + if doA or doB: self.vgprPool.checkIn(tmpVgpr) self.sgprPool.checkIn(tmpSgprA1) self.sgprPool.checkIn(tmpSgprB1) -# self.sgprPool.checkIn(tmpSgprA2) -# self.sgprPool.checkIn(tmpSgprB2) + self.sgprPool.checkIn(tmpSgprA2) + self.sgprPool.checkIn(tmpSgprB2) self.sgprPool.checkIn(tmpSgpr) self.sgprPool.checkIn(tmpSgprQregA) self.sgprPool.checkIn(tmpSgprQregB) self.sgprPool.checkIn(tmpSgprKA) self.sgprPool.checkIn(tmpSgprKB) + self.sgprPool.checkIn(tmpSgpr1) + self.sgprPool.checkIn(tmpSgpr2) + self.sgprPool.checkIn(tmpSgpr3) + self.sgprPool.checkIn(tmpSgpr4) return imod ############################################################################## @@ -6271,12 +6543,11 @@ def globalReadIncrementAB(self, kernel, tPA, tPB, loopIdx, prefetchIndex): ############################################################################## # Global Read: - # globalReadGuardK is called for loads in the tail loop + # globalReadTrueGuardK is called for loads in the tail loop # Must ensure each load is in bounds - either using buffer bounds # or exec-mask checks. ############################################################################## - def globalReadGuardK(self, kernel, tP, doTailOpt = False, \ - idx = 0, jumpLabel = None, tmpVgpr = None, kLabelsList = [], behavior = "", kSgpr = None): + def globalReadGuardK(self, kernel, tP, doTailOpt = 0, optParams = None): module = Module("globalReadGuardK") tc = tP["tensorChar"] problemType = self.states.kernel["ProblemType"] @@ -6335,10 +6606,35 @@ def globalReadGuardK(self, kernel, tP, doTailOpt = False, \ zeroVgpr = self.vgprPool.checkOut(1,"zeroVgpr") module.add(VMovB32(dst=vgpr(zeroVgpr), src=hex(0), comment="zero")) - def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", jumpLabel = None, \ - numElementsPer4Bytes = 0, kSgpr = None, doTailOpt = False): + def globalReadGuardKBody(tP, optParams = None): + if optParams != None: + jumpLabel = optParams.jumpLabel + idx = optParams.idx + tmpVgpr = optParams.tmpVgpr + kLabelsList = optParams.kLabelsList + behavior = optParams.behavior + kSgpr = optParams.kSgpr + vgprSlot = optParams.vgprSlot + periodParam = optParams.periodParam + loadNum = optParams.loadNum + preDirectToLdsLoads = optParams.preDirectToLdsLoads + finalLoop = optParams.finalLoop + else: + jumpLabel = None + idx = 0 + tmpVgpr = None + kLabelsList = [] + behavior = "" + kSgpr = None + vgprSlot = [] + periodParam = [] + loadNum = 0 + preDirectToLdsLoads = 0 + finalLoop = 0 + tc = tP["tensorChar"] self.vgprs.globalReadRegisters[tc] = [] + tcDataType = "" if tc == "Metadata" else tc graIdx = 0 g2lIdx = 0 @@ -6350,16 +6646,32 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju isLds = True if kernel["DirectToLds%s"%tc] else False directToLdsLoads = 0 + if doTailOpt == 2 and behavior == "LOAD": + directToLdsLoads += preDirectToLdsLoads prevLdsOffset = 0 instOffset = 0 loopCnt = -1 - for perp in range(0, tP["nrp"]): - for sPerp in range(0, tP["nrpv"]): - for para in range(0, tP["nrc"]): - for sPara in range(0, tP["nrcv"]//tP["nrcvpi"]): + vgprIdx = 0 + loadCnt = 0 + perpStart = periodParam[0] if doTailOpt == 2 else 0 + perpEnd = periodParam[1] if doTailOpt == 2 else tP["nrp"] + sPerpStart = periodParam[2] if doTailOpt == 2 else 0 + sPerpEnd = periodParam[3] if doTailOpt == 2 else tP["nrpv"] + paraStart = periodParam[4] if doTailOpt == 2 else 0 + paraEnd = periodParam[5] if doTailOpt == 2 else tP["nrc"] + sParaStart = periodParam[6] if doTailOpt == 2 else 0 + sParaEnd = periodParam[7] if doTailOpt == 2 else (tP["nrcv"]//tP["nrcvpi"]) + rStart = periodParam[8] if doTailOpt == 2 else 0 + rEnd = periodParam[9] if doTailOpt == 2 else 0 + + for perp in range(perpStart, perpEnd): + for sPerp in range(sPerpStart, sPerpEnd): + for para in range(paraStart, paraEnd): + for sPara in range(sParaStart, sParaEnd): i = sPara + (tP["nrcv"] // tP["nrcvpi"]) * (para + tP["nrc"] * (sPerp + tP["nrpv"] * perp)) - loopCnt += 1 + loopCnt = sPara + para * sParaEnd + sPerp * paraEnd * sParaEnd + perp * sPerpEnd * paraEnd * sParaEnd + graIdx = i * self.states.rpgo if kernel["BufferLoad"] else i * self.states.rpga g2lIdx = i * loadWidth * tP["bpeRatio"] if (tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc] and kernel["ConvertAfterDS"]: @@ -6368,7 +6680,6 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju g2lIdx *= tP["bpe"] // tP["bpeGR"] destVgprHi = None - destVgprHitmp = None # Fix tmpVgprIdx = tmpVgpr dataIsByte = False packInt8Code = None @@ -6385,8 +6696,15 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju # so far, limit to double only numLoadVectorComp = numLoadVectorComp // kernel["GlobalReadVectorWidth%c"%tc] int8TempVgpr = numLoadVectorComp - 1 + # for each component in vector while r < numLoadVectorComp: + if doTailOpt == 1 and i == idx and ((r+1) % numElementsPer4Bytes != 0): + if kLabelsList != None: + module.add(kLabelsList.pop()) + module.add(SCmpGeU32(src0=sgpr(kSgpr), src1=(r + 1), comment="")) + module.add(SCBranchSCC0(labelName=jumpLabel.getLabelName(), comment="")) + numElementsPerLoad = 1 # FIXME: Don't know why for grvw == 1, need further investigate glvwWorkaround = 8 * kernel["ProblemType"]["DataType"].numRegisters() @@ -6408,12 +6726,19 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju dataIsByte = True # Check out 3 regs once , for component 1,2,3 (r = 1,2,3) - if doTailOpt: + if doTailOpt == 1: if r == 1: packInt8Code = Module() destVgprHi = tmpVgprIdx else: - if r == 1: + if doTailOpt == 2: + if r == 1: + packInt8Code = Module() + if r != 0: + destVgprHi = vgprSlot[vgprIdx] + if r >= rStart and r < rEnd and (r % 4 != 0): + vgprIdx = vgprIdx + 1 + elif r == 1: packInt8Code = Module() destVgprHi = self.vgprPool.checkOut( int8TempVgpr , 'destVgprHi') regIdx = r // 4 @@ -6430,7 +6755,7 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju # In some cards, loading half types into register will zero out # the other half. Therefore we need to load into a separate register # then pack 2 registers into one - if doTailOpt: + if doTailOpt == 1: destVgprHi = tmpVgprIdx else: if (tP["localWriteInstruction"].blockWidth == 0.5) and (r%2 == 0): @@ -6439,8 +6764,12 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju eccOffset = _getEccOffset(tP["globalReadInstruction"].totalWidth, bpr=self.states.bpr, bpe=eccBpe, \ glvw=tP["glvw"], idx=loopCnt, numVgprG2L=numVgprG2L) else: - destVgprHi = self.vgprPool.checkOut(1, 'destVgprHi') - + if doTailOpt == 2: + destVgprHi = vgprSlot[vgprIdx] + if r >= rStart and r < rEnd and (r % 2 != 0): + vgprIdx = vgprIdx + 1 + else: + destVgprHi = self.vgprPool.checkOut(1, 'destVgprHi') regIdx = r // 2 elif dataType.isInt8x4() or dataType.isSingle(): regIdx = r @@ -6453,7 +6782,7 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju regIdx = r*4 else: printWarning("DataType unsupported") - if not doTailOpt: + if doTailOpt == 0: module.addComment0("g2l=%u, load component %u"%(g2lIdx, r)) offset = 0 @@ -6520,7 +6849,10 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju ldsOffset = ldsInc * tP["nrc"] * (sPerp + tP["nrpv"] * perp) + lscaOffset ldsInc = ldsOffset - prevLdsOffset prevLdsOffset = ldsOffset - module.add(SAddU32(dst=mgpr(0), src0=mgpr(0), src1=ldsInc, comment="Move LDS write address to next line" )) + + if (doTailOpt == 0) or (doTailOpt == 2 and behavior == "LOAD") or\ + (doTailOpt == 1 and i == idx and ((r+1) % numElementsPer4Bytes != 0) and r != 0): + module.add(SAddU32(dst=mgpr(0), src0=mgpr(0), src1=ldsInc, comment="Move LDS write address to next line" )) destVgpr=0 self.vgprs.globalReadRegisters[tc].append(0) else: @@ -6555,22 +6887,32 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju # if hi8=1 or hi16=1 (component 1,2,3 for int8) or (component 1 for half), use the temp destVgprHi # but only when hi16=1 we use the _d16_hi version instruction, see the below visualized int8 comment - if doTailOpt: + if doTailOpt == 1: loadVgpr = destVgprHi else: loadVgpr = destVgprHi if ((hi16 or hi8) and destVgprHi != None) else destVgpr self.vgprs.globalReadRegisters[tc][-1] = destVgprHi if ((hi16 or hi8) and destVgprHi != None) else self.vgprs.globalReadRegisters[tc][-1] if (kernel["ProblemType"]["DataType%s"%tcDataType].isInt8() or kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat() or tP["isM"]) and (not self.states.archCaps["HasEccHalf"]): module.add(VMovB32(dst=vgpr(loadVgpr), src=0, comment="set to zero to avoid unexpected value")) - - if doTailOpt: - if behavior == "LOAD" and i == idx and ((r + 1) % numElementsPer4Bytes != 0): - #hi16 = False - if kLabelsList != None: - module.add(kLabelsList.pop()) - module.add(SCmpGeU32(src0=sgpr(kSgpr), src1=(r + 1), comment="")) - module.add(SCBranchSCC0(labelName=jumpLabel.getLabelName(), comment="")) - module.addComment0("g2l=%u, load component %u"%(g2lIdx, r)) + if doTailOpt == 1: + if behavior == "LOAD" and i == idx: + if (numElementsPerLoad == 2 and r % numElementsPer4Bytes != 0) or \ + (numElementsPerLoad != 2 and ((r + 1) % numElementsPer4Bytes != 0)): + module.addComment0("g2l=%u, load component %u"%(g2lIdx, r)) + module.add(self.chooseGlobalRead(True, \ + bpl, destVgpr=loadVgpr, \ + addr0=vgpr(offsetVgpr), addr1=sgpr("Srd%s"%tc, 4), \ + soffset=soffset, offset=offset, \ + glc=isGlc, slc=isSlc, nt=isNT, lds=isLds, \ + hi16=hi16, \ + comment=comment)) + tmpVgprIdx += 1 + + if (numElementsPerLoad == 2 and r == (numLoadVectorComp - 1)) or \ + (numElementsPerLoad != 2 and (r + 1) == (numLoadVectorComp - 1)): + module.add(SBranch(labelName=jumpLabel.getLabelName(), comment="")) + else: + if (doTailOpt == 0) or (doTailOpt == 2 and behavior == "LOAD" and r >= rStart and r < rEnd): module.add(self.chooseGlobalRead(True, \ bpl, destVgpr=loadVgpr, \ addr0=vgpr(offsetVgpr), addr1=sgpr("Srd%s"%tc, 4), \ @@ -6578,17 +6920,7 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju glc=isGlc, slc=isSlc, nt=isNT, lds=isLds, \ hi16=hi16, \ comment=comment)) - tmpVgprIdx += 1 - if (r + 1) == (numLoadVectorComp - 1): - module.add(SBranch(labelName=jumpLabel.getLabelName(), comment="")) - else: - module.add(self.chooseGlobalRead(True, \ - bpl, destVgpr=loadVgpr, \ - addr0=vgpr(offsetVgpr), addr1=sgpr("Srd%s"%tc, 4), \ - soffset=soffset, offset=offset, \ - glc=isGlc, slc=isSlc, nt=isNT, lds=isLds, \ - hi16=hi16, \ - comment=comment)) + loadCnt = loadCnt + 1 if unrollMirrorWithSoffset: codeMod = Module("mirrorIdx%u"%loopCnt) codeMod.add(VAddU32(dst=vgpr(offsetVgpr), src0=vgpr(offsetVgpr), src1=soffset_prev, comment="mirror unroll: restore GRO=GRO+SGRO")) @@ -6639,13 +6971,11 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju # V1, V3 -> shift left 8 bits, or 4 regs (pack) # DestV0|=(V1 << 8), DestV0|= V2, DestV0|=(V3 << 8) # Int8 (byte) + if doTailOpt == 2 and r >= rStart and r < rEnd: + loadNum = loadNum - 1 if dataIsByte and (destVgprHi != None): - if doTailOpt: + if doTailOpt == 1: if behavior == "MERGE" and i == idx and ((r + 1) % numElementsPer4Bytes != 0): - if kLabelsList != None: - module.add(kLabelsList.pop()) - module.add(SCmpGeU32(src0=sgpr(kSgpr), src1=(r + 1), comment="")) - module.add(SCBranchSCC0(labelName=jumpLabel.getLabelName(), comment="")) # hi8 -> r = 1,3 # hi16 -> r = 2,3 module.add(SWaitCnt(vmcnt=0, comment="")) @@ -6656,47 +6986,60 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju if (r + 1) == (numLoadVectorComp - 1): module.add(SBranch(labelName=jumpLabel.getLabelName(), comment="")) else: - # hi8 -> r = 1,3 - # hi16 -> r = 2,3 - if hi8 or hi16: - # r = 1,2,3, vmcnt needed for one packing - packInt8Code.add(SWaitCnt(vmcnt=(int8TempVgpr-r), comment="")) - if hi8: - # r = 1,3, shift needed - packInt8Code.add(VLShiftLeftB32(dst=vgpr(destVgprHi), shiftHex=hex(0x8), src=vgpr(destVgprHi), comment="shift left to higher 8 bits")) - if hi8 or hi16: - # r = 1,2,3, packing - packInt8Code.add(VOrB32(dst=vgpr(destVgpr), src0=vgpr(destVgpr), src1=vgpr(destVgprHi), comment="pack a sub 8-bit with dest")) - destVgprHi += 1 + if (doTailOpt == 0) or (doTailOpt == 2 and behavior == "MERGE" and r >= rStart and r < rEnd): + # hi8 -> r = 1,3 + # hi16 -> r = 2,3 + if hi8 or hi16: + # r = 1,2,3, vmcnt needed for one packing + if doTailOpt == 0: + packInt8Code.add(SWaitCnt(vmcnt=(int8TempVgpr-r), comment="")) + else: + packInt8Code.add(SWaitCnt(vmcnt=(loadNum), comment="")) + if hi8: + # r = 1,3, shift needed + packInt8Code.add(VLShiftLeftB32(dst=vgpr(destVgprHi), shiftHex=hex(0x8), src=vgpr(destVgprHi), comment="shift left to higher 8 bits")) + if hi8 or hi16: + # r = 1,2,3, packing + packInt8Code.add(VOrB32(dst=vgpr(destVgpr), src0=vgpr(destVgpr), src1=vgpr(destVgprHi), comment="pack a sub 8-bit with dest")) + destVgprHi += 1 # Half elif destVgprHi != None: - if doTailOpt: - if behavior == "MERGE" and i == idx and ((r + 1) % numElementsPer4Bytes != 0): - if kLabelsList != None: - module.add(kLabelsList.pop()) - module.add(SCmpGeU32(src0=sgpr(kSgpr), src1=(r + 1), comment="")) - module.add(SCBranchSCC0(labelName=jumpLabel.getLabelName(), comment="")) - module.add(SWaitCnt(vmcnt=0, comment="")) - if kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat(): - module.add(VLShiftRightB32(dst=vgpr(destVgprHi), shiftHex=hex(8), src=vgpr(destVgprHi), comment="shift right to lower 8 bits")) - module.add(VOrB32(dst=vgpr(destVgpr), src0=vgpr(destVgpr), src1=vgpr(destVgprHi), comment="HasEccHalf: pack")) - if kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat() and (g2lIdx % 2 == 1): - module.add(VLShiftLeftB32(dst=vgpr(destVgpr), shiftHex=hex(16), src=vgpr(destVgpr), comment="shift left to higher 16 bits")) - tmpVgprIdx += 1 - if (r + 1) == (numLoadVectorComp - 1): - module.add(SBranch(labelName=jumpLabel.getLabelName(), comment="")) + if doTailOpt == 1: + if behavior == "MERGE" and i == idx: + if (numElementsPerLoad == 2 and r % numElementsPer4Bytes != 0) or \ + (numElementsPerLoad != 2 and ((r + 1) % numElementsPer4Bytes != 0)): + module.add(SWaitCnt(vmcnt=0, comment="")) + if kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat(): + module.add(VLShiftRightB32(dst=vgpr(destVgprHi), shiftHex=hex(8), src=vgpr(destVgprHi), comment="shift right to lower 8 bits")) + module.add(VOrB32(dst=vgpr(destVgpr), src0=vgpr(destVgpr), src1=vgpr(destVgprHi), comment="HasEccHalf: pack")) + if kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat() and (g2lIdx % 2 == 1): + module.add(VLShiftLeftB32(dst=vgpr(destVgpr), shiftHex=hex(16), src=vgpr(destVgpr), comment="shift left to higher 16 bits")) + tmpVgprIdx += 1 + if (numElementsPerLoad == 2 and r == (numLoadVectorComp - 1)) or\ + (numElementsPerLoad != 2 and (r + 1) == (numLoadVectorComp - 1)): + module.add(SBranch(labelName=jumpLabel.getLabelName(), comment="")) else: - if r % 2 == 1: - module.add(SWaitCnt(vmcnt=0, comment="")) - if kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat(): - module.add(VLShiftRightB32(dst=vgpr(destVgprHi), shiftHex=hex(8), src=vgpr(destVgprHi), comment="shift right to lower 8 bits")) - module.add(VOrB32(dst=vgpr(destVgpr), src0=vgpr(destVgpr), src1=vgpr(destVgprHi), comment="HasEccHalf: pack")) - if kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat() and (g2lIdx % 2 == 1): - module.add(VLShiftLeftB32(dst=vgpr(destVgpr), shiftHex=hex(16), src=vgpr(destVgpr), comment="shift left to higher 16 bits")) + if (doTailOpt == 0) or (doTailOpt == 2 and behavior == "MERGE" and r >= rStart and r < rEnd): + if r % 2 == 1: + if doTailOpt == 0: + module.add(SWaitCnt(vmcnt=0, comment="")) + else: + + module.add(SWaitCnt(vmcnt=(loadNum), comment="")) + if kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat(): + module.add(VLShiftRightB32(dst=vgpr(destVgprHi), shiftHex=hex(8), src=vgpr(destVgprHi), comment="shift right to lower 8 bits")) + module.add(VOrB32(dst=vgpr(destVgpr), src0=vgpr(destVgpr), src1=vgpr(destVgprHi), comment="HasEccHalf: pack")) + if kernel["ProblemType"]["DataType%s"%tcDataType].is8bitFloat() and (g2lIdx % 2 == 1): + module.add(VLShiftLeftB32(dst=vgpr(destVgpr), shiftHex=hex(16), src=vgpr(destVgpr), comment="shift left to higher 16 bits")) + else: + if doTailOpt == 1 and i == idx and behavior == "MERGE" and\ + ((numElementsPerLoad == 2 and r == (numLoadVectorComp - 1)) or\ + (numElementsPerLoad != 2 and (r + 1) == (numLoadVectorComp - 1))): + module.add(SBranch(labelName=jumpLabel.getLabelName(), comment="")) # For half (bf16). Note: for int8, we will checkin after loading all components if (destVgprHi != None) and (not dataIsByte): - if not doTailOpt: + if doTailOpt == 0: self.vgprPool.checkIn(destVgprHi) destVgprHi = None @@ -6713,7 +7056,7 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju if dataIsByte and int8TempVgpr: assert packInt8Code != None and destVgprHi != None module.add(packInt8Code) - if not doTailOpt: + if doTailOpt == 0: self.vgprPool.checkIn(destVgprHi - int8TempVgpr) destVgprHi = None @@ -6753,26 +7096,171 @@ def globalReadGuardKBody(tP, tmpVgpr = None, kLabelsList = [], behavior = "", ju self.vgprPool.checkIn(destVgprHi) destVgprHi = None - globalReadGuardKBody(tP, tmpVgpr, kLabelsList, behavior, jumpLabel, numElementsPer4Bytes, kSgpr, doTailOpt) - if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: - if tP["is_sparse"]: - globalReadGuardKBody(tP["tpsMetadata"]) + return loadCnt, self.vgprs.globalReadRegisters[tc], directToLdsLoads - if self.db["ConservativeWaitCnt"] & 0x1: - module.add(SBarrier(comment="debug")) - module.add(SWaitCnt(lgkmcnt=0, vmcnt=0, vscnt=0, comment="")) - module.add(SBarrier(comment="debug")) - #module.add(self.getCmpAssert(self.asmAssert.lt, vgpr("Serial"), 64)) # examine second wavefront + loadCnt, vgprList, directToLdsLoads = globalReadGuardKBody(tP, optParams) - # TODO - can remove one of these m0 restores if A and B both TLU - if kernel["DirectToLds%s"%tP["tensorChar"]]: - module.add(SMovB32(dst=mgpr(0), src=hex(kernel["LdsNumBytes"]), \ - comment="Restore LDS clamp at %u bytes"%(kernel["LdsNumBytes"]))) + if doTailOpt == 0 or \ + (doTailOpt == 2 and optParams.finalLoop == 1 and optParams.behavior == "MERGE"): + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: + if tP["is_sparse"]: + globalReadGuardKBody(tP["tpsMetadata"]) + + if self.db["ConservativeWaitCnt"] & 0x1: + module.add(SBarrier(comment="debug")) + module.add(SWaitCnt(lgkmcnt=0, vmcnt=0, vscnt=0, comment="")) + module.add(SBarrier(comment="debug")) + + # TODO - can remove one of these m0 restores if A and B both TLU + if kernel["DirectToLds%s"%tP["tensorChar"]]: + module.add(SMovB32(dst=mgpr(0), src=hex(kernel["LdsNumBytes"]), \ + comment="Restore LDS clamp at %u bytes HERE"%(kernel["LdsNumBytes"]))) + + if not kernel["BufferLoad"]: + self.vgprPool.checkIn(maxAddrVgpr) + self.vgprPool.checkIn(bpeVgpr) + self.vgprPool.checkIn(zeroVgpr) + + if doTailOpt == 2: + return module, loadCnt, vgprList, directToLdsLoads + else: + return module - if not kernel["BufferLoad"]: - self.vgprPool.checkIn(maxAddrVgpr) - self.vgprPool.checkIn(bpeVgpr) - self.vgprPool.checkIn(zeroVgpr) + ############################################################################## + # Redorder the wait instructions to reduce overall waiting time in tail loop. + # It needs more vgprs to store data which is needed to be merged into a dword. + ############################################################################## + def doTailLoopOpt(self, kernel, tP): + module = Module("doTailLoop") + tc = tP["tensorChar"] + orinrp = tP["nrp"] + orinrpv = tP["nrpv"] + orinrc = tP["nrc"] + orinrcv_div_nrcvpi = tP["nrcv"] // tP["nrcvpi"] + loadWidth = tP["globalReadInstruction"].totalWidth + oriNumLoadVectorComp = (int(loadWidth*self.states.bpr//tP["bpeGR"])) + numElementsPerLoad = 1 + glvwWorkaround = 8 * kernel["ProblemType"]["DataType"].numRegisters() + dataType = kernel["ProblemType"]["DataType"] if tP["glvw"] < glvwWorkaround \ + else kernel["ProblemType"]["DataType%s"%tc] + + if dataType.isHalf() or dataType.isBFloat16(): + if tP["glvw"] > 1 and kernel["AssertSummationElementMultiple"] % 2 == 0: + # Pack two FP16 values into a single load dword x2 + numElementsPerLoad = 2 + oriNumLoadVectorComp = oriNumLoadVectorComp // 2 + totalVgprNum = (tP["nrp"] * tP["nrpv"] * tP["nrc"] * (tP["nrcv"] // tP["nrcvpi"])) * (oriNumLoadVectorComp) + if kernel["ProblemType"]["Gradient"] and kernel["ProblemType"]["UseBias"] \ + and (kernel["ProblemType"]["BiasSrc"] == "A" or kernel["ProblemType"]["BiasSrc"] == "B"): + totalVgprNum += 1 + + numLoadVectorComp = oriNumLoadVectorComp + nrp = orinrp + nrpv = orinrpv + nrc = orinrc + nrcv_div_nrcvpi = orinrcv_div_nrcvpi + currentSize = self.vgprPool.size() + VgprSlot = [] + VgprSlotBk = [] + + maxVgpr = 50 + while (len(VgprSlot) < totalVgprNum and len(VgprSlot) <= maxVgpr): + tempVgpr = self.vgprPool.checkOut(1,"") + if tempVgpr >= currentSize: + self.vgprPool.checkIn(tempVgpr) + break + if kernel["ProblemType"]["Gradient"] and kernel["ProblemType"]["UseBias"] \ + and (kernel["ProblemType"]["BiasSrc"] == "A" or kernel["ProblemType"]["BiasSrc"] == "B"): + if tempVgpr != self.states.bias.startVgprValu: + VgprSlot.append(tempVgpr) + else: + VgprSlot.append(tempVgpr) + VgprSlotBk.append(tempVgpr) + + loopNum = 1 + finalVgprNum = totalVgprNum + while (finalVgprNum > len(VgprSlot) and totalVgprNum > 0): + if nrp > 1: + nrp -= 1 + elif nrpv > 1: + nrpv -= 1 + elif nrc > 1: + nrc -= 1 + elif nrcv_div_nrcvpi > 1: + nrcv_div_nrcvpi -= 1 + else: + numLoadVectorComp = numLoadVectorComp - 1 + finalVgprNum = (nrp * nrpv * nrc * nrcv_div_nrcvpi) * (numLoadVectorComp) + + nrpLoopNum = ceil(tP["nrp"] / nrp) + nrpvLoopNum = ceil(tP["nrpv"] / nrpv) + nrcLoopNum = ceil(tP["nrc"] / nrc) + nrcv_div_nrcvpiLoopNum = ceil(tP["nrcv"] // tP["nrcvpi"] / nrcv_div_nrcvpi) + loadVectorCompLoopNum = int(ceil(oriNumLoadVectorComp / numLoadVectorComp)) + totalLoopNum = nrpLoopNum * nrpvLoopNum * nrcLoopNum * nrcv_div_nrcvpiLoopNum * loadVectorCompLoopNum + + i = 0 + globalReadRegisters = [] + directToLdsLoads = 0 + finalLoop = 0 + + self.optParamsLoad = TailOptParams() + self.optParamsMerge = TailOptParams() + for nrp_idx in range(nrpLoopNum): + for nrpv_idx in range(nrpvLoopNum): + for nrc_idx in range(nrcLoopNum): + for nrcv_div_nrcvpi_idx in range(nrcv_div_nrcvpiLoopNum): + for loadVectorComp_idx in range(loadVectorCompLoopNum): + firstLoop = 0 + if numElementsPerLoad == 2: + periodParam = [nrp_idx * nrp, min(orinrp, (nrp_idx + 1) * nrp), \ + nrpv_idx * nrpv, min(orinrpv, (nrpv_idx + 1) * nrpv), \ + nrc_idx * nrc, min(orinrc, (nrc_idx + 1) * nrc), \ + nrcv_div_nrcvpi_idx * nrcv_div_nrcvpi, \ + min(orinrcv_div_nrcvpi, (nrcv_div_nrcvpi_idx + 1) * nrcv_div_nrcvpi), \ + loadVectorComp_idx * numLoadVectorComp * 2, \ + min(oriNumLoadVectorComp * 2, (loadVectorComp_idx + 1) * numLoadVectorComp * 2)] + else: + periodParam = [nrp_idx * nrp, min(orinrp, (nrp_idx + 1) * nrp), \ + nrpv_idx * nrpv, min(orinrpv, (nrpv_idx + 1) * nrpv), \ + nrc_idx * nrc, min(orinrc, (nrc_idx + 1) * nrc), \ + nrcv_div_nrcvpi_idx * nrcv_div_nrcvpi, \ + min(orinrcv_div_nrcvpi, (nrcv_div_nrcvpi_idx + 1) * nrcv_div_nrcvpi), \ + loadVectorComp_idx * numLoadVectorComp, \ + min(oriNumLoadVectorComp, (loadVectorComp_idx + 1) * numLoadVectorComp)] + if i == 0: + firstLoop = 1 + + self.optParamsLoad.behavior = "LOAD" + self.optParamsLoad.vgprSlot = VgprSlot + self.optParamsLoad.periodParam = periodParam + self.optParamsLoad.directToLdsLoads = directToLdsLoads + self.optParamsLoad.firstLoop = firstLoop + self.optParamsLoad.finalLoop = finalLoop + imod, loadCnt, vgprList, directToLdsLoads = self.globalReadDo(kernel, 2, tP, -1, 0, 2, self.optParamsLoad) + module.add(imod) + + globalReadRegisters = globalReadRegisters + vgprList + + if i == totalLoopNum - 1: + finalLoop = 1 + + self.optParamsMerge.behavior = "MERGE" + self.optParamsMerge.vgprSlot = VgprSlot + self.optParamsMerge.periodParam = periodParam + self.optParamsMerge.loadNum = loadCnt + self.optParamsMerge.directToLdsLoads = directToLdsLoads + self.optParamsMerge.firstLoop = firstLoop + self.optParamsMerge.finalLoop = finalLoop + imod, loadCnt, vgprList, directToLdsLoads = self.globalReadDo(kernel, 2, tP, -1, 0, 2, self.optParamsMerge) + module.add(imod) + + i += 1 + + self.vgprs.globalReadRegisters[tc] = globalReadRegisters + while VgprSlotBk: + tempVgpr = VgprSlotBk.pop(0) + self.vgprPool.checkIn(tempVgpr) return module @@ -6813,11 +7301,13 @@ def directToLdsM0Update(self, kernel, mode, tP, usePlaceHolder=False): ############################################################################## # Global Read: Do It A/B ############################################################################## - def globalReadDo(self, kernel, mode, tP, unrollLoopIdx=-1, g2lBufIdx=0): + def globalReadDo(self, kernel, mode, tP, unrollLoopIdx=-1, g2lBufIdx=0, \ + doTailOpt = 0, optParams = None): tc = tP["tensorChar"] problemType = self.states.kernel["ProblemType"] imod = StructuredModule("globalReadDo%s_%u"%(tc,mode)) - if not self.do["GlobalRead%s"%tP["tensorChar"]]: return imod + if not self.do["GlobalRead%s"%tP["tensorChar"]]: + return imod # sizeK % LOCAL_DEPTHU guardK = (mode==2) @@ -6841,20 +7331,29 @@ def globalReadDo(self, kernel, mode, tP, unrollLoopIdx=-1, g2lBufIdx=0): # if DirectToVgpr is enabled and swapGlobalRead is true, change the first to B if self.isSwapGlobalReadOrderForDtvOrDtl(kernel, prefetch1=(mode==0)): tc1st = 'B' + if tc == tc1st and (kernel["DirectToLdsA"] or kernel["DirectToLdsB"]) \ + and not kernel["PrefetchGlobalRead"]==2: + if doTailOpt == 0 or \ + (doTailOpt == 2 and optParams.behavior == "LOAD" and optParams.firstLoop == 1): + # generate local read wait for DirectToLds except for + # PrefetchGlobalRead=2 (for PGR=2, generate wait after m0 value setting) + imod.header.addComment0("before DirectToLds load, ensure prior ds_reads have finished") + if (kernel["DirectToVgprA"] or kernel["DirectToVgprB"]) and not guardK and mode != 3: + # no need to generate sync here if DirectToVgpr is enabled and not tail loop + imod.header.add(SWaitCnt(lgkmcnt=0, comment="wait for LDS read/write")) + else: + imod.header.add(self._syncThreads(kernel)) - if tc == tc1st and (kernel["DirectToLdsA"] or kernel["DirectToLdsB"]) and not kernel["PrefetchGlobalRead"]==2: - # generate local read wait for DirectToLds except for PrefetchGlobalRead=2 (for PGR=2, generate wait after m0 value setting) - imod.header.addComment0("before DirectToLds load, ensure prior ds_reads have finished") - if (kernel["DirectToVgprA"] or kernel["DirectToVgprB"]) and not guardK: - # no need to generate sync here if DirectToVgpr is enabled and not tail loop - imod.header.add(SWaitCnt(lgkmcnt=0, comment="wait for LDS read/write")) + if guardK: + if doTailOpt == 0: + imod.middle.add(self.globalReadGuardK(kernel, tP)) + return imod else: - imod.header.add(self._syncThreads(kernel)) + module, loadCnt, vgprList, directToLdsLoads = \ + self.globalReadGuardK(kernel, tP, doTailOpt, optParams) - - if guardK: - imod.middle.add(self.globalReadGuardK(kernel, tP)) - return imod + imod.middle.add(module) + return imod, loadCnt, vgprList, directToLdsLoads # else not-guardK below: @@ -7524,10 +8023,21 @@ def localWriteBody(tP): paramList = [] numsOfRegister = [] globalBlockWidth = tP["globalReadInstruction"].totalWidth +# print("tc = ", tc, ", numBlocks = ", numBlocks) +# print("regs: ", self.vgprs.globalReadRegisters[tc]) +# print("destVgprPrefix = ", destVgprPrefix, ", blockWidth = ", blockWidth) for _ in range(0, numBlocks): # FIXME: In the future all registers should pass from global read instead of recalculate them if globalBlockWidth == blockWidth and tP["glvw"] == 1: +# print("destVgprPrefix = ", destVgprPrefix) +# print("tc = ", tc, ", i", i) +# print(self.vgprs.globalReadRegisters[tc][i]) +# print("blockWidth = ", blockWidth) +# print(vgpr(destVgprPrefix + "+%u"%(self.vgprs.globalReadRegisters[tc][i]), blockWidth)) +# print("i = ", i) +# print("_ = ", _) paramList.append(vgpr(destVgprPrefix + "+%u"%(self.vgprs.globalReadRegisters[tc][i]), blockWidth)) +# print("DONE") elif blockWidth == 1: paramList.append(vgpr(destVgprPrefix + "+%u"%(g2lIdx))) numsOfRegister.append(1) diff --git a/tensilelite/Tensile/SolutionStructs.py b/tensilelite/Tensile/SolutionStructs.py index 2d5ec86fdc..129de9ed88 100644 --- a/tensilelite/Tensile/SolutionStructs.py +++ b/tensilelite/Tensile/SolutionStructs.py @@ -1386,15 +1386,17 @@ def assignProblemIndependentDerivedParameters(state): reject(state, "MacroTile mismatch") # tail loop optimization + state["tailLoopOptA"] = True + state["tailLoopOptB"] = True + if (tuple(state["ISA"]) != (9, 4, 2)) or \ - (state["ProblemType"]["Sparse"]) or \ - (state["LocalSplitU"] > 1) or \ - (state["WaveSeparateGlobalReadA"] != 0) or \ - (state["WaveSeparateGlobalReadB"] != 0) or \ - (state["DirectToVgprA"] or state["DirectToVgprB"]): - state["tailLoopOpt"] = False - else: - state["tailLoopOpt"] = True + (state["ProblemType"]["Sparse"]): + state["tailLoopOptA"] = False + state["tailLoopOptB"] = False + if (state["DirectToVgprA"]): + state["tailLoopOptA"] = False + if (state["DirectToVgprB"]): + state["tailLoopOptB"] = False # done state["AssignedProblemIndependentDerivedParameters"] = True