Skip to content

Commit

Permalink
Flip the order in which the GEMM tiles are computed, i.e. instead of …
Browse files Browse the repository at this point in the history
…traversing the tiles in `n` direction first, traverse them in `m` direction, and then in `n` direction.

Some minor improvements due to better cache re-use across tiles where the tile's `nc`x`k` weights fit in the lowest-level cache.

PiperOrigin-RevId: 690625625
  • Loading branch information
gonnet authored and xnnpack-bot committed Nov 27, 2024
1 parent d62fa0e commit 0bd416b
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 102 deletions.
91 changes: 42 additions & 49 deletions src/operator-run.c
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ void xnn_compute_batched_packw_gemm_goi(

void xnn_compute_hmp_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t group_index, size_t mr_block_start,
size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) {
uint32_t uarch_index, size_t group_index, size_t nr_block_start,
size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) {
const size_t k_scaled = context->k_scaled;
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
Expand Down Expand Up @@ -448,20 +448,17 @@ void xnn_compute_hmp_grouped_gemm(

void xnn_compute_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t group_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) {
size_t group_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size) {
xnn_compute_hmp_grouped_gemm(context, XNN_UARCH_DEFAULT, group_index,
mr_block_start, nr_block_start, mr_block_size,
nr_block_size);
nr_block_start, mr_block_start, nr_block_size,
mr_block_size);
}

void xnn_compute_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
size_t nr_block_start, size_t mr_block_start, size_t nr_block_size,
size_t mr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;

Expand All @@ -480,11 +477,8 @@ void xnn_compute_gemm(

void xnn_compute_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
size_t nr_block_start, size_t mr_block_start, size_t nr_block_size,
size_t mr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;

Expand All @@ -504,8 +498,8 @@ void xnn_compute_dqgemm(

void xnn_compute_hmp_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) {
uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size) {
const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset(
mr_block_start, context->k_scaled, context->mr, context->kr, context->sr);
const size_t cm_stride = context->cm_stride;
Expand All @@ -523,10 +517,10 @@ void xnn_compute_hmp_qp8gemm(

void xnn_compute_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
size_t nr_block_size) {
xnn_compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, mr_block_start,
nr_block_start, mr_block_size, nr_block_size);
size_t nr_block_start, size_t mr_block_start, size_t nr_block_size,
size_t mr_block_size) {
xnn_compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, nr_block_start,
mr_block_start, nr_block_size, mr_block_size);
}

void xnn_compute_spmm(
Expand Down Expand Up @@ -2368,8 +2362,8 @@ void xnn_compute_rope(
#if XNN_MAX_UARCH_TYPES > 1
void xnn_compute_hmp_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) {
uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;

Expand All @@ -2382,32 +2376,31 @@ void xnn_compute_hmp_gemm(
(void*)((uintptr_t)context->c + mr_block_start * cm_stride +
(nr_block_start << context->log2_csize)),
cm_stride, context->cn_stride, context->fused_params);
}

void xnn_compute_hmp_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
}

context->dq_ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->k_scaled,
(const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
a_stride,
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->fused_params,
(const void*) ((uintptr_t) &context->quantization_params[mr_block_start]));
}
void xnn_compute_hmp_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
const size_t k_scaled = context->k_scaled;
const size_t cn_stride = context->cn_stride;
const uintptr_t a = (uintptr_t)context->a;
const void* packed_w = (const void*)((uintptr_t)context->packed_w +
nr_block_start * context->w_stride);
const uintptr_t c =
(uintptr_t)context->c + (nr_block_start << context->log2_csize);
const void* fused_params = context->fused_params;
const void* quantization_params =
(const void*)((uintptr_t)&context->quantization_params[mr_block_start]);

context->dq_ukernel.function[uarch_index](
mr_block_size, nr_block_size, k_scaled,
(const void*)(a + mr_block_start * a_stride), a_stride,
(const void*)packed_w, (void*)(c + mr_block_start * cm_stride), cm_stride,
cn_stride, fused_params, quantization_params);
}

