Skip to content

Commit

Permalink
move ggml_amx_init from ggml.c to ggml-amx/mmq.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
mingfeima committed Aug 14, 2024
1 parent 0b4de32 commit 2c95fa5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
29 changes: 16 additions & 13 deletions ggml/src/ggml-amx/mmq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2363,8 +2363,18 @@ bool ggml_amx_init() {
}

bool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor * dst) {
// load tile config
ggml_tile_config_init();

static thread_local bool is_first_time = true;
if (is_first_time) {
#pragma omp single
{
ggml_amx_init();
}

// load tile config
ggml_tile_config_init();
}
is_first_time = false;

const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
Expand Down Expand Up @@ -2464,7 +2474,7 @@ void ggml_mul_mat_amx(struct ggml_tensor * dst, int nth, int ith, void * wdata,
return;
}

#pragma omp master
#pragma omp single
{
GGML_DISPATCH_QTYPES(TYPE, [&] {
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
Expand All @@ -2479,20 +2489,13 @@ void ggml_mul_mat_amx(struct ggml_tensor * dst, int nth, int ith, void * wdata,
src0->extra = aligned_alloc(64, N * row_size_B);
convert_B_packed_format<type, blck_size>((void *)src0->extra, (const type *)src0->data, N, K);
}
});
}
#pragma omp barrier

const float * A_data = static_cast<const float *>(src1->data);
parallel_for(nth, ith, M, [&](int begin, int end) {
GGML_DISPATCH_QTYPES(TYPE, [&] {
const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
for (int m = begin; m < end; ++m) {
const float * A_data = static_cast<const float *>(src1->data);
for (int m = 0; m < M; ++m) {
from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
}
});
});
#pragma omp barrier
}

GGML_ASSERT(src0->extra != nullptr);
if (M == 1) {
Expand Down
11 changes: 1 addition & 10 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,6 @@ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
float ggml_table_f32_f16[1 << 16];

#if GGML_USE_AMX
// global flag for amx init
static bool ggml_amx_initialized = false;
#endif

GGML_CALL const char * ggml_status_to_string(enum ggml_status status) {
switch (status) {
case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
Expand Down Expand Up @@ -3530,10 +3525,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
}

#if GGML_USE_AMX
ggml_amx_initialized = ggml_amx_init();
#endif

is_first_call = false;
}

Expand Down Expand Up @@ -12334,7 +12325,7 @@ static void ggml_compute_forward_mul_mat(
// compute by src0 rows

#if GGML_USE_AMX
if (ggml_compute_forward_mul_mat_use_amx(dst) && ggml_amx_initialized) {
if (ggml_compute_forward_mul_mat_use_amx(dst)) {
ggml_mul_mat_amx(dst, nth, ith, params->wdata, params->wsize);
return;
}
Expand Down

0 comments on commit 2c95fa5

Please sign in to comment.