Skip to content

Commit

Permalink
add uma support
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Dec 16, 2024
1 parent 6ea605d commit 353c5f8
Showing 1 changed file with 53 additions and 16 deletions.
69 changes: 53 additions & 16 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5471,21 +5471,58 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;

vk_buffer d_D = dst_buf_ctx->dev_buffer;
vk_buffer d_K = k_buf_ctx->dev_buffer;
vk_buffer d_V = v_buf_ctx->dev_buffer;
vk_buffer d_R = r_buf_ctx->dev_buffer;
vk_buffer d_TF = tf_buf_ctx->dev_buffer;
vk_buffer d_TD = td_buf_ctx->dev_buffer;
vk_buffer d_State = state_buf_ctx->dev_buffer;

const uint64_t k_offset = vk_tensor_offset(k);
const uint64_t v_offset = vk_tensor_offset(v);
const uint64_t r_offset = vk_tensor_offset(r);
const uint64_t tf_offset = vk_tensor_offset(tf);
const uint64_t td_offset = vk_tensor_offset(td);
const uint64_t state_offset = vk_tensor_offset(state);
const uint64_t dst_offset = vk_tensor_offset(dst);
ggml_vk_sync_buffers(subctx);

vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;

if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);

K_uma = d_K != nullptr;
V_uma = d_V != nullptr;
R_uma = d_R != nullptr;
TF_uma = d_TF != nullptr;
TD_uma = d_TD != nullptr;
STATE_uma = d_State != nullptr;
DST_uma = d_D != nullptr;
}

if (!K_uma) {
d_K = k_buf_ctx->dev_buffer;
k_offset = vk_tensor_offset(k) + k->view_offs;
}
if (!V_uma) {
d_V = v_buf_ctx->dev_buffer;
v_offset = vk_tensor_offset(v) + v->view_offs;
}
if (!R_uma) {
d_R = r_buf_ctx->dev_buffer;
r_offset = vk_tensor_offset(r) + r->view_offs;
}
if (!TF_uma) {
d_TF = tf_buf_ctx->dev_buffer;
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
}
if (!TD_uma) {
d_TD = td_buf_ctx->dev_buffer;
td_offset = vk_tensor_offset(td) + td->view_offs;
}
if (!STATE_uma) {
d_State = state_buf_ctx->dev_buffer;
state_offset = vk_tensor_offset(state) + state->view_offs;
}
if (!DST_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}

const uint64_t k_size = ggml_nbytes(k);
const uint64_t v_size = ggml_nbytes(v);
Expand All @@ -5501,7 +5538,7 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
1
};

ggml_vk_sync_buffers(subctx);

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_K, k_offset, k_size },
vk_subbuffer{ d_V, v_offset, v_size },
Expand Down

0 comments on commit 353c5f8

Please sign in to comment.