-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement qs8 x8c4 pack using avxvnni #7462
base: master
Are you sure you want to change the base?
Conversation
xnn_prefetch_to_l1((const int8_t*) w4 + 448); | ||
xnn_prefetch_to_l1((const int8_t*) w5 + 448); | ||
xnn_prefetch_to_l1((const int8_t*) w6 + 448); | ||
xnn_prefetch_to_l1((const int8_t*) w7 + 448); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you have time, try some different offsets here and run on meteorlake, alderlake or lunar lake.
448 is 7 cache lines of 64 and a good choice
960 is 15 cache lines of 64 and may work better on large packw
128 is 2 cache lines ahead and sometimes good when doing a small amount of loads. You are reading 32 bytes, so these prefetches are somewhat redundent.
4096 is interesting due to page
instead of prefetching the same cachelines, it might be worth the logic to only prefetch half the time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried those different offsets on MTL on benchmark qs8_packw_bench, and 960 has the lowest cache miss, but 448 has better performance. for alternating prefetch, instructions is reduced, but totally there is no performance improvement.
out += 256; | ||
} | ||
|
||
// KC main loop multiple of 8x8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop could be removed. In practice KC is a large value most of the time
v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); | ||
v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); | ||
v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); | ||
xnn_prefetch_to_l1((const int8_t*) w0 + 224); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the prefetch distance should remain constant.
|
||
const __m256i vmask = _mm256_set1_epi32((1u << (k * sizeof(int8_t) * 8)) - 1); | ||
|
||
__m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unaligned_load_u32 will read out of bounds. I think we need to be asan friendly and remove the 'EXTRA_BYTES' in the calling unittests and stay within bounds.
To keep the source code simple, maybe write a function to replace unaligned_load_u32(w0)); which reads 1 to 3 bytes, passing k.
On avx10 we'll be able to set up a kmask. The same source code will be used, but compiled as avx512
22972ed
to
2adadf1
Compare
uint32_t value = 0; | ||
const uint8_t* bytes = (const uint8_t*)src; | ||
for (size_t i = 0; i < k; ++i) { | ||
value |= (uint32_t)bytes[i] << (i * 8); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a space after casting.
value |= (uint32_t) bytes[i] << (i * 8);
const void* scale, | ||
int8_t* packed_weights, | ||
size_t extra_bytes, | ||
const void* params) XNN_OOB_READS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove XNN_OOB_READS
This kernel should be asan friendly now
|
||
const __m256i vmask = _mm256_set1_epi32((1u << (k * sizeof(int8_t) * 8)) - 1); | ||
|
||
__m256i v0 = _mm256_set1_epi32((int32_t) safe_load_32bits(w0, k)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this function should follow the naming conventions of sse or the function it replaces
unaligned_load_u32
becomes
safe_load_u32
// KC main loop multiple of 8x4 | ||
for (; k >= 4; k -= 4) { | ||
__m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); | ||
v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on a future avx10 version of this kernel, this and the remainder code could use a kmask to read and embedded broadcast to the correct channels.
if XNN_LIKELY(b != NULL) { | ||
size_t nb = n; | ||
do { | ||
*((int32_t*) out) = *b++; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optionally use something like safe_load to read 1 to 7 int32 and then store a full vector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can write something like:
XNN_INLINE static __m256i safe_load_avx(const void* src, size_t n) {
int32_t temp[8] = {0};
for (size_t i = 0; i < n; ++i) {
temp[i] = ((const int32_t*)src)[i];
}
return _mm256_loadu_si256((const __m256i*) temp);
}
when n
is little, I wonder the memory overhead will be a problem
} else { | ||
size_t nb = n; | ||
do { | ||
*((int32_t*) out) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can store full vectors of 0 the same as the main NC loop
_mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that maybe cause memory out of bound? unless add a mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated. the output is padded, so it will not out of bound. Thanks the feedback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good
-- 2adadf1 by zhenweijin <[email protected]>: implement qs8 x8c4 pack using avxvnni FUTURE_COPYBARA_INTEGRATE_REVIEW=#7462 from kylo5aby:qs8-x8c4 2adadf1 PiperOrigin-RevId: 698941566
This is failing to compile:
|
-- 2adadf1 by zhenweijin <[email protected]>: implement qs8 x8c4 pack using avxvnni FUTURE_COPYBARA_INTEGRATE_REVIEW=#7462 from kylo5aby:qs8-x8c4 2adadf1 PiperOrigin-RevId: 698941566
-- 2adadf1 by zhenweijin <[email protected]>: implement qs8 x8c4 pack using avxvnni FUTURE_COPYBARA_INTEGRATE_REVIEW=#7462 from kylo5aby:qs8-x8c4 2adadf1 PiperOrigin-RevId: 698941566
-- c490961 by zhenweijin <[email protected]>: implement qs8 x8c4 pack using avxvnni FUTURE_COPYBARA_INTEGRATE_REVIEW=#7462 from kylo5aby:qs8-x8c4 c490961 PiperOrigin-RevId: 698941566
Resolved. |
qs8_packw/xnn_qs8_packw_gemm_goi_ukernel_x8c4__scalar_sd1x_encoder_decoder/B:1/M:512/N:4096/K:4096/real_time 3220123 ns 2962085 ns 211 bytes=10.4215G/s elements=5.21012G/s
qs8_packw/xnn_qs8_packw_gemm_goi_ukernel_x8c4__avxvnni_sd1x_encoder_decoder/B:1/M:512/N:4096/K:4096/real_time 1534454 ns 710227 ns
qs8_packw/xnn_qs8_packw_gemm_goi_ukernel_x8c4__scalar_sd1x_diffusion/B:8/M:64/N:64/K:160/real_time 15215 ns 13587 ns 43699 bytes=10.8019G/s elements=5.38414G/s
qs8_packw/xnn_qs8_packw_gemm_goi_ukernel_x8c4__avxvnni_sd1x_diffusion/B:8/M:64/N:64/K:160/real_time 5801 ns 1592 ns 117772 bytes=28.33G/s elements=14.1209G/s
qs8_packw/xnn_qs8_packw_gemm_goi_ukernel_x8c4__scalar_attention/B:7/M:1/N:16/K:2304/real_time 48944 ns 47492 ns 13489 bytes=10.5469G/s elements=5.27231G/s
qs8_packw/xnn_qs8_packw_gemm_goi_ukernel_x8c4__avxvnni_attention/B:7/M:1/N:16/K:2304/real_time 14653 ns 7623 ns 45095 bytes=35.2285G/s elements=17.6104G/s