From f1d038f0fa8ffb1d792d3195a7dc6b8d44d7bcf6 Mon Sep 17 00:00:00 2001 From: Pedro Gonnet Date: Fri, 20 Dec 2024 06:46:11 -0800 Subject: [PATCH] Better tiling for GEMMs. This change does the following: * Flip the order in which the GEMM tiles are computed (loop over `m`, then `n`), * Change GEMM task functions in `operator-run.c` to allow tiling over multiples of `mr`, * Tile GEMMs over multiples of `mr` and `nr` such that we keep the minimum number of tasks required, * Choose GEMM tile sizes that fit into the L1 or L2 data cache. PiperOrigin-RevId: 708308468 --- BUILD.bazel | 2 + src/microkernel-utils.c | 130 ++++++++- src/operator-run.c | 325 ++++++++++++--------- src/operators/batch-matrix-multiply-nc.c | 19 +- src/operators/convolution-nhwc.c | 28 +- src/operators/dynamic-fully-connected-nc.c | 66 +++-- src/operators/fully-connected-nc.c | 74 +++-- src/xnnpack/compute.h | 109 +++---- src/xnnpack/microkernel-utils.h | 11 +- test/microkernel-utils.cc | 50 +++- 10 files changed, 509 insertions(+), 305 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 6b443921e0f6..60ba71826ce3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -895,6 +895,8 @@ xnnpack_cc_library( hdrs = ["src/xnnpack/microkernel-utils.h"], deps = [ ":common", + ":hardware_config", + ":logging", ":math", ], ) diff --git a/src/microkernel-utils.c b/src/microkernel-utils.c index 1dbf36268640..aa953f09a647 100644 --- a/src/microkernel-utils.c +++ b/src/microkernel-utils.c @@ -8,26 +8,126 @@ #include #include +#include "xnnpack/common.h" +#include "xnnpack/hardware-config.h" +#include "xnnpack/log.h" #include "xnnpack/math.h" -size_t xnn_gemm_best_nc(size_t num_groups, size_t m, size_t n, size_t mr, - size_t nr, size_t num_threads) { - size_t nc = n; - if (num_threads > 1) { - const size_t min_num_tiles = num_threads * XNN_GEMM_TILES_PER_THREAD; - const size_t num_tile_rows = divide_round_up(m, mr) * num_groups; - const size_t num_tile_cols = divide_round_up(min_num_tiles, num_tile_rows); - - // We are looking for an `nc` that is the smallest integer multiple of `nr` - // such that `divide_round_up(n, nc)` is `num_tile_cols`. - nc = max(1, round_up(n, nr) / (nr * num_tile_cols)) * nr; - while (nr < nc && divide_round_up(n, nc - nr) == divide_round_up(n, nc)) { - nc -= nr; +static bool fits_in_cache(size_t mr, size_t nc, size_t m_stride, + size_t n_stride, size_t cm_stride, size_t cn_stride, + size_t cache_size, size_t cache_line_size) { + // Check if the bytes fit. + const size_t lines_mr = divide_round_up(mr * m_stride, cache_line_size); + const size_t lines_nc = divide_round_up(nc * n_stride, cache_line_size); + const size_t lines_output = + mr * divide_round_up(nc * cn_stride, cache_line_size); + const size_t lines_per_row = lines_mr + lines_nc + lines_output; + if (cache_size < lines_per_row * cache_line_size) { + return false; + } + + // Otherwiese, we're good. + return true; +} + +void xnn_gemm_best_tile_size(size_t num_groups, size_t m, size_t n, + size_t m_stride, size_t n_stride, size_t cm_stride, + size_t cn_stride, size_t mr, size_t nr, + size_t num_threads, size_t *mc, size_t *nc) { + // Adjust `mr` and `nr` if they are larger than `m` and `n`, respectively. + mr = min(mr, m); + nr = min(nr, n); + + // We only care about the number of tiles if we have more than one thread. + const size_t min_num_tiles = + num_threads > 1 ? XNN_GEMM_TILES_PER_THREAD * num_threads : 1; + + // Start with a `mr`x`nr` tile. + *mc = mr; + *nc = nr; + size_t best_num_tiles = + divide_round_up(m, *mc) * divide_round_up(n, *nc) * num_groups; + + // Select which cache we want the tiles to fit in. Start with L1, and if the + // smallest possible tile won't fit, try L2. If the smallest tile still won't + // fit, then don't try to fit to the cache size. + const struct xnn_hardware_config *hardware_config = + xnn_init_hardware_config(); + size_t cache_size = hardware_config->l1_data_cache_bytes; + size_t cache_line_size = hardware_config->l1_data_cache_line_size; + if (XNN_ARCH_X86 || XNN_ARCH_X86_64 || + (cache_size && !fits_in_cache(mr, nr, m_stride, n_stride, cm_stride, + cn_stride, cache_size, cache_line_size))) { + cache_size = hardware_config->l2_data_cache_bytes; + cache_line_size = hardware_config->l2_data_cache_line_size; + if (cache_size && !fits_in_cache(mr, nr, m_stride, n_stride, cm_stride, + cn_stride, cache_size, cache_line_size)) { + // Don't check for cache fit. + cache_size = 0; + } + } + + // Loop over all multiples of `mr`. + for (int i = 1; (i - 1) * mr < m; i++) { + // Skip this `i` if it results in the same number of tiles as `i - 1`. + if (1 < i && + divide_round_up(m, i * mr) == divide_round_up(m, (i - 1) * mr)) { + continue; + } + + // Make sure this value of `i` generates enough tiles. + const size_t num_tiles_m = divide_round_up(m, i * mr); + const size_t num_tiles_n = divide_round_up(n, nr); + const size_t num_tiles = num_tiles_n * num_tiles_m * num_groups; + if (num_tiles < min_num_tiles) { + break; + } + + // Loop over all multiples of `nr`. + for (int j = 1; (j - 1) * nr < n; j++) { + // Skip this `j` if it results in the same number of tiles as `j - 1`. + if (1 < j && + divide_round_up(n, j * nr) == divide_round_up(n, (j - 1) * nr)) { + continue; + } + + // If we have at most `mr` rows per group, then there will be no cache + // re-use across tile rows and we don't care about whether the data fits + // in cache or not. + // If, however, we have more than one tile row, then we want the data used + // to compute a tile of size `mr`x`j*nr` to fit in the cache. + if (mr < m && cache_size && + !fits_in_cache(mr, j * nr, m_stride, n_stride, cm_stride, cn_stride, + cache_size, cache_line_size)) { + break; + } + + // Make sure this pair of `i` and `j` generates enough tiles. + const size_t num_tiles_n = divide_round_up(n, j * nr); + const size_t num_tiles = num_tiles_n * num_tiles_m * num_groups; + if (num_tiles < min_num_tiles) { + break; + } + + // New best tile size? We define the "best" tiling as the smallest total + // number of tiles, and for tilings with the same number of tiles, we take + // the tiling with the largest `nc`. + if (num_tiles < best_num_tiles || + (num_tiles == best_num_tiles && *nc < j * nr)) { + *mc = i * mr; + *nc = j * nr; + best_num_tiles = num_tiles; + } } - nc = min(nc, n); } - return nc; + // Restrict the resulting `mc` and `nc` to `m` and `n`, respectively. + *mc = min(*mc, m); + *nc = min(*nc, n); + xnn_log_debug( + "Tile size for GEMM with num_groups=%zi, m=%zi, n=%zi and mr=%zi, nr=%zi " + "set to [%zi, %zi]", + num_groups, m, n, mr, nr, *mc, *nc); } static size_t dwconv_num_middle_pass( diff --git a/src/operator-run.c b/src/operator-run.c index afcf6ebbfc2c..d9165029822c 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -20,6 +20,7 @@ #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microkernel-type.h" +#include "xnnpack/microparams-init.h" #include "xnnpack/microparams.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator.h" @@ -382,11 +383,8 @@ void xnn_compute_batched_packw_gemm_goi( void xnn_compute_hmp_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t group_index, size_t mr_block_start, - size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { - const size_t k_scaled = context->k_scaled; - const size_t a_stride = context->a_stride; - const size_t cm_stride = context->cm_stride; + uint32_t uarch_index, size_t group_index, size_t nr_block_start, + size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) { const size_t num_batch_dims = context->num_batch_dims; const size_t group_index_c = group_index; @@ -404,111 +402,142 @@ void xnn_compute_hmp_grouped_gemm( group_index_b = (index % context->batch_dims_b[k]) + context->batch_dims_b[k] * group_index_b; } - if (context->quantization_params != NULL) { - // If the effective `mr_block_size` is smaller than the kernel's `mr`, - // create a padded copy of the dynamic quantization params. - const struct xnn_qd8_quantization_params* quantization_params = - &context->quantization_params[group_index_a * context->gq_stride + - mr_block_start]; + + const size_t k_scaled = context->k_scaled; + const size_t a_stride = context->a_stride; + const size_t cm_stride = context->cm_stride; + const size_t cn_stride = context->cn_stride; + const uintptr_t a = + (uintptr_t)context->a + group_index_a * context->ga_stride; + const void* packed_w = (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride + + group_index_b * context->gw_stride); + const uintptr_t c = (uintptr_t)context->c + + (nr_block_start << context->log2_csize) + + group_index_c * context->gc_stride; + const struct xnn_qd8_quantization_params* context_quantization_params = + context->quantization_params; + const void* params = &context->params; + const size_t mr = context->mr; + const size_t mr_block_end = mr_block_start + mr_block_size; + + if (context_quantization_params != NULL) { + const size_t gq_stride = context->gq_stride; struct xnn_qd8_quantization_params padded_quantization_params[XNN_MAX_MR]; - if (mr_block_size < context->mr) { - memcpy(padded_quantization_params, quantization_params, - mr_block_size * sizeof(struct xnn_qd8_quantization_params)); - for (size_t i = mr_block_size; i < context->mr; i++) { - padded_quantization_params[i] = - padded_quantization_params[mr_block_size - 1]; + const xnn_dqgemm_ukernel_fn dq_ukernel_function = + context->dq_ukernel.function[uarch_index]; + + for (; mr_block_start < mr_block_end; mr_block_start += mr) { + const size_t mr_block_size = min(mr, mr_block_end - mr_block_start); + + // If the effective `mr_block_size` is smaller than the kernel's `mr`, + // create a padded copy of the dynamic quantization params. + const struct xnn_qd8_quantization_params* quantization_params = + &context_quantization_params[group_index_a * gq_stride + + mr_block_start]; + if (mr_block_size < mr) { + for (size_t i = 0; i < mr_block_size; i++) { + padded_quantization_params[i] = quantization_params[i]; + } + for (size_t i = mr_block_size; i < mr; i++) { + padded_quantization_params[i] = + padded_quantization_params[mr_block_size - 1]; + } + quantization_params = padded_quantization_params; } - quantization_params = padded_quantization_params; - }; - context->dq_ukernel.function[uarch_index]( - mr_block_size, nr_block_size, k_scaled, - (const void*)((uintptr_t)context->a + mr_block_start * a_stride + - group_index_a * context->ga_stride), - a_stride, - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index_b * context->gw_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize) + - group_index_c * context->gc_stride), - cm_stride, context->cn_stride, &context->params, quantization_params); + dq_ukernel_function(mr_block_size, nr_block_size, k_scaled, + (const void*)(a + mr_block_start * a_stride), + a_stride, packed_w, + (void*)(c + mr_block_start * cm_stride), cm_stride, + cn_stride, params, quantization_params); + } } else { - context->ukernel.function[uarch_index]( - mr_block_size, nr_block_size, k_scaled, - (const void*)((uintptr_t)context->a + mr_block_start * a_stride + - group_index_a * context->ga_stride), - a_stride, - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index_b * context->gw_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize) + - group_index_c * context->gc_stride), - cm_stride, context->cn_stride, &context->params); + const xnn_gemm_ukernel_fn ukernel_function = + context->ukernel.function[uarch_index]; + + for (; mr_block_start < mr_block_end; mr_block_start += mr) { + const size_t mr_block_size = min(mr, mr_block_end - mr_block_start); + ukernel_function(mr_block_size, nr_block_size, k_scaled, + (const void*)(a + mr_block_start * a_stride), a_stride, + packed_w, (void*)(c + mr_block_start * cm_stride), + cm_stride, cn_stride, params); + } } } void xnn_compute_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t group_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size) { + size_t group_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { xnn_compute_hmp_grouped_gemm(context, XNN_UARCH_DEFAULT, group_index, - mr_block_start, nr_block_start, mr_block_size, - nr_block_size); + nr_block_start, mr_block_start, nr_block_size, + mr_block_size); } void xnn_compute_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size) -{ + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size) { const size_t a_stride = context->a_stride; + const size_t k_scaled = context->k_scaled; const size_t cm_stride = context->cm_stride; - - context->ukernel.function[XNN_UARCH_DEFAULT]( - mr_block_size, - nr_block_size, - context->k_scaled, - (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), - a_stride, - (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), - (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), - cm_stride, - context->cn_stride, - context->fused_params); + const size_t mr = context->mr; + const size_t mr_block_end = mr_block_start + mr_block_size; + const xnn_gemm_ukernel_fn ukernel_function = + context->ukernel.function[XNN_UARCH_DEFAULT]; + const size_t cn_stride = context->cn_stride; + const uintptr_t a = (uintptr_t)context->a; + const void* packed_w = (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride); + const uintptr_t c = + (uintptr_t)context->c + (nr_block_start << context->log2_csize); + const void* fused_params = context->fused_params; + + for (; mr_block_start < mr_block_end; mr_block_start += mr) { + mr_block_size = min(mr, mr_block_end - mr_block_start); + ukernel_function(mr_block_size, nr_block_size, k_scaled, + (const void*)(a + mr_block_start * a_stride), a_stride, + packed_w, (void*)(c + mr_block_start * cm_stride), + cm_stride, cn_stride, fused_params); + } } void xnn_compute_dqgemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size) -{ + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size) { const size_t a_stride = context->a_stride; + const size_t k_scaled = context->k_scaled; const size_t cm_stride = context->cm_stride; - - context->dq_ukernel.function[XNN_UARCH_DEFAULT]( - mr_block_size, - nr_block_size, - context->k_scaled, - (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), - a_stride, - (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), - (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), - cm_stride, - context->cn_stride, - context->fused_params, - (const void*) ((uintptr_t) &context->quantization_params[mr_block_start])); + const size_t mr = context->mr; + const size_t mr_block_end = mr_block_start + mr_block_size; + const xnn_dqgemm_ukernel_fn dq_ukernel_function = + context->dq_ukernel.function[XNN_UARCH_DEFAULT]; + const size_t cn_stride = context->cn_stride; + const uintptr_t a = (uintptr_t)context->a; + const void* packed_w = (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride); + const uintptr_t c = + (uintptr_t)context->c + (nr_block_start << context->log2_csize); + const void* fused_params = context->fused_params; + const struct xnn_qd8_quantization_params* quantization_params = + context->quantization_params; + + for (; mr_block_start < mr_block_end; mr_block_start += mr) { + mr_block_size = min(mr, mr_block_end - mr_block_start); + dq_ukernel_function(mr_block_size, nr_block_size, k_scaled, + (const void*)(a + mr_block_start * a_stride), a_stride, + packed_w, (void*)(c + mr_block_start * cm_stride), + cm_stride, cn_stride, fused_params, + &quantization_params[mr_block_start]); + } } void xnn_compute_hmp_grouped_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t group_index, size_t mr_block_start, - size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { + uint32_t uarch_index, size_t group_index, size_t nr_block_start, + size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) { const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset( mr_block_start, context->k_scaled, context->mr, context->kr, context->sr); const size_t cm_stride = context->cm_stride; @@ -546,38 +575,44 @@ void xnn_compute_hmp_grouped_qp8gemm( void xnn_compute_grouped_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t group_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size) { + size_t group_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { xnn_compute_hmp_grouped_qp8gemm(context, XNN_UARCH_DEFAULT, group_index, - mr_block_start, nr_block_start, mr_block_size, - nr_block_size); + nr_block_start, mr_block_start, nr_block_size, + mr_block_size); } void xnn_compute_hmp_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size) { + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset( mr_block_start, context->k_scaled, context->mr, context->kr, context->sr); const size_t cm_stride = context->cm_stride; - - context->qp8_ukernel.function[uarch_index]( - mr_block_size, nr_block_size, context->k_scaled, - (const void*)((uintptr_t)context->a + a_offset), - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, - /*dst_stride_col=*/sizeof(float), context->fused_params); + const void* packed_w = (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride); + const uintptr_t c = + (uintptr_t)context->c + (nr_block_start << context->log2_csize); + void* const fused_params = context->fused_params; + const size_t mr = context->mr; + const size_t mr_block_end = mr_block_start + mr_block_size; + + for (; mr_block_start < mr_block_end; mr_block_start += mr) { + mr_block_size = min(mr, mr_block_end - mr_block_start); + context->qp8_ukernel.function[uarch_index]( + mr_block_size, nr_block_size, context->k_scaled, + (const void*)((uintptr_t)context->a + a_offset), packed_w, + (void*)(c + mr_block_start * cm_stride), cm_stride, + /*dst_stride_col=*/sizeof(float), fused_params); + } } void xnn_compute_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, - size_t nr_block_size) { - xnn_compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, mr_block_start, - nr_block_start, mr_block_size, nr_block_size); + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size) { + xnn_compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, nr_block_start, + mr_block_start, nr_block_size, mr_block_size); } void xnn_compute_spmm( @@ -2390,46 +2425,60 @@ void xnn_compute_rope( #if XNN_MAX_UARCH_TYPES > 1 void xnn_compute_hmp_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size) { + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; - - context->ukernel.function[uarch_index]( - mr_block_size, nr_block_size, context->k_scaled, - (const void*)((uintptr_t)context->a + mr_block_start * a_stride), - a_stride, - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, context->cn_stride, context->fused_params); + const size_t mr = context->mr; + const size_t mr_block_end = mr_block_start + mr_block_size; + const size_t k_scaled = context->k_scaled; + const size_t cn_stride = context->cn_stride; + const size_t w_stride = context->w_stride; + const uintptr_t a = (uintptr_t)context->a; + const void* packed_w = + (void*)((uintptr_t)context->packed_w + nr_block_start * w_stride); + const uintptr_t c = + (uintptr_t)context->c + (nr_block_start << context->log2_csize); + const void* fused_params = context->fused_params; + + for (; mr_block_start < mr_block_end; mr_block_start += mr) { + mr_block_size = min(mr, mr_block_end - mr_block_start); + context->ukernel.function[uarch_index]( + mr_block_size, nr_block_size, k_scaled, + (const void*)(a + mr_block_start * a_stride), a_stride, packed_w, + (void*)(c + mr_block_start * cm_stride), cm_stride, cn_stride, + fused_params); } +} - void xnn_compute_hmp_dqgemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size) - { - const size_t a_stride = context->a_stride; - const size_t cm_stride = context->cm_stride; - +void xnn_compute_hmp_dqgemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { + const size_t a_stride = context->a_stride; + const size_t cm_stride = context->cm_stride; + const size_t mr = context->mr; + const size_t mr_block_end = mr_block_start + mr_block_size; + const size_t k_scaled = context->k_scaled; + const size_t cn_stride = context->cn_stride; + const uintptr_t a = (uintptr_t)context->a; + const void* packed_w = (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride); + const uintptr_t c = + (uintptr_t)context->c + (nr_block_start << context->log2_csize); + const void* fused_params = context->fused_params; + const void* quantization_params = + (const void*)((uintptr_t)&context->quantization_params[mr_block_start]); + + for (; mr_block_start < mr_block_end; mr_block_start += mr) { + mr_block_size = min(mr, mr_block_end - mr_block_start); context->dq_ukernel.function[uarch_index]( - mr_block_size, - nr_block_size, - context->k_scaled, - (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), - a_stride, - (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), - (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), - cm_stride, - context->cn_stride, - context->fused_params, - (const void*) ((uintptr_t) &context->quantization_params[mr_block_start])); + mr_block_size, nr_block_size, k_scaled, + (const void*)(a + mr_block_start * a_stride), a_stride, packed_w, + (void*)(c + mr_block_start * cm_stride), cm_stride, cn_stride, + fused_params, quantization_params); } +} void xnn_compute_hmp_grouped_batch_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index c19825035a69..476e5453f488 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -684,7 +684,16 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( memcpy(&batch_matrix_multiply_op->context.gemm.gemm.gemm.params, params, params_size); batch_matrix_multiply_op->context.gemm.gemm.gemm.fused_params = &batch_matrix_multiply_op->context.gemm.gemm.gemm.params; - size_t nc = xnn_gemm_best_nc(batch_size_c, m, n, mr, nr, num_threads); + // Compute the optimal tile size for this GEMM. + size_t mc; + size_t nc; + xnn_gemm_best_tile_size( + /*num_groups=*/batch_size_c, m, n, + /*m_stride=*/batch_matrix_multiply_op->context.gemm.gemm.gemm.a_stride, + /*n_stride=*/batch_matrix_multiply_op->context.gemm.gemm.gemm.w_stride, + /*cm_stride=*/batch_matrix_multiply_op->context.gemm.gemm.gemm.cm_stride, + /*cn_stride=*/1 << log2_output_element_size, mr, nr, num_threads, &mc, + &nc); #if XNN_MAX_UARCH_TYPES > 1 if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { @@ -718,10 +727,10 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( } #endif gemm_compute->range[0] = batch_size_c; - gemm_compute->range[1] = m; - gemm_compute->range[2] = n; - gemm_compute->tile[0] = mr; - gemm_compute->tile[1] = nc; + gemm_compute->range[2] = m; + gemm_compute->range[1] = n; + gemm_compute->tile[1] = mc; + gemm_compute->tile[0] = nc; batch_matrix_multiply_op->state = xnn_run_state_needs_setup; return xnn_status_success; diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index bfc0be24eb70..8e9ddb89983e 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -2083,6 +2083,7 @@ static enum xnn_status reshape_gemm( .log2_csize = log2_output_element_size, .num_batch_dims = 1, .ukernel = gemm_ukernel, + .mr = mr, }; convolution_op->context.gemm.gemm.gemm.batch_dims_a[0] = groups; convolution_op->context.gemm.gemm.gemm.batch_dims_b[0] = groups; @@ -2090,8 +2091,17 @@ static enum xnn_status reshape_gemm( memcpy(&convolution_op->context.gemm.gemm.gemm.params, &convolution_op->params, sizeof(convolution_op->context.gemm.gemm.gemm.params)); convolution_op->context.gemm.gemm.gemm.fused_params = &convolution_op->context.gemm.gemm.gemm.params; - size_t nc = xnn_gemm_best_nc(groups, batch_output_size, group_output_channels, - mr, nr, num_threads); + // Compute the optimal tile size for this GEMM. + size_t mc; + size_t nc; + xnn_gemm_best_tile_size( + /*num_groups=*/groups, /*m=*/batch_output_size, + /*n=*/group_output_channels, + /*m_stride=*/convolution_op->context.gemm.gemm.gemm.a_stride, + /*n_stride=*/convolution_op->context.gemm.gemm.gemm.w_stride, + /*cm_stride=*/convolution_op->context.gemm.gemm.gemm.cm_stride, + /*cn_stride=*/1 << log2_output_element_size, mr, nr, num_threads, &mc, + &nc); if (groups == 1) { #if XNN_MAX_UARCH_TYPES > 1 @@ -2106,10 +2116,8 @@ static enum xnn_status reshape_gemm( convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; convolution_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; #endif - convolution_op->compute[0].range[0] = batch_output_size; - convolution_op->compute[0].range[1] = group_output_channels; - convolution_op->compute[0].tile[0] = mr; - convolution_op->compute[0].tile[1] = nc; + convolution_op->compute[0].range[1] = batch_output_size; + convolution_op->compute[0].range[0] = group_output_channels; } else { #if XNN_MAX_UARCH_TYPES > 1 if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { @@ -2124,11 +2132,11 @@ static enum xnn_status reshape_gemm( convolution_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm; #endif convolution_op->compute[0].range[0] = groups; - convolution_op->compute[0].range[1] = batch_output_size; - convolution_op->compute[0].range[2] = group_output_channels; - convolution_op->compute[0].tile[0] = mr; - convolution_op->compute[0].tile[1] = nc; + convolution_op->compute[0].range[2] = batch_output_size; + convolution_op->compute[0].range[1] = group_output_channels; } + convolution_op->compute[0].tile[1] = mc; + convolution_op->compute[0].tile[0] = nc; convolution_op->state = xnn_run_state_needs_setup; *workspace_size = 0; diff --git a/src/operators/dynamic-fully-connected-nc.c b/src/operators/dynamic-fully-connected-nc.c index cf5be9143101..9b52287a3ea7 100644 --- a/src/operators/dynamic-fully-connected-nc.c +++ b/src/operators/dynamic-fully-connected-nc.c @@ -376,13 +376,15 @@ static enum xnn_status reshape_dynamic_fully_connected_nc( } dynamic_fully_connected_op->context.gemm.gemm.gemm = (struct gemm_context){ - .k_scaled = input_channels << log2_input_element_size, - .w_stride = bias_element_size + (round_up_po2(input_channels, kr * sr) << log2_input_element_size), - .a_stride = input_stride << log2_input_element_size, - .cm_stride = output_stride << log2_output_element_size, - .cn_stride = nr << log2_output_element_size, - .log2_csize = log2_output_element_size, - .ukernel = gemm_ukernel, + .k_scaled = input_channels << log2_input_element_size, + .w_stride = bias_element_size + (round_up_po2(input_channels, kr * sr) + << log2_input_element_size), + .a_stride = input_stride << log2_input_element_size, + .cm_stride = output_stride << log2_output_element_size, + .cn_stride = nr << log2_output_element_size, + .log2_csize = log2_output_element_size, + .ukernel = gemm_ukernel, + .mr = mr, }; memcpy(&dynamic_fully_connected_op->context.gemm.gemm.gemm.params, params, params_size); dynamic_fully_connected_op->context.gemm.gemm.gemm.fused_params = &dynamic_fully_connected_op->context.gemm.gemm.gemm.params; @@ -391,26 +393,40 @@ static enum xnn_status reshape_dynamic_fully_connected_nc( } dynamic_fully_connected_op->context.gemm.gemm.gemm.fused_params = &dynamic_fully_connected_op->context.gemm.gemm.gemm.params; - size_t nc = - xnn_gemm_best_nc(/*num_groups=*/1, batch_size, output_channels, mr, nr, - pthreadpool_get_threads_count(threadpool)); + // Compute the optimal tile size for this GEMM. + size_t mc; + size_t nc; + xnn_gemm_best_tile_size( + /*num_groups=*/1, /*m=*/batch_size, /*n=*/output_channels, + /*m_stride=*/dynamic_fully_connected_op->context.gemm.gemm.gemm.a_stride, + /*n_stride=*/dynamic_fully_connected_op->context.gemm.gemm.gemm.w_stride, + /*cm_stride=*/ + dynamic_fully_connected_op->context.gemm.gemm.gemm.cm_stride, + /*cn_stride=*/1 << log2_output_element_size, mr, nr, + /*num_threads=*/pthreadpool_get_threads_count(threadpool), &mc, &nc); #if XNN_MAX_UARCH_TYPES > 1 - if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { - dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d_with_uarch; - dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; - } else { - dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d; - dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; - } - #else - dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d; - dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; - #endif - dynamic_fully_connected_op->compute[1].range[0] = batch_size; - dynamic_fully_connected_op->compute[1].range[1] = output_channels; - dynamic_fully_connected_op->compute[1].tile[0] = mr; - dynamic_fully_connected_op->compute[1].tile[1] = nc; + if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { + dynamic_fully_connected_op->compute[1].type = + xnn_parallelization_type_2d_tile_2d_with_uarch; + dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = + (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_gemm; + } else { + dynamic_fully_connected_op->compute[1].type = + xnn_parallelization_type_2d_tile_2d; + dynamic_fully_connected_op->compute[1].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm; + } +#else + dynamic_fully_connected_op->compute[1].type = + xnn_parallelization_type_2d_tile_2d; + dynamic_fully_connected_op->compute[1].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm; +#endif + dynamic_fully_connected_op->compute[1].range[1] = batch_size; + dynamic_fully_connected_op->compute[1].range[0] = output_channels; + dynamic_fully_connected_op->compute[1].tile[1] = mc; + dynamic_fully_connected_op->compute[1].tile[0] = nc; dynamic_fully_connected_op->state = xnn_run_state_needs_setup; return xnn_status_success; diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 03e2a75f7be2..f1e3af203a44 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -2269,47 +2269,61 @@ static enum xnn_status reshape_fully_connected_nc( memcpy(&fully_connected_op->context.gemm.gemm.gemm.params, params, params_size); fully_connected_op->context.gemm.gemm.gemm.fused_params = &fully_connected_op->context.gemm.gemm.gemm.params; - size_t nc = - xnn_gemm_best_nc(/*num_groups=*/1, batch_size, output_channels, mr, nr, - pthreadpool_get_threads_count(threadpool)); + // Compute the optimal tile size for this GEMM. + size_t mc; + size_t nc; + xnn_gemm_best_tile_size( + /*num_groups=*/1, /*m=*/batch_size, /*n=*/output_channels, + /*m_stride=*/fully_connected_op->context.gemm.gemm.gemm.a_stride, + /*n_stride=*/fully_connected_op->context.gemm.gemm.gemm.w_stride, + /*cm_stride=*/fully_connected_op->context.gemm.gemm.gemm.cm_stride, + /*cn_stride=*/1 << log2_output_element_size, mr, nr, + /*num_threads=*/pthreadpool_get_threads_count(threadpool), &mc, &nc); #if XNN_MAX_UARCH_TYPES > 1 - if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { - fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_with_uarch; - if (dynamic_quantization) { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_dqgemm; - } else if (is_qp8_ukernel) { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = - (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_qp8gemm; - } else { - fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; - } + if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { + fully_connected_op->compute[0].type = + xnn_parallelization_type_2d_tile_2d_with_uarch; + if (dynamic_quantization) { + fully_connected_op->compute[0].task_2d_tile_2d_with_id = + (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_dqgemm; + } else if (is_qp8_ukernel) { + fully_connected_op->compute[0].task_2d_tile_2d_with_id = + (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_qp8gemm; } else { - fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; - if (dynamic_quantization) { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; - } else if (is_qp8_ukernel) { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; - } else { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; - } + fully_connected_op->compute[0].task_2d_tile_2d_with_id = + (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_gemm; } - #else + } else { fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; if (dynamic_quantization) { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_dqgemm; } else if (is_qp8_ukernel) { fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; } else { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm; } -#endif - fully_connected_op->compute[0].range[0] = batch_size; - fully_connected_op->compute[0].range[1] = output_channels; - fully_connected_op->compute[0].tile[0] = mr; - fully_connected_op->compute[0].tile[1] = nc; + } +#else + fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; + if (dynamic_quantization) { + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_dqgemm; + } else if (is_qp8_ukernel) { + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; + } else { + fully_connected_op->compute[0].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm; + } +#endif // XNN_MAX_UARCH_TYPES > 1 + fully_connected_op->compute[0].range[1] = batch_size; + fully_connected_op->compute[0].range[0] = output_channels; + fully_connected_op->compute[0].tile[1] = mc; + fully_connected_op->compute[0].tile[0] = nc; fully_connected_op->state = xnn_run_state_needs_setup; return xnn_status_success; diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 6fde480accd3..56ef33414bac 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -342,73 +342,56 @@ struct gemm_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_grouped_gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t group_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); +XNN_PRIVATE void xnn_compute_grouped_gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t group_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size); + +XNN_PRIVATE void xnn_compute_grouped_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t group_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size); + +XNN_PRIVATE void xnn_compute_dqgemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size); - XNN_PRIVATE void xnn_compute_grouped_qp8gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t group_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size); +XNN_PRIVATE void xnn_compute_gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size); - XNN_PRIVATE void xnn_compute_dqgemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); +XNN_PRIVATE void xnn_compute_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size); +#if XNN_MAX_UARCH_TYPES > 1 +XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t group_index, size_t nr_block_start, + size_t mr_block_start, size_t nr_block_size, size_t mr_block_size); - XNN_PRIVATE void xnn_compute_gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); +XNN_PRIVATE void xnn_compute_hmp_grouped_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t group_index, size_t mr_block_start, + size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); - XNN_PRIVATE void xnn_compute_qp8gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, - size_t nr_block_size); -#if XNN_MAX_UARCH_TYPES > 1 - XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t group_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_grouped_qp8gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t group_index, size_t mr_block_start, - size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_dqgemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_qp8gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size); - #endif // XNN_MAX_UARCH_TYPES > 1 +XNN_PRIVATE void xnn_compute_hmp_gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size); + +XNN_PRIVATE void xnn_compute_hmp_dqgemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size); + +XNN_PRIVATE void xnn_compute_hmp_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size); +#endif // XNN_MAX_UARCH_TYPES > 1 #endif // Context for Sparse Matrix-Dense Matrix Multiplication. diff --git a/src/xnnpack/microkernel-utils.h b/src/xnnpack/microkernel-utils.h index add9c4dc79a3..c820f39fa604 100644 --- a/src/xnnpack/microkernel-utils.h +++ b/src/xnnpack/microkernel-utils.h @@ -17,10 +17,13 @@ extern "C" { // least this many tiles per thread. #define XNN_GEMM_TILES_PER_THREAD 5 -// Computes the largest `nc`, the largest multiple of `nr` such that there are -// at least five tiles per thread (if `num_threads > 1`). -size_t xnn_gemm_best_nc(size_t num_groups, size_t m, size_t n, size_t mr, - size_t nr, size_t num_threads); +// Compute the optimal tile size (integer multiples of `mr` and `nr`) for a GEMM +// such that the number of tiles is minimized, but at least `min_num_tiles` and +// such that the data needed for each tile fits in either the L1 or L2 cache. +void xnn_gemm_best_tile_size(size_t num_groups, size_t m, size_t n, + size_t m_stride, size_t n_stride, size_t cm_stride, + size_t cn_stride, size_t mr, size_t nr, + size_t num_threads, size_t *mc, size_t *nc); // The total tile size needed to cover kernel_size. XNN_INTERNAL size_t xnn_dwconv_multipass_tile_size( diff --git a/test/microkernel-utils.cc b/test/microkernel-utils.cc index 79ac5971bbee..f227625d0ec0 100644 --- a/test/microkernel-utils.cc +++ b/test/microkernel-utils.cc @@ -13,7 +13,7 @@ #include "xnnpack/microfnptr.h" #include "replicable_random_device.h" -TEST(GEMM_BEST_NC, min_tiles_per_thread) { +TEST(GEMM_BEST_TILE_SIZE, min_tiles_per_thread) { xnnpack::ReplicableRandomDevice rnd; std::uniform_int_distribution rnd_kernel_dim(1, XNN_MAX_MR); std::uniform_int_distribution rnd_tensor_dim(1, 100); @@ -24,39 +24,59 @@ TEST(GEMM_BEST_NC, min_tiles_per_thread) { const size_t mr = rnd_kernel_dim(rnd); const size_t nr = 8 * rnd_kernel_dim(rnd); const size_t m = rnd_tensor_dim(rnd); + const size_t k = rnd_tensor_dim(rnd); const size_t n = nr + rnd_tensor_dim(rnd); const size_t num_threads = rnd_thread_dim(rnd); - const size_t num_tiles_m = divide_round_up(m, mr); const size_t min_num_tiles = XNN_GEMM_TILES_PER_THREAD * num_threads; for (size_t num_groups : {(size_t)1, num_threads, 5 * num_threads, 10 * num_threads}) { - const size_t nc = xnn_gemm_best_nc(num_groups, m, n, mr, nr, num_threads); - - // Check that `nc` is a multiple of `nr` if it is less than `n`. - if (nc < nr) { - EXPECT_EQ(nc % nr, 0) << "Not a multiple of `nr`"; + size_t mc; + size_t nc; + xnn_gemm_best_tile_size( + num_groups, m, n, /*m_stride=*/k * sizeof(float), + /*n_stride=*/k * sizeof(float), /*cm_stride=*/n * sizeof(float), + /*cn_stride=*/sizeof(float), mr, nr, num_threads, &mc, &nc); + + // Check that `mc` and `nc` are multiples of `mr` and `nr` if they are + // less than `m` and `n`, respectively. + if (nc < n) { + EXPECT_EQ(nc % nr, 0) + << "mc=" << nc << " is not a multiple of nr=" << nr; + } + if (mc < m) { + EXPECT_EQ(mc % mr, 0) + << "mc=" << mc << " is not a multiple of mr=" << mr; } - // If an `nc` larger than `nr` was chosen, make sure we still have enough - // tiles. - if (nr < nc) { + // If an `nc` larger than `nr`, or `mc` larger than `mr`, was chosen, make + // sure we still have enough tiles. + if (nr < nc || mr < mc) { + const size_t num_tiles_m = divide_round_up(m, mc); const size_t num_tiles_n = divide_round_up(n, nc); const size_t num_tiles = num_groups * num_tiles_m * num_tiles_n; EXPECT_LE(min_num_tiles, num_tiles) << "Didn't generate enough tiles, num_groups=" << num_groups - << ", m=" << m << ", n=" << n << ", " << "mr=" << mr << " , " - << "nr=" << nr << " , " << "nc=" << nc + << ", m=" << m << ", n=" << n << ", " << "mr=" << mr + << ", nr=" << nr << ", mc=" << mc << ", nc=" << nc << ", num_threads=" << num_threads; } - // Verify that the next-smallest `nc` would increase the number of tiles. + // Verify that the next-smallest `nc` or `mc` would increase the number of + // tiles. if (nr < nc && nc < n) { EXPECT_NE(divide_round_up(n, nc), divide_round_up(n, nc - nr)) << "Failed to get minimal `nc` for num_groups=" << num_groups - << ", m=" << m << ", n=" << n << ", " << "mr=" << mr << " , " - << "nr=" << nr << " , " << "nc=" << nc + << ", m=" << m << ", n=" << n << ", mr=" << mr << ", nr=" << nr + << ", mc=" << mc << ", nc=" << nc + << ", num_threads=" << num_threads; + } + if (mr < mc && mc < m) { + EXPECT_NE(divide_round_up(m, mc), divide_round_up(m, mc - mr)) + << "Failed to get minimal `mc` for num_groups=" << num_groups + << ", m=" << m << ", n=" << n << ", mr=" << mr << ", nr=" << nr + << ", mc=" << mc << ", nc=" << nc << ", num_threads=" << num_threads; } }