Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flip the order in which the GEMM tiles are computed, i.e. instead of traversing the tiles in n direction first, traverse them in m direction, and then in n direction. #7513

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 42 additions & 49 deletions src/operator-run.c
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ void xnn_compute_batched_packw_gemm_goi(

void xnn_compute_hmp_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t group_index, size_t mr_block_start,
size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) {
uint32_t uarch_index, size_t group_index, size_t nr_block_start,
size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) {
const size_t k_scaled = context->k_scaled;
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
Expand Down Expand Up @@ -448,20 +448,17 @@ void xnn_compute_hmp_grouped_gemm(

void xnn_compute_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t group_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) {
size_t group_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size) {
xnn_compute_hmp_grouped_gemm(context, XNN_UARCH_DEFAULT, group_index,
mr_block_start, nr_block_start, mr_block_size,
nr_block_size);
nr_block_start, mr_block_start, nr_block_size,
mr_block_size);
}

void xnn_compute_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
size_t nr_block_start, size_t mr_block_start, size_t nr_block_size,
size_t mr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;

Expand All @@ -480,11 +477,8 @@ void xnn_compute_gemm(

void xnn_compute_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
size_t nr_block_start, size_t mr_block_start, size_t nr_block_size,
size_t mr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;

Expand All @@ -504,8 +498,8 @@ void xnn_compute_dqgemm(

void xnn_compute_hmp_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) {
uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size) {
const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset(
mr_block_start, context->k_scaled, context->mr, context->kr, context->sr);
const size_t cm_stride = context->cm_stride;
Expand All @@ -523,10 +517,10 @@ void xnn_compute_hmp_qp8gemm(

void xnn_compute_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
size_t nr_block_size) {
xnn_compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, mr_block_start,
nr_block_start, mr_block_size, nr_block_size);
size_t nr_block_start, size_t mr_block_start, size_t nr_block_size,
size_t mr_block_size) {
xnn_compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, nr_block_start,
mr_block_start, nr_block_size, mr_block_size);
}

void xnn_compute_spmm(
Expand Down Expand Up @@ -2401,8 +2395,8 @@ void xnn_compute_rope(
#if XNN_MAX_UARCH_TYPES > 1
void xnn_compute_hmp_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size) {
uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;

Expand All @@ -2415,32 +2409,31 @@ void xnn_compute_hmp_gemm(
(void*)((uintptr_t)context->c + mr_block_start * cm_stride +
(nr_block_start << context->log2_csize)),
cm_stride, context->cn_stride, context->fused_params);
}

void xnn_compute_hmp_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size)
{
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
}

context->dq_ukernel.function[uarch_index](
mr_block_size,
nr_block_size,
context->k_scaled,
(const void*) ((uintptr_t) context->a + mr_block_start * a_stride),
a_stride,
(const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride),
(void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)),
cm_stride,
context->cn_stride,
context->fused_params,
(const void*) ((uintptr_t) &context->quantization_params[mr_block_start]));
}
void xnn_compute_hmp_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size) {
const size_t a_stride = context->a_stride;
const size_t cm_stride = context->cm_stride;
const size_t k_scaled = context->k_scaled;
const size_t cn_stride = context->cn_stride;
const uintptr_t a = (uintptr_t)context->a;
const void* packed_w = (const void*)((uintptr_t)context->packed_w +
nr_block_start * context->w_stride);
const uintptr_t c =
(uintptr_t)context->c + (nr_block_start << context->log2_csize);
const void* fused_params = context->fused_params;
const void* quantization_params =
(const void*)((uintptr_t)&context->quantization_params[mr_block_start]);

context->dq_ukernel.function[uarch_index](
mr_block_size, nr_block_size, k_scaled,
(const void*)(a + mr_block_start * a_stride), a_stride,
(const void*)packed_w, (void*)(c + mr_block_start * cm_stride), cm_stride,
cn_stride, fused_params, quantization_params);
}

