Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release and allocate VGPR resoures in tail loop. #1586

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 135 additions & 105 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .TensileInstructions import Item, TensileInstructions, slash50, replaceHolder, \
KernelBody, Module, StructuredModule, TextBlock, Dump, LabelManager, \
RegisterPool, Assert, fastdeepcopy, TensileInstructionsPassOptions, \
TensileInstructionsPass, \
TensileInstructionsPass, ValueSet, RegSet, \
SLongBranchPositive, SBranch, SCBranchSCC0, SCBranchSCC1
from .TensileInstructions.Instructions import *
from .KernelWriterModules import *
Expand Down Expand Up @@ -200,6 +200,7 @@ class StateValues:
totalSgprs: int = 0
lastValuAB: int = 0
lastVgprForReads: int = 0
startVgpr: int = 0
startVgprAddressDbg: int = -1
startVgprAlphaTmp: int = -1
startVgprSerial: int = -1
Expand Down Expand Up @@ -2581,6 +2582,15 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
module.add(self.removeStagger(kernel, tensorParametersA))
module.add(self.removeStagger(kernel, tensorParametersB))

self.vgprPool.add(self.states.a.startVgprValu , \
self.states.lastValuAB - self.states.a.startVgprValu, "ValuAB")
module.addComment1("Tail: add ValuA/B vgpr buffer [%u...%u) to pool" % \
(self.states.a.startVgprValu, self.states.lastValuAB))
self.vgprPool.add(self.states.lastValuAB , \
self.states.lastVgprForReads - self.states.lastValuAB, "address vgpr")
module.addComment1("Tail: add address/G2L vgpr [%u...%u) to pool" % \
(self.states.lastValuAB, self.states.lastVgprForReads))

if not kernel["NoTailLoop"]:
########################################
# Tail Loop
Expand All @@ -2596,14 +2606,6 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
if (kernel["DirectToVgprA"] or kernel["DirectToVgprB"] or kernel["DirectToLdsA"] or kernel["DirectToLdsB"]):
mEnd = kernel["DepthU"]//(kernel["MatrixInstK"]*kernel["LocalSplitU"])

# TailLoop unroll case (mEnd > 1), we need to keep these vgpr
if mEnd == 1:
# add vgprBuffer for local read to vgprPool because we won't issue local read in this section
self.vgprPool.add(self.states.a.startVgprValu , \
self.states.lastValuAB - self.states.a.startVgprValu, "ValuAB") # Add as available
module.addComment1("Tail: add ValuA/B vgpr buffer [%u...%u) to pool" % \
(self.states.a.startVgprValu, self.states.a.startVgprValu+self.states.lastValuAB))

# Update local write pointers in case the upcoming global reads are writing directly to LDS:
if kernel["PrefetchGlobalRead"]:
module.addComment1("local write reset offsets a")
Expand All @@ -2618,6 +2620,16 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
module.add(self.localWriteResetOffsets(kernel, False, tensorParametersB))

# tail: global read
# Check out VGPR for DTVA
vDtvResources = self.tailLoopAllocDTVVgpr(kernel, tensorParametersA, tensorParametersB)
for item in vDtvResources:
if item[0] != -1:
module.add(item[1])

# Check out VGPR for G2l
moduleMacroG2lVgpr, vgprG2L = self.tailLoopAllocG2LVgpr(kernel)
module.add(moduleMacroG2lVgpr)

module.add(self.calculateLoopNumIter(kernel, tensorParametersA, tensorParametersB, -1))
if self.states.actualSummationLoops==1:
module.addComment1("remove stagger offsets for tail loop")
Expand Down Expand Up @@ -2681,6 +2693,15 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
module.add(self._syncThreads(kernel))
#module.add(self.dumpLds(kernel, 0, 8))

# tail: free G2L Vgpr
module.add(self.tailLoopFreeVgpr(vgprG2L, moduleMacroG2lVgpr))

# Check out VGPR for ALU
valuResources = self.tailLoopAllocValuVgpr(kernel, tensorParametersA, tensorParametersB, tPM)
for item in valuResources:
if item[0] != -1:
module.add(item[1])

# tail: re-init local read addresses
if kernel["PrefetchGlobalRead"]:
module.addComment1("Tail: local read reset offsets a")
Expand Down Expand Up @@ -2709,20 +2730,6 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
if (kernel["AssertSummationElementMultiple"] % KinInnerUnroll == 0):
tailLoopInnerUnroll = kernel["InnerUnroll"]

# TailLoop unroll case (mEnd > 1), we need to keep these vgpr
if mEnd == 1:
# remove vgprBuffer for local read from vgprPool because we are ready to issue local read
self.vgprPool.remove(self.states.a.startVgprValu , \
self.states.lastValuAB - self.states.a.startVgprValu , "ValuAB") # remove from pool
module.addComment1("Tail: remove ValuA/B vgpr buffer [%u...%u) from pool" % \
(self.states.a.startVgprValu , self.states.lastValuAB))

# add address vgpr to vgprPool
self.vgprPool.add(self.states.lastValuAB , \
self.states.lastVgprForReads - self.states.lastValuAB, "address vgpr") # Add as available
module.addComment1("Tail: add address/G2L vgpr [%u...%u) to pool" % \
(self.states.lastValuAB, self.states.lastVgprForReads))

for mValue in range(mEnd):
if mEnd > 1:
# print tail loop counter if mEnd>1 (means do tail loop unroll)
Expand Down Expand Up @@ -2784,15 +2791,28 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
module.add(self.closeLoop(kernel, tensorParametersA, tensorParametersB, -1, finalLoop, skipCondJumpCounter=mValue))
# always emit the skip-tail-loop label
module.add(self.closeLoop(kernel, tensorParametersA, tensorParametersB, -1, None, emitEndLabelOnly=True))

# Check in VGPR for VALU
for item in valuResources:
if item[0] != -1:
module.add(self.tailLoopFreeVgpr(item[0], item[1]))

# Check in VGPR for DTV
for item in vDtvResources:
if item[0] != -1:
module.add(self.tailLoopFreeVgpr(item[0], item[1]))

# tail: close
self.states.inTailLoop = False


# FIXME: Add back.
if mEnd == 1:
# remove address vgpr to vgprPool
self.vgprPool.remove(self.states.lastValuAB , \
self.states.lastVgprForReads - self.states.lastValuAB, "address vgpr") # Add as available
module.addComment1("Tail: remove address/G2L [%u...%u) from pool" % \
(self.states.lastValuAB, self.states.lastVgprForReads))
#add misc vgpr to vgprPool
self.vgprPool.add(self.states.startVgprMisc , \
self.states.startVgpr - self.states.startVgprMisc, "misc vgpr") # Add as available
module.addComment1("Tail: add MISC Vgpr [%u...%u) from pool" % \
(self.states.startVgprMisc, self.states.startVgpr))

if self.do["executeToLoopEnd"]:
module.add(self.functionEnd(kernel, addLabel=False))
Expand Down Expand Up @@ -3483,6 +3503,8 @@ def readWriteVectors(mat, vw, kernel):
tpALocal) / (float)(self.states.bpr))
else:
self.states.a.numVgprG2LAllocated = self.states.a.numVgprG2L
else:
self.states.a.numVgprG2LAllocated = 0
# using _ds_store_b8: need one more vgpr space to do lshr
if tensorParametersA["localWriteInstruction"].blockWidth == 0.25:
self.states.a.numVgprG2L = self.states.a.numVgprG2L * 2
Expand Down Expand Up @@ -3516,6 +3538,8 @@ def readWriteVectors(mat, vw, kernel):
tpBLocal) / (float)(self.states.bpr))
else:
self.states.b.numVgprG2LAllocated = self.states.b.numVgprG2L
else:
self.states.b.numVgprG2LAllocated = 0
# using _ds_store_b8: need one more vgpr space to do lshr
if tensorParametersB["localWriteInstruction"].blockWidth == 0.25:
self.states.b.numVgprG2L = self.states.b.numVgprG2L * 2
Expand Down Expand Up @@ -3544,6 +3568,8 @@ def readWriteVectors(mat, vw, kernel):
if tensorParametersM["localWriteInstruction"].blockWidth == 0.25:
self.states.m.numVgprG2L = self.states.m.numVgprG2L * 2
self.states.m.numVgprG2LAllocated = self.states.m.numVgprG2LAllocated * 2
else:
self.states.m.numVgprG2LAllocated = 0
####################################
# num vgprs: local read addresses
self.states.a.numVgprLocalReadAddr = 1 * self.states.rpla
Expand Down Expand Up @@ -3660,12 +3686,76 @@ def readWriteVectors(mat, vw, kernel):
vgprIdx = self.states.totalMixedAgprs
self.states.c.numVgprValu = self.states.totalMixedAgprs

#----------------------------------
# Move to the front and bypass to tail loop
self.states.startVgprMisc = vgprIdx
if not kernel["LocalWriteUseSgprA"]:
self.states.a.startVgprLocalWriteAddr = vgprIdx
vgprIdx += self.states.a.numVgprLocalWriteAddr

if not kernel["LocalWriteUseSgprB"]:
self.states.b.startVgprLocalWriteAddr = vgprIdx
vgprIdx += self.states.b.numVgprLocalWriteAddr

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
if self.states.combineLocalAddresses:
self.states.m.startVgprLocalWriteAddr = self.states.m.startVgprLocalReadAddr
else:
self.states.m.startVgprLocalWriteAddr = vgprIdx
vgprIdx += self.states.m.numVgprLocalWriteAddr

# BufferLoad:
# Uses a resource descriptor (SRD) which is stored in 4 SGPRs and thus shared by all work-items.
# Each work-item also uses a unique 32-bit offset into vgprGlobalReadOffset. These offsets are set when
# the tile is initialized and stay constant through the execution of the kernel.
# The base address in the SRD is updated when the algorithm moves to a new tile
# BufferLoad disables the gptGlobalReadAddr used in flat addressing.
if kernel["BufferLoad"]:
self.startVgprGlobalReadOffsetA = vgprIdx
vgprIdx += 1 if kernel["_UseSgprForGRO"] else self.states.a.numVgprGlobalReadOffsets
self.startVgprGlobalReadOffsetB = vgprIdx
vgprIdx += 1 if kernel["_UseSgprForGRO"] else self.states.b.numVgprGlobalReadOffsets
if kernel["ProblemType"]["Sparse"]:
self.startVgprGlobalReadOffsetMetadata = vgprIdx
if kernel["DirectToVgprSparseMetadata"]:
miWaveTile = kernel["MIWaveTileB"] if kernel["ProblemType"]["Sparse"] == 2 else kernel["MIWaveTileA"]
vgprIdx += miWaveTile
else:
vgprIdx += 1 if kernel["_UseSgprForGRO"] else self.states.m.numVgprGlobalReadOffsets
else:
# TODO: alignment hack, figure out a better solution
vgprIdx = ((vgprIdx+1)//2)*2
self.startVgprGlobalReadAddressesA = vgprIdx
vgprIdx += numVgprGlobalReadAddressesA
self.startVgprGlobalReadAddressesB = vgprIdx
vgprIdx += numVgprGlobalReadAddressesB

self.startVgprGlobalReadIncsA = vgprIdx
vgprIdx += numVgprGlobalReadIncsA
self.startVgprGlobalReadIncsB = vgprIdx
vgprIdx += numVgprGlobalReadIncsB
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
self.startVgprGlobalReadIncsMetadata = vgprIdx
vgprIdx += numVgprGlobalReadIncsMetadata

self.states.a.startVgprLocalReadAddr = vgprIdx
vgprIdx += self.states.a.numVgprLocalReadAddr

self.states.b.startVgprLocalReadAddr = vgprIdx
vgprIdx += self.states.b.numVgprLocalReadAddr

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
self.states.m.startVgprLocalReadAddr = vgprIdx
vgprIdx += self.states.m.numVgprLocalReadAddr

# ----------------------------
# TODO: alignment hack, figure out a better solution
vgprIdx = ((vgprIdx+1)//2)*2
# Avoid bank conflict between VgprA and VgprC
if(self.states.archCaps["VgprBank"]):
vgprIdx += 2
self.states.a.startVgprValu = vgprIdx
self.states.startVgpr = vgprIdx
vgprIdx += self.states.a.numVgprValu
numVgprValuPackA = 0
if tensorParametersA["bpe"] < 4 and not kernel["UnrollMajorLDSA"]:
Expand Down Expand Up @@ -3723,28 +3813,15 @@ def readWriteVectors(mat, vw, kernel):
vgprIdx = self.states.b.startVgprValu \
+ max(self.states.b.numVgprValu + numVgprValuPackB, self.states.b.numVgprG2LAllocated)

if kernel["ProblemType"]["Gradient"] and kernel["ProblemType"]["UseBias"]:
if kernel["ProblemType"]["BiasSrc"] == "A":
self.states.bias.numVgprValu = kernel["MIWaveTile"][0]
elif kernel["ProblemType"]["BiasSrc"] == "B":
self.states.bias.numVgprValu = kernel["MIWaveTile"][1]
else:
self.states.bias.numVgprValu = 0
self.states.bias.numVgprValu *= max(kernel["ProblemType"]["ComputeDataType"].numRegisters(), 1)
else:
self.states.bias.numVgprValu = 0
self.states.bias.startVgprValu = vgprIdx
vgprIdx += self.states.bias.numVgprValu

if ((tensorParametersA["bpe"] < 4 and not kernel["UnrollMajorLDSA"]) or (tensorParametersB["bpe"] < 4 and not kernel["UnrollMajorLDSB"])) \
and (kernel["ProblemType"]["DataType"].isInt8() or kernel["ProblemType"]["DataType"].is8bitFloat()):
self.states.a.startVgprValuPackTemp = vgprIdx
self.states.b.startVgprValuPackTemp = vgprIdx
vgprIdx += 1

self.states.a.startVgprValuCvtTemp = -1
self.states.b.startVgprValuCvtTemp = -1
if kernel["ConvertAfterDS"]:
self.states.a.startVgprValuCvtTemp = -1
self.states.b.startVgprValuCvtTemp = -1
if ((tensorParametersA["bpe"] > tensorParametersA["bpeDS"]) and kernel["ProblemType"]["DataTypeA"].is8bitFloat()):
self.states.a.startVgprValuCvtTemp = vgprIdx
if ((tensorParametersB["bpe"] > tensorParametersB["bpeDS"]) and kernel["ProblemType"]["DataTypeB"].is8bitFloat()):
Expand Down Expand Up @@ -3783,62 +3860,12 @@ def readWriteVectors(mat, vw, kernel):
# code
if kernel["PrefetchGlobalRead"]:
self.states.lastValuAB = vgprIdx
#----------------------------------

if not kernel["LocalWriteUseSgprA"]:
self.states.a.startVgprLocalWriteAddr = vgprIdx
vgprIdx += self.states.a.numVgprLocalWriteAddr

if not kernel["LocalWriteUseSgprB"]:
self.states.b.startVgprLocalWriteAddr = vgprIdx
vgprIdx += self.states.b.numVgprLocalWriteAddr

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
if self.states.combineLocalAddresses:
self.states.m.startVgprLocalWriteAddr = self.states.m.startVgprLocalReadAddr
else:
self.states.m.startVgprLocalWriteAddr = vgprIdx
vgprIdx += self.states.m.numVgprLocalWriteAddr

# BufferLoad:
# Uses a resource descriptor (SRD) which is stored in 4 SGPRs and thus shared by all work-items.
# Each work-item also uses a unique 32-bit offset into vgprGlobalReadOffset. These offsets are set when
# the tile is initialized and stay constant through the execution of the kernel.
# The base address in the SRD is updated when the algorithm moves to a new tile
# BufferLoad disables the gptGlobalReadAddr used in flat addressing.
if kernel["BufferLoad"]:
self.startVgprGlobalReadOffsetA = vgprIdx
vgprIdx += 1 if kernel["_UseSgprForGRO"] else self.states.a.numVgprGlobalReadOffsets
self.startVgprGlobalReadOffsetB = vgprIdx
vgprIdx += 1 if kernel["_UseSgprForGRO"] else self.states.b.numVgprGlobalReadOffsets
if kernel["ProblemType"]["Sparse"]:
self.startVgprGlobalReadOffsetMetadata = vgprIdx
if kernel["DirectToVgprSparseMetadata"]:
miWaveTile = kernel["MIWaveTileB"] if kernel["ProblemType"]["Sparse"] == 2 else kernel["MIWaveTileA"]
vgprIdx += miWaveTile
else:
vgprIdx += 1 if kernel["_UseSgprForGRO"] else self.states.m.numVgprGlobalReadOffsets
else:
# TODO: alignment hack, figure out a better solution
vgprIdx = ((vgprIdx+1)//2)*2
self.startVgprGlobalReadAddressesA = vgprIdx
vgprIdx += numVgprGlobalReadAddressesA
self.startVgprGlobalReadAddressesB = vgprIdx
vgprIdx += numVgprGlobalReadAddressesB

self.startVgprGlobalReadIncsA = vgprIdx
vgprIdx += numVgprGlobalReadIncsA
self.startVgprGlobalReadIncsB = vgprIdx
vgprIdx += numVgprGlobalReadIncsB
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
self.startVgprGlobalReadIncsMetadata = vgprIdx
vgprIdx += numVgprGlobalReadIncsMetadata
#-----------

if self.states.a.startVgprG2L is None and self.states.a.numVgprG2LAllocated > 0:
# TODO: alignment hack, figure out a better solution
vgprIdx = ((vgprIdx+1)//2)*2
self.states.a.startVgprG2L = vgprIdx;
self.states.a.startVgprG2L = vgprIdx
if kernel["ULSGRODoubleG2L"] == 1:
vgprIdx += self.states.a.numVgprG2LAllocated*2
else:
Expand All @@ -3847,7 +3874,7 @@ def readWriteVectors(mat, vw, kernel):
if self.states.b.startVgprG2L is None and self.states.b.numVgprG2LAllocated > 0:
# TODO: alignment hack, figure out a better solution
vgprIdx = ((vgprIdx+1)//2)*2
self.states.b.startVgprG2L = vgprIdx;
self.states.b.startVgprG2L = vgprIdx
if kernel["ULSGRODoubleG2L"] == 1:
vgprIdx += self.states.b.numVgprG2LAllocated*2
else:
Expand All @@ -3859,21 +3886,24 @@ def readWriteVectors(mat, vw, kernel):
vgprIdx = ((vgprIdx+1)//2)*2
self.states.m.startVgprG2L = vgprIdx; vgprIdx += self.states.m.numVgprG2LAllocated

self.states.a.startVgprLocalReadAddr = vgprIdx
vgprIdx += self.states.a.numVgprLocalReadAddr

self.states.b.startVgprLocalReadAddr = vgprIdx
vgprIdx += self.states.b.numVgprLocalReadAddr

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
self.states.m.startVgprLocalReadAddr = vgprIdx
vgprIdx += self.states.m.numVgprLocalReadAddr

# GlobalRead, LocalWrite, LocalRead, G2L can be reclaimed, extend the "lastVgprForReads" value
if kernel["PrefetchGlobalRead"]:
self.states.lastVgprForReads = vgprIdx
#-----------
if kernel["ProblemType"]["Gradient"] and kernel["ProblemType"]["UseBias"]:
if kernel["ProblemType"]["BiasSrc"] == "A":
self.states.bias.numVgprValu = kernel["MIWaveTile"][0]
elif kernel["ProblemType"]["BiasSrc"] == "B":
self.states.bias.numVgprValu = kernel["MIWaveTile"][1]
else:
self.states.bias.numVgprValu = 0
self.states.bias.numVgprValu *= max(kernel["ProblemType"]["ComputeDataType"].numRegisters(), 1)
else:
self.states.bias.numVgprValu = 0
self.states.bias.startVgprValu = vgprIdx
vgprIdx += self.states.bias.numVgprValu

#-----------
if kernel["ProblemType"]["OutputAmaxD"]:
self.startVgprAmaxOut = vgprIdx
self.startVgprAmaxOutB = vgprIdx + 1
Expand Down
Loading
Loading