Skip to content

Commit

Permalink
Better tiling for GEMMs.
Browse files Browse the repository at this point in the history
This change does the following:
 * Flip the order in which the GEMM tiles are computed (loop over `m`, then `n`),
 * Change GEMM task functions in `operator-run.c` to allow tiling over multiples of `mr`,
 * Tile GEMMs over multiples of `mr` and `nr` such that we keep the minimum number of tasks required,
 * Choose GEMM tile sizes that fit into the L1 or L2 data cache.

PiperOrigin-RevId: 708308468
  • Loading branch information
gonnet authored and xnnpack-bot committed Dec 20, 2024
1 parent 08185b7 commit f1d038f
Show file tree
Hide file tree
Showing 10 changed files with 509 additions and 305 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,8 @@ xnnpack_cc_library(
hdrs = ["src/xnnpack/microkernel-utils.h"],
deps = [
":common",
":hardware_config",
":logging",
":math",
],
)
Expand Down
130 changes: 115 additions & 15 deletions src/microkernel-utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,126 @@
#include <assert.h>
#include <stddef.h>

#include "xnnpack/common.h"
#include "xnnpack/hardware-config.h"
#include "xnnpack/log.h"
#include "xnnpack/math.h"

size_t xnn_gemm_best_nc(size_t num_groups, size_t m, size_t n, size_t mr,
size_t nr, size_t num_threads) {
size_t nc = n;
if (num_threads > 1) {
const size_t min_num_tiles = num_threads * XNN_GEMM_TILES_PER_THREAD;
const size_t num_tile_rows = divide_round_up(m, mr) * num_groups;
const size_t num_tile_cols = divide_round_up(min_num_tiles, num_tile_rows);

// We are looking for an `nc` that is the smallest integer multiple of `nr`
// such that `divide_round_up(n, nc)` is `num_tile_cols`.
nc = max(1, round_up(n, nr) / (nr * num_tile_cols)) * nr;
while (nr < nc && divide_round_up(n, nc - nr) == divide_round_up(n, nc)) {
nc -= nr;
static bool fits_in_cache(size_t mr, size_t nc, size_t m_stride,
size_t n_stride, size_t cm_stride, size_t cn_stride,
size_t cache_size, size_t cache_line_size) {
// Check if the bytes fit.
const size_t lines_mr = divide_round_up(mr * m_stride, cache_line_size);
const size_t lines_nc = divide_round_up(nc * n_stride, cache_line_size);
const size_t lines_output =
mr * divide_round_up(nc * cn_stride, cache_line_size);
const size_t lines_per_row = lines_mr + lines_nc + lines_output;
if (cache_size < lines_per_row * cache_line_size) {
return false;
}

// Otherwiese, we're good.
return true;
}

void xnn_gemm_best_tile_size(size_t num_groups, size_t m, size_t n,
size_t m_stride, size_t n_stride, size_t cm_stride,
size_t cn_stride, size_t mr, size_t nr,
size_t num_threads, size_t *mc, size_t *nc) {
// Adjust `mr` and `nr` if they are larger than `m` and `n`, respectively.
mr = min(mr, m);
nr = min(nr, n);

// We only care about the number of tiles if we have more than one thread.
const size_t min_num_tiles =
num_threads > 1 ? XNN_GEMM_TILES_PER_THREAD * num_threads : 1;

// Start with a `mr`x`nr` tile.
*mc = mr;
*nc = nr;
size_t best_num_tiles =
divide_round_up(m, *mc) * divide_round_up(n, *nc) * num_groups;

// Select which cache we want the tiles to fit in. Start with L1, and if the
// smallest possible tile won't fit, try L2. If the smallest tile still won't
// fit, then don't try to fit to the cache size.
const struct xnn_hardware_config *hardware_config =
xnn_init_hardware_config();
size_t cache_size = hardware_config->l1_data_cache_bytes;
size_t cache_line_size = hardware_config->l1_data_cache_line_size;
if (XNN_ARCH_X86 || XNN_ARCH_X86_64 ||
(cache_size && !fits_in_cache(mr, nr, m_stride, n_stride, cm_stride,
cn_stride, cache_size, cache_line_size))) {
cache_size = hardware_config->l2_data_cache_bytes;
cache_line_size = hardware_config->l2_data_cache_line_size;
if (cache_size && !fits_in_cache(mr, nr, m_stride, n_stride, cm_stride,
cn_stride, cache_size, cache_line_size)) {
// Don't check for cache fit.
cache_size = 0;
}
}

// Loop over all multiples of `mr`.
for (int i = 1; (i - 1) * mr < m; i++) {
// Skip this `i` if it results in the same number of tiles as `i - 1`.
if (1 < i &&
divide_round_up(m, i * mr) == divide_round_up(m, (i - 1) * mr)) {
continue;
}

// Make sure this value of `i` generates enough tiles.
const size_t num_tiles_m = divide_round_up(m, i * mr);
const size_t num_tiles_n = divide_round_up(n, nr);
const size_t num_tiles = num_tiles_n * num_tiles_m * num_groups;
if (num_tiles < min_num_tiles) {
break;
}

// Loop over all multiples of `nr`.
for (int j = 1; (j - 1) * nr < n; j++) {
// Skip this `j` if it results in the same number of tiles as `j - 1`.
if (1 < j &&
divide_round_up(n, j * nr) == divide_round_up(n, (j - 1) * nr)) {
continue;
}

// If we have at most `mr` rows per group, then there will be no cache
// re-use across tile rows and we don't care about whether the data fits
// in cache or not.
// If, however, we have more than one tile row, then we want the data used
// to compute a tile of size `mr`x`j*nr` to fit in the cache.
if (mr < m && cache_size &&
!fits_in_cache(mr, j * nr, m_stride, n_stride, cm_stride, cn_stride,
cache_size, cache_line_size)) {
break;
}

// Make sure this pair of `i` and `j` generates enough tiles.
const size_t num_tiles_n = divide_round_up(n, j * nr);
const size_t num_tiles = num_tiles_n * num_tiles_m * num_groups;
if (num_tiles < min_num_tiles) {
break;
}

// New best tile size? We define the "best" tiling as the smallest total
// number of tiles, and for tilings with the same number of tiles, we take
// the tiling with the largest `nc`.
if (num_tiles < best_num_tiles ||
(num_tiles == best_num_tiles && *nc < j * nr)) {
*mc = i * mr;
*nc = j * nr;
best_num_tiles = num_tiles;
}
}
nc = min(nc, n);
}

return nc;
// Restrict the resulting `mc` and `nc` to `m` and `n`, respectively.
*mc = min(*mc, m);
*nc = min(*nc, n);
xnn_log_debug(
"Tile size for GEMM with num_groups=%zi, m=%zi, n=%zi and mr=%zi, nr=%zi "
"set to [%zi, %zi]",
num_groups, m, n, mr, nr, *mc, *nc);
}

static size_t dwconv_num_middle_pass(
Expand Down
Loading

0 comments on commit f1d038f

Please sign in to comment.