void xnn_compute_hmp_grouped_batch_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
Expand Down
8 changes: 4 additions & 4 deletions src/operators/batch-matrix-multiply-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,10 @@ static enum xnn_status reshape_batch_matrix_multiply_nc(
(pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_gemm;
#endif
gemm_compute->range[0] = batch_size_c;
gemm_compute->range[1] = m;
gemm_compute->range[2] = n;
gemm_compute->tile[0] = mr;
gemm_compute->tile[1] = nc;
gemm_compute->range[2] = m;
gemm_compute->range[1] = n;
gemm_compute->tile[1] = mr;
gemm_compute->tile[0] = nc;
batch_matrix_multiply_op->state = xnn_run_state_needs_setup;

return xnn_status_success;
Expand Down
16 changes: 8 additions & 8 deletions src/operators/convolution-nhwc.c
Original file line number Diff line number Diff line change
Expand Up @@ -1962,10 +1962,10 @@ static enum xnn_status reshape_gemm(
convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d;
convolution_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
#endif
convolution_op->compute[0].range[0] = batch_output_size;
convolution_op->compute[0].range[1] = group_output_channels;
convolution_op->compute[0].tile[0] = mr;
convolution_op->compute[0].tile[1] = nc;
convolution_op->compute[0].range[1] = batch_output_size;
convolution_op->compute[0].range[0] = group_output_channels;
convolution_op->compute[0].tile[1] = mr;
convolution_op->compute[0].tile[0] = nc;
} else {
#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
Expand All @@ -1980,10 +1980,10 @@ static enum xnn_status reshape_gemm(
convolution_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
#endif
convolution_op->compute[0].range[0] = groups;
convolution_op->compute[0].range[1] = batch_output_size;
convolution_op->compute[0].range[2] = group_output_channels;
convolution_op->compute[0].tile[0] = mr;
convolution_op->compute[0].tile[1] = nc;
convolution_op->compute[0].range[2] = batch_output_size;
convolution_op->compute[0].range[1] = group_output_channels;
convolution_op->compute[0].tile[1] = mr;
convolution_op->compute[0].tile[0] = nc;
}
convolution_op->state = xnn_run_state_needs_setup;

Expand Down
36 changes: 21 additions & 15 deletions src/operators/dynamic-fully-connected-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -396,21 +396,27 @@ static enum xnn_status reshape_dynamic_fully_connected_nc(
pthreadpool_get_threads_count(threadpool));

#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d_with_uarch;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm;
} else {
dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
}
#else
dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
#endif
dynamic_fully_connected_op->compute[1].range[0] = batch_size;
dynamic_fully_connected_op->compute[1].range[1] = output_channels;
dynamic_fully_connected_op->compute[1].tile[0] = mr;
dynamic_fully_connected_op->compute[1].tile[1] = nc;
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
dynamic_fully_connected_op->compute[1].type =
xnn_parallelization_type_2d_tile_2d_with_uarch;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id =
(pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_gemm;
} else {
dynamic_fully_connected_op->compute[1].type =
xnn_parallelization_type_2d_tile_2d;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d =
(pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm;
}
#else
dynamic_fully_connected_op->compute[1].type =
xnn_parallelization_type_2d_tile_2d;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d =
(pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm;
#endif
dynamic_fully_connected_op->compute[1].range[1] = batch_size;
dynamic_fully_connected_op->compute[1].range[0] = output_channels;
dynamic_fully_connected_op->compute[1].tile[1] = mr;
dynamic_fully_connected_op->compute[1].tile[0] = nc;
dynamic_fully_connected_op->state = xnn_run_state_needs_setup;

return xnn_status_success;
Expand Down
8 changes: 4 additions & 4 deletions src/operators/fully-connected-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2019,10 +2019,10 @@ static enum xnn_status reshape_fully_connected_nc(
fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
}
#endif
fully_connected_op->compute[0].range[0] = batch_size;
fully_connected_op->compute[0].range[1] = output_channels;
fully_connected_op->compute[0].tile[0] = mr;
fully_connected_op->compute[0].tile[1] = nc;
fully_connected_op->compute[0].range[1] = batch_size;
fully_connected_op->compute[0].range[0] = output_channels;
fully_connected_op->compute[0].tile[1] = mr;
fully_connected_op->compute[0].tile[0] = nc;
fully_connected_op->state = xnn_run_state_needs_setup;

return xnn_status_success;
Expand Down
44 changes: 22 additions & 22 deletions src/xnnpack/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,59 +345,59 @@ struct gemm_context {
XNN_PRIVATE void xnn_compute_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
size_t nr_block_size);
size_t nr_block_start, size_t mr_block_start, size_t nr_block_size,
size_t mr_block_size);
#if XNN_MAX_UARCH_TYPES > 1
XNN_PRIVATE void xnn_compute_hmp_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_hmp_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_hmp_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_hmp_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size);
uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size);
#endif // XNN_MAX_UARCH_TYPES > 1
#endif

Expand Down

0 comments on commit 0bd416b

Please sign in to comment.