From bbe7b22d800c69bc0871014837b7d06e7a1d70f8 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Mon, 12 Aug 2024 14:47:26 +0800 Subject: [PATCH] build_rwkv6: Simplify graph Signed-off-by: Molly Sophia --- src/llama.cpp | 46 ++++++++++++---------------------------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 5e4f491b5edcc3..22bf1f07d3b3cb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8592,40 +8592,18 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6( xxx ); - struct ggml_tensor *mw = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); - mw = ggml_reshape_2d( - ctx, - ggml_set_1d(ctx, mw, ggml_view_1d(ctx, xxx, n_embed * n_tokens, 0), 0), - n_embed, n_tokens - ); - - struct ggml_tensor *mk = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); - mk = ggml_reshape_2d( - ctx, - ggml_set_1d(ctx, mk, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * sizeof(float)), 0), - n_embed, n_tokens - ); - - struct ggml_tensor *mv = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); - mv = ggml_reshape_2d( - ctx, - ggml_set_1d(ctx, mv, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 2 * sizeof(float)), 0), - n_embed, n_tokens - ); - - struct ggml_tensor *mr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); - mr = ggml_reshape_2d( - ctx, - ggml_set_1d(ctx, mr, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 3 * sizeof(float)), 0), - n_embed, n_tokens - ); - - struct ggml_tensor *mg = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); - mg = ggml_reshape_2d( - ctx, - ggml_set_1d(ctx, mg, ggml_view_1d(ctx, xxx, n_embed * n_tokens, n_embed * n_tokens * 4 * sizeof(float)), 0), - n_embed, n_tokens - ); + // struct ggml_tensor *mw = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embed * n_tokens); + // mw = ggml_reshape_2d( + // ctx, + // ggml_set_1d(ctx, mw, ggml_view_1d(ctx, xxx, n_embed * n_tokens, 0), 0), + // n_embed, n_tokens + // ); + + struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], 0); + struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * sizeof(float)); + struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 2 * sizeof(float)); + struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 3 * sizeof(float)); + struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 4 * sizeof(float)); struct ggml_tensor * xw = ggml_add( ctx,