void xnn_compute_hmp_grouped_batch_igemm(
const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)],
Expand Down
8 changes: 4 additions & 4 deletions src/operators/batch-matrix-multiply-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -670,10 +670,10 @@ static enum xnn_status reshape_batch_matrix_multiply_nc(
(pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_gemm;
#endif
gemm_compute->range[0] = batch_size_c;
gemm_compute->range[1] = m;
gemm_compute->range[2] = n;
gemm_compute->tile[0] = mr;
gemm_compute->tile[1] = nc;
gemm_compute->range[2] = m;
gemm_compute->range[1] = n;
gemm_compute->tile[1] = mr;
gemm_compute->tile[0] = nc;
batch_matrix_multiply_op->state = xnn_run_state_needs_setup;

return xnn_status_success;
Expand Down
16 changes: 8 additions & 8 deletions src/operators/convolution-nhwc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2106,10 +2106,10 @@ static enum xnn_status reshape_gemm(
convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d;
convolution_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
#endif
convolution_op->compute[0].range[0] = batch_output_size;
convolution_op->compute[0].range[1] = group_output_channels;
convolution_op->compute[0].tile[0] = mr;
convolution_op->compute[0].tile[1] = nc;
convolution_op->compute[0].range[1] = batch_output_size;
convolution_op->compute[0].range[0] = group_output_channels;
convolution_op->compute[0].tile[1] = mr;
convolution_op->compute[0].tile[0] = nc;
} else {
#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
Expand All @@ -2124,10 +2124,10 @@ static enum xnn_status reshape_gemm(
convolution_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm;
#endif
convolution_op->compute[0].range[0] = groups;
convolution_op->compute[0].range[1] = batch_output_size;
convolution_op->compute[0].range[2] = group_output_channels;
convolution_op->compute[0].tile[0] = mr;
convolution_op->compute[0].tile[1] = nc;
convolution_op->compute[0].range[2] = batch_output_size;
convolution_op->compute[0].range[1] = group_output_channels;
convolution_op->compute[0].tile[1] = mr;
convolution_op->compute[0].tile[0] = nc;
}
convolution_op->state = xnn_run_state_needs_setup;

Expand Down
36 changes: 21 additions & 15 deletions src/operators/dynamic-fully-connected-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -396,21 +396,27 @@ static enum xnn_status reshape_dynamic_fully_connected_nc(
pthreadpool_get_threads_count(threadpool));

#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d_with_uarch;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm;
} else {
dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
}
#else
dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
#endif
dynamic_fully_connected_op->compute[1].range[0] = batch_size;
dynamic_fully_connected_op->compute[1].range[1] = output_channels;
dynamic_fully_connected_op->compute[1].tile[0] = mr;
dynamic_fully_connected_op->compute[1].tile[1] = nc;
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
dynamic_fully_connected_op->compute[1].type =
xnn_parallelization_type_2d_tile_2d_with_uarch;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id =
(pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_gemm;
} else {
dynamic_fully_connected_op->compute[1].type =
xnn_parallelization_type_2d_tile_2d;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d =
(pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm;
}
#else
dynamic_fully_connected_op->compute[1].type =
xnn_parallelization_type_2d_tile_2d;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d =
(pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm;
#endif
dynamic_fully_connected_op->compute[1].range[1] = batch_size;
dynamic_fully_connected_op->compute[1].range[0] = output_channels;
dynamic_fully_connected_op->compute[1].tile[1] = mr;
dynamic_fully_connected_op->compute[1].tile[0] = nc;
dynamic_fully_connected_op->state = xnn_run_state_needs_setup;

return xnn_status_success;
Expand Down
8 changes: 4 additions & 4 deletions src/operators/fully-connected-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2254,10 +2254,10 @@ static enum xnn_status reshape_fully_connected_nc(
fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm;
}
#endif
fully_connected_op->compute[0].range[0] = batch_size;
fully_connected_op->compute[0].range[1] = output_channels;
fully_connected_op->compute[0].tile[0] = mr;
fully_connected_op->compute[0].tile[1] = nc;
fully_connected_op->compute[0].range[1] = batch_size;
fully_connected_op->compute[0].range[0] = output_channels;
fully_connected_op->compute[0].tile[1] = mr;
fully_connected_op->compute[0].tile[0] = nc;
fully_connected_op->state = xnn_run_state_needs_setup;

return xnn_status_success;
Expand Down
44 changes: 22 additions & 22 deletions src/xnnpack/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,59 +345,59 @@ struct gemm_context {
XNN_PRIVATE void xnn_compute_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t mr_block_start, size_t nr_block_start, size_t mr_block_size,
size_t nr_block_size);
size_t nr_block_start, size_t mr_block_start, size_t nr_block_size,
size_t mr_block_size);
#if XNN_MAX_UARCH_TYPES > 1
XNN_PRIVATE void xnn_compute_hmp_grouped_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t group_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_hmp_gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_hmp_dqgemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index,
size_t mr_block_start,
size_t nr_block_start,
size_t mr_block_size,
size_t nr_block_size);
size_t mr_block_start,
size_t nr_block_size,
size_t mr_block_size);

XNN_PRIVATE void xnn_compute_hmp_qp8gemm(
const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)],
uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start,
size_t mr_block_size, size_t nr_block_size);
uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start,
size_t nr_block_size, size_t mr_block_size);
#endif // XNN_MAX_UARCH_TYPES > 1
#endif

Expand Down