Skip to content

Commit

Permalink
ggml : fixes for AVXVNNI instruction set with MSVC and Clang (#11027)
Browse files Browse the repository at this point in the history
* Fixes for clang AVX VNNI

* enable AVX VNNI and alder lake build for MSVC

* Apply suggestions from code review

---------

Co-authored-by: slaren <[email protected]>
  • Loading branch information
Srihari-mcw and slaren authored Dec 31, 2024
1 parent 45095a6 commit 0827b2c
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
4 changes: 2 additions & 2 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 FMA)
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
if (NOT MSVC)
# MSVC doesn't support AVX-VNNI or AMX
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
# MSVC doesn't support AMX
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
endif()
else ()
Expand Down
3 changes: 1 addition & 2 deletions ggml/src/ggml-cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND ARCH_DEFINITIONS GGML_SSE42)
endif()
if (GGML_AVX_VNNI)
# MSVC generates AVX512 with AVX-VNNI intrinsics even with /arch:AVX2
#list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
endif()
else ()
if (GGML_NATIVE)
Expand Down
5 changes: 4 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
}

static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbusd_epi32(zero, ax, sy);
#elif defined(__AVXVNNI__)
const __m256i zero = _mm256_setzero_si256();
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
Expand Down
6 changes: 5 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
}

static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#elif defined(__AVXVNNI__)
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
return _mm256_cvtepi32_ps(summed_pairs);
#else
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-cpu/llamafile/sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1000,8 +1000,10 @@ class tinyBLAS_Q0_AVX {

inline __m256 updot(__m256i u, __m256i s) {
__m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
#elif defined(__AVXVNNI__)
res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
#else
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
#endif
Expand Down

0 comments on commit 0827b2c

Please sign in to comment.