-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OPT] Tail Loop Optimization #1567
base: develop
Are you sure you want to change the base?
Conversation
Any brief before/after comparison of the tail loop asm code? |
ef4242e
to
4b4f883
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good Opt. If you can share the performance gain for sensitive sizes, it will be much better.
details: 1. Separate tailLoopOpt for A / B: tailLoopOptA / tailLoopOptB. 2. Not supported: DTV, SparseGemm. 3. Reorder load instructions with more vgprs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
tailLoopOpt2nd == False) else 3 | ||
|
||
globalReadMode1st = 3 if tensorParameters1st["isSwizzled"] else globalReadMode1st | ||
globalReadMode2nd = 3 if tensorParameters2nd["isSwizzled"] else globalReadMode2nd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put more comments about what are globalReadMode1st and globalReadMode2nd here.
details:
Compare:
globalReadMode = 3 -> use more vgpr to reorder GR, waitcnt, v_or_b32 instructions
Before:
/* g2l=0, load component 0 /
buffer_load_ubyte_d16 v[vgprG2LA+0+0], ..., 0 offen offset:0 // load one buffer value
/ g2l=0, load component 1 /
buffer_load_ubyte_d16 v0, ..., 0 offen offset:1 // load one buffer value
s_waitcnt vmcnt(0)
v_lshlrev_b32 v0, 0x8, v0 // shift left to higher 8 bits
v_or_b32 v[vgprG2LA+0+0], v[vgprG2LA+0+0], v0 // pack a sub 8-bit with dest
/ g2l=0, load component 0 /
buffer_load_ubyte_d16 v[vgprG2LA+0+4], ... offen offset:0 // load one buffer value
/ g2l=0, load component 1 */
buffer_load_ubyte_d16 v0, ... offen offset:1 // load one buffer value
s_waitcnt vmcnt(0)
v_lshlrev_b32 v0, 0x8, v0 // shift left to higher 8 bits
v_or_b32 v[vgprG2LA+0+4], v[vgprG2LA+0+4], v0 // pack a sub 8-bit with dest
...
After:
buffer_load_ubyte_d16 v[vgprG2LA+0+0], ... offen offset:0 // load one buffer value
buffer_load_ubyte_d16 v0, ..., 0 offen offset:1 // load one buffer value
buffer_load_ubyte_d16 v[vgprG2LA+0+4], ... offen offset:0 // load one buffer value
buffer_load_ubyte_d16 v1, ... offen offset:1 // load one buffer value
buffer_load_ubyte_d16 v[vgprG2LA+1+0], offen offset:0 // load one buffer value
...
s_waitcnt vmcnt(10)
v_lshlrev_b32 v0, 0x8, v0 // shift left to higher 8 bits
v_or_b32 v[vgprG2LA+0+0], v[vgprG2LA+0+0], v0 // pack a sub 8-bit with dest
s_waitcnt vmcnt(8)
v_lshlrev_b32 v1, 0x8, v1 // shift left to higher 8 bits
v_or_b32 v[vgprG2LA+0+4], v[vgprG2LA+0+4], v1 // pack a sub 8-bit with dest
...
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
globalReadMode = 2 -> use wider global load instructions
Before:
/* g2l=0, load component 0 /
buffer_load_ubyte_d16 v[vgprG2LB+0+0], ..., 0 offen offset:0 // load one buffer value
/ g2l=0, load component 1 /
buffer_load_ubyte_d16 v51, ..., 0 offen offset:1 // load one buffer value
/ g2l=0, load component 2 /
buffer_load_ubyte_d16_hi v52, ..., 0 offen offset:2 // load one buffer value
/ g2l=0, load component 3 */
buffer_load_ubyte_d16_hi v53, ..., 0 offen offset:3 // load one buffer value
...
s_waitcnt vmcnt(14)
v_lshlrev_b32 v51, 0x8, v51 // shift left to higher 8 bits
v_or_b32 v[vgprG2LB+0+0], v[vgprG2LB+0+0], v51 // pack a sub 8-bit with dest
s_waitcnt vmcnt(13)
v_or_b32 v[vgprG2LB+0+0], v[vgprG2LB+0+0], v52 // pack a sub 8-bit with dest
s_waitcnt vmcnt(12)
v_lshlrev_b32 v53, 0x8, v53 // shift left to higher 8 bits
v_or_b32 v[vgprG2LB+0+0], v[vgprG2LB+0+0], v53 // pack a sub 8-bit with dest
...
After:
buffer_load_dwordx4 v[vgprG2LB+0:vgprG2LB+0+3], v[vgprGlobalReadOffsetB+0], s[sgprSrdB:sgprSrdB+3], 0 offen offset:0 // G -> Reg 0_0_0_0
... (calculate some data to determine how to load the last data)
label_LoadB:
... (jump to specified load tile)
label_LOAD_B0:
label_LOAD_B0_K1:
s_cmp_ge_u32 s11, 1
s_cbranch_scc0 label_MergeB
/* g2l=0, load component 0 */
buffer_load_ubyte_d16 v54, ... 0 offen offset:0 // load one buffer value
label_LOAD_B0_K2
...
label_LOAD_B0_K15:
... (load code)
s_branch label_MergeB
label_MergeB:
... (jump to specified load tile)
label_MERGE_B0:
label_MERGE_B0_K1:
s_cmp_ge_u32 s11, 1
s_cbranch_scc0 label_CheckB_OOB
s_waitcnt vmcnt(0)
v_or_b32 v[vgprG2LB+0+0], v[vgprG2LB+0+0], v54 // pack a sub 8-bit with dest
label_MERGE_B0_K2:
...
label_MERGE_B0_K15:
... (pack code)
s_branch label_CheckB_OOB
label_CheckB_OOB:
...
label_CheckLoopBeginB:
... (calculate size to be loaded and size can be loaded)
label_B0:
... (check if there's other tile should be loaded again due to OOB)
s_cbranch_scc1 label_LoadB // Reload
s_branch label_CheckLoopBeginB // Re check
label_TailGlobalLoadEnd:
s_waitcnt vmcnt(0)