From 5c1d1177d30673cfac46b1b5879808a6ddcb70e0 Mon Sep 17 00:00:00 2001 From: Lucas Nogueira Date: Fri, 15 Nov 2024 14:32:27 -0300 Subject: [PATCH 1/3] Add complete implementation of the classical PCA algorithm with covariance matrix and power iteration with a very simple test file --- Makefile | 7 + .../cvector-generator/cvector-generator.cpp | 3 +- .../mini-tests/test-vanilla-pca.cpp | 116 +++++++ examples/cvector-generator/vanilla_pca.hpp | 314 ++++++++++++++++++ 4 files changed, 438 insertions(+), 2 deletions(-) create mode 100644 examples/cvector-generator/mini-tests/test-vanilla-pca.cpp create mode 100644 examples/cvector-generator/vanilla_pca.hpp diff --git a/Makefile b/Makefile index 87fe795aa8432..d00f3b0571e43 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,7 @@ BUILD_TARGETS = \ llama-tokenize \ llama-vdot \ llama-cvector-generator \ + llama-test-vanilla-pca \ llama-gen-docs \ tests/test-c.o @@ -1479,6 +1480,12 @@ llama-cvector-generator: examples/cvector-generator/cvector-generator.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +# TODO: Move to tests +llama-test-vanilla-pca: examples/cvector-generator/mini-tests/test-vanilla-pca.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + llama-convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index d1731bba64e1b..e5804be2890f6 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -2,8 +2,7 @@ #include "common.h" #include "llama.h" #include "ggml.h" -#include "pca.hpp" -#include "mean.hpp" +#include "vanilla_pca.hpp" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" diff --git a/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp b/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp new file mode 100644 index 0000000000000..6744b623e22b8 --- /dev/null +++ b/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp @@ -0,0 +1,116 @@ + +#include "common.h" +#include "llama.h" +#include "ggml.h" +#include "../vanilla_pca.hpp" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include + +// Function to initialize ggml with optional GPU backend support +struct ggml_context *initialize_ggml_context() { +#ifdef GGML_USE_CUDA + struct ggml_init_params params = { .mem_size = 1024 * 1024, .mem_buffer = NULL, .use_gpu = true }; + printf("Initializing with GPU backend...\n"); +#else + struct ggml_init_params params = { .mem_size = 1024 * 1024, .mem_buffer = NULL }; + printf("Initializing with CPU backend...\n"); +#endif + return ggml_init(params); +} + +// Helper function to create a tensor from a matrix +struct ggml_tensor *create_tensor(struct ggml_context *ctx, float *data, int rows, int cols) { + struct ggml_tensor *tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols, rows); + memcpy(tensor->data, data, ggml_nbytes(tensor)); + return tensor; +} + +// Function to run PCA and print results +void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) { + struct ggml_tensor *input_tensor = create_tensor(ctx, matrix, rows, cols); + + PCA::pca_params pca_params; + pca_params.n_threads = 8; + pca_params.n_batch = 20; + pca_params.n_iterations = 1000; + pca_params.tolerance = 1e-5; + + PCA::pca_result result; + PCA::run_single_pca(pca_params, input_tensor, result); + + printf("\nPrincipal components:\n"); + float *b = (float *)result.principal_component->data; + for (int i = 0; i < result.principal_component->ne[0]; i++) { + printf("%f ", b[i]); + } + printf("\nEigenvalue: %f\n", result.explained_variance); +} + +int main() { + // Initialize ggml context + struct ggml_context *ctx = initialize_ggml_context(); + if (ctx == NULL) { + printf("Failed to initialize ggml context\n"); + return 1; + } + + // Define matrices + float input_matrix1[16] = { + -0.124132, 0.740341, -0.452462, 0.777050, + 1.045571, -0.342142, -0.926047, -0.512965, + 0.710109, 0.092479, 0.630075, 1.762937, + 0.230954, -0.808937, 1.057424, 0.051361 + }; + + float input_matrix2[100] = { + 440152.493740, 122038.234845, 495176.910111, 34388.521115, 909320.402079, 258779.981600, 662522.284354, 311711.076089, 520068.021178, 546710.279343, + 184854.455526, 969584.627765, 775132.823361, 939498.941564, 894827.350428, 597899.978811, 921874.235023, 88492.502052, 195982.862419, 45227.288911, + 325330.330763, 388677.289689, 271349.031774, 828737.509152, 356753.326694, 280934.509687, 542696.083158, 140924.224975, 802196.980754, 74550.643680, + 986886.936601, 772244.769297, 198715.681534, 5522.117124, 815461.428455, 706857.343848, 729007.168041, 771270.346686, 74044.651734, 358465.728544, + 115869.059525, 863103.425876, 623298.126828, 330898.024853, 63558.350286, 310982.321716, 325183.322027, 729606.178338, 637557.471355, 887212.742576, + 472214.925162, 119594.245938, 713244.787223, 760785.048617, 561277.197569, 770967.179955, 493795.596364, 522732.829382, 427541.018359, 25419.126744, + 107891.426993, 31429.185687, 636410.411264, 314355.981076, 508570.691165, 907566.473926, 249292.229149, 410382.923036, 755551.138543, 228798.165492, + 76979.909829, 289751.452914, 161221.287254, 929697.652343, 808120.379564, 633403.756510, 871460.590188, 803672.076899, 186570.058886, 892558.998490, + 539342.241916, 807440.155164, 896091.299923, 318003.474972, 110051.924528, 227935.162542, 427107.788626, 818014.765922, 860730.583256, 6952.130531, + 510747.302578, 417411.003149, 222107.810471, 119865.367334, 337615.171404, 942909.703913, 323202.932021, 518790.621743, 703018.958895, 363629.602379 + }; + + float input_matrix3[9] = { + 0.374540, 0.950714, 0.731994, + 0.598658, 0.156019, 0.155995, + 0.058084, 0.866176, 0.601115 + }; + + float input_matrix4[9] = { + 10.000000, 0.000000, 0.000000, + 0.000000, 5.000000, 0.000000, + 0.000000, 0.000000, 1.000000 + }; + + // Run PCA for each matrix + printf("Testing Matrix 1:\n"); + run_pca_test(ctx, input_matrix1, 4, 4); + + printf("\nTesting Matrix 2:\n"); + run_pca_test(ctx, input_matrix2, 10, 10); + + printf("\nTesting Matrix 3:\n"); + run_pca_test(ctx, input_matrix3, 3, 3); + + printf("\nTesting Matrix 4:\n"); + run_pca_test(ctx, input_matrix4, 3, 3); + + // Cleanup + ggml_free(ctx); + return 0; +} + diff --git a/examples/cvector-generator/vanilla_pca.hpp b/examples/cvector-generator/vanilla_pca.hpp new file mode 100644 index 0000000000000..b4350db820b7e --- /dev/null +++ b/examples/cvector-generator/vanilla_pca.hpp @@ -0,0 +1,314 @@ +#include "common.h" +#include "llama.h" +#include "ggml.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define DEBUG_POS 5 + +static void print_debug_tensor(struct ggml_tensor * t, bool with_data = true) { + printf("%s: %s (%s): [%d, %d]\n", __func__, t->name, ggml_type_name(t->type), (int) t->ne[0], (int) t->ne[1]); + if (!with_data) return; + printf("%s: %s[0] = [", __func__, t->name); + for (size_t i = 0; i <= DEBUG_POS; i++) { + printf(" %f,", ggml_get_f32_nd(t, i, 0, 0, 0)); + } + printf(" ... ]\n"); +} + +// begin vanilla pca namespace +namespace PCA { + +// input params for PCA computations +struct pca_params { + int n_threads = 1; + int n_batch = 20; // number of iterations do to in one batch. larger the batch, more memory is used + int n_iterations = 1000; + float tolerance = 1e-7; +}; + +// result from each iteration +struct pca_result { + struct ggml_tensor * principal_component; // eigenvectors of the covariance matrix + float explained_variance; // eigenvalues of the covariance matrix +}; + +void compute_covariance(struct pca_params &pca_params, + struct ggml_tensor * X, + struct ggml_tensor * covariance, + struct ggml_backend * backend) { + + // Memory allocation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx = NULL; + struct ggml_init_params ctx_params = { + ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + NULL, + true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx = ggml_init(ctx_params); + gf = ggml_new_graph(ctx); + + // Step 0: Transpose the input because of row-major + X = ggml_cont(ctx, ggml_transpose(ctx, X)); + + // Step 1: Compute the mean for each feature + struct ggml_tensor * mean = ggml_repeat(ctx, ggml_mean(ctx, X), X); // mean with trick to make it easier to sub + struct ggml_tensor * centered_data = ggml_sub(ctx, X, mean); + + // Step 2: Compute the covariance matrix + struct ggml_tensor * cov = ggml_mul_mat(ctx, centered_data, centered_data); // C = X * X^T + cov = ggml_scale(ctx, cov, 1.0/(X->ne[0]-1)); + ggml_build_forward_expand(gf, cov); + + // Step 3: Create ggml_gallocr for graph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + // Step 4: Check if CPU and compute the result of the graph + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); + } + ggml_backend_graph_compute(backend, gf); + + // Step 5: Store covariance matrix in the data pointer + struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); + float * result_data = (float*) malloc(ggml_nbytes(result)); + ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); + covariance->data = result_data; + + // Step 6: Free memory + ggml_gallocr_free(allocr); + ggml_free(ctx); +} + +static void compute_cross_covariance(struct pca_params &pca_params, + struct ggml_tensor * A, + struct ggml_tensor * B, + struct ggml_tensor * cross_covariance, + struct ggml_backend * backend) { + + // Memory allocation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx = NULL; + struct ggml_init_params ctx_params = { + ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + NULL, + true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx = ggml_init(ctx_params); + gf = ggml_new_graph(ctx); + + // Step 1: Compute matrices of cross_covariance + struct ggml_tensor * AT = ggml_cont(ctx, ggml_transpose(ctx, A)); + struct ggml_tensor * BT = ggml_cont(ctx, ggml_transpose(ctx, B)); + struct ggml_tensor * AT_B = ggml_mul_mat(ctx, AT, BT); + struct ggml_tensor * BT_A = ggml_cont(ctx, ggml_transpose(ctx, AT_B)); + + // Step 2: Compute the covariance matrix + struct ggml_tensor * cross_cov = ggml_add(ctx, AT_B, BT_A); + cross_cov = ggml_scale(ctx, cross_cov, 0.5); + ggml_build_forward_expand(gf, cross_cov); + + // Step 3: Create ggml_gallocr for graph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + // Step 4: Check if CPU and compute the result of the graph + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); + } + ggml_backend_graph_compute(backend, gf); + + // Step 5: Store covariance matrix in the data pointer + struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); + float * result_data = (float*) malloc(ggml_nbytes(result)); + ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); + cross_covariance->data = result_data; + + // Step 6: Free memory + ggml_gallocr_free(allocr); + ggml_free(ctx); +} + +// Find the dominant eigenvector of tensor M +static void power_iteration(struct pca_params &pca_params, + struct ggml_tensor * M, + struct pca_result &result, + struct ggml_backend * backend) { + + int m = M->ne[1]; + + // Initialize random vector + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(-1.0f, 1.0f); + float * b = (float*) malloc(m * sizeof(float)); + for (int i = 0; i < m; i++) { + b[i] = dist(gen); + }; + float eigenvalue = 0; + + // Iterate + int n_rounds = pca_params.n_iterations / pca_params.n_batch; + for(int i = 0; i < n_rounds; i++) { + + // Memory allocation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx = NULL; + struct ggml_init_params params = { + ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + NULL, + true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx = ggml_init(params); + gf = ggml_new_graph(ctx); + + // Fill current eigen vector + struct ggml_tensor * e_curr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m); + struct ggml_tensor * e_prev = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m); + + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + + ggml_backend_tensor_set(e_curr, b, 0, ggml_nbytes(e_curr)); + ggml_backend_tensor_set(e_prev, b, 0, ggml_nbytes(e_curr)); + + struct ggml_tensor * e_next = NULL; + struct ggml_tensor * e_norm = NULL; + for(int j = 0; j < pca_params.n_batch; j++) { + // Compute next candidate vector multiplying M with the current vector + e_next = ggml_mul_mat(ctx, M, e_curr); + + // Compute the norm of the new vector (and normalize it) + // this will give us the next eigenvector and eigenvalue + e_norm = ggml_sqrt_inplace(ctx, ggml_sum_rows(ctx, ggml_sqr(ctx, e_next))); + e_curr = ggml_div_inplace(ctx, e_next, e_norm); + ggml_format_name(e_norm, "eigenvalue_%d", j); + ggml_format_name(e_curr, "eigenvector_%d", j); + + // Update graph + ggml_build_forward_expand(gf, e_curr); + } + + // Compute the similarity between the current eigenvector and the previous (dot product) + struct ggml_tensor * similarity = ggml_mul_mat(ctx, e_curr, e_prev); + ggml_build_forward_expand(gf, similarity); + + // Create ggml_gallocr for graph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + // Check if CPU and compute the result of the graph + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); + } + ggml_status graph_status = ggml_backend_graph_compute(backend, gf); + + // Get graph results (eigenvector and eigenvalue) and store it in b and eigenvalue + if(graph_status == GGML_STATUS_SUCCESS){ + + // Similarity is the last node in the graph + struct ggml_tensor * similarity_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); + float similarity = (float)((float*) similarity_tensor->data)[0]; + + // Eigenvector is the second last node in the graph + // struct ggml_tensor * eigenvector_tensor = gf->nodes[gf->n_nodes-2]; + struct ggml_tensor * eigenvector_tensor = ggml_graph_node(gf,ggml_graph_n_nodes(gf)-2); + float * eigenvector_data = (float*) malloc(ggml_nbytes(eigenvector_tensor)); + ggml_backend_tensor_get(eigenvector_tensor, eigenvector_data, 0, ggml_nbytes(eigenvector_tensor)); + b = eigenvector_data; + + // Eigenvalue computation is 1 operation before eigenvector computation + // struct ggml_tensor * eigenvalue_tensor = gf->nodes[gf->n_nodes-3]; + struct ggml_tensor * eigenvalue_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-3); + eigenvalue = (float)((float*) eigenvalue_tensor->data)[0]; + + // Check if the similarity is close enough to 1, if so we converged and should break + if(1 - similarity < pca_params.tolerance) + break; + } + + // Free memory + ggml_gallocr_free(allocr); + ggml_free(ctx); + } + + // Store result + result.principal_component->data = b; + result.explained_variance = eigenvalue; + return; +} + +static void run_single_pca(struct pca_params &pca_params, + struct ggml_tensor * X, + struct pca_result &result + ) { + + ggml_set_name(X, "input_tensor"); + + int m = X->ne[1]; // Number of features + + // Step 1. Initialize GGML Backend + ggml_backend_t backend = NULL; + #ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + backend = ggml_backend_cuda_init(0); // init device 0 + if (!backend) { fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } + #endif + // If there aren't GPU Backends fallback to CPU backend + if (!backend) { backend = ggml_backend_cpu_init(); } + + // Compute the context size needed + size_t ctx_size = 2 * ggml_tensor_overhead(); + + // Step 2. Initialize GGML Context + struct ggml_init_params ctx_params { + ctx_size, // mem_size + NULL, // mem_buffer + true, // no_alloc + }; + struct ggml_context * ctx = ggml_init(ctx_params); + + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + + // Step 3. Compute the data covariance matrix + struct ggml_tensor * covariance = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, m, m); + ggml_set_name(covariance, "covariance_tensor"); + compute_covariance(pca_params, X, covariance, backend); + + // Step 4. Power iteration + result.principal_component = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m); + power_iteration(pca_params, covariance, result, backend); + + // Free ggml context and backend + ggml_free(ctx); + ggml_backend_free(backend); +} + + +static void run_pca( + struct pca_params & params, + const std::vector & v_input, // shape of v_input[0]: [n_samples, n_embd] + const std::vector & v_output) { + + for (size_t i = 0; i < v_input.size(); i++) { + struct pca_result result; + run_single_pca(params, v_input[i], result); + ggml_backend_tensor_get(result.principal_component, v_output[i]->data, 0, ggml_nbytes(result.principal_component)); + } +} + +// end namespace +} From 1840df1b58967245b11e2a2b1b1fb3f057472c49 Mon Sep 17 00:00:00 2001 From: Lucas Nogueira Date: Sat, 16 Nov 2024 01:13:46 -0300 Subject: [PATCH 2/3] Apply suggestions from the PR: employ CPU buffers to copy results, use correct ctx_size and add GGML_ASSERT to check v_output --- .../cvector-generator/cvector-generator.cpp | 4 +- .../mini-tests/test-vanilla-pca.cpp | 56 +- examples/cvector-generator/pca.hpp | 499 +++++++++--------- 3 files changed, 288 insertions(+), 271 deletions(-) diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index e5804be2890f6..e7c924fb4b59e 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -2,7 +2,9 @@ #include "common.h" #include "llama.h" #include "ggml.h" -#include "vanilla_pca.hpp" + +#include "mean.hpp" +#include "pca.hpp" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" diff --git a/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp b/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp index 6744b623e22b8..72405762b9f0a 100644 --- a/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp +++ b/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp @@ -2,7 +2,7 @@ #include "common.h" #include "llama.h" #include "ggml.h" -#include "../vanilla_pca.hpp" +#include "../pca.hpp" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" @@ -15,28 +15,11 @@ #include #include -// Function to initialize ggml with optional GPU backend support -struct ggml_context *initialize_ggml_context() { -#ifdef GGML_USE_CUDA - struct ggml_init_params params = { .mem_size = 1024 * 1024, .mem_buffer = NULL, .use_gpu = true }; - printf("Initializing with GPU backend...\n"); -#else - struct ggml_init_params params = { .mem_size = 1024 * 1024, .mem_buffer = NULL }; - printf("Initializing with CPU backend...\n"); -#endif - return ggml_init(params); -} - -// Helper function to create a tensor from a matrix -struct ggml_tensor *create_tensor(struct ggml_context *ctx, float *data, int rows, int cols) { - struct ggml_tensor *tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols, rows); - memcpy(tensor->data, data, ggml_nbytes(tensor)); - return tensor; -} - // Function to run PCA and print results -void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) { - struct ggml_tensor *input_tensor = create_tensor(ctx, matrix, rows, cols); +static void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) { + // struct ggml_tensor *input_tensor = create_tensor(ctx, matrix, rows, cols); + struct ggml_tensor *input_tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, rows, cols); + memcpy(input_tensor->data, matrix, rows * cols * sizeof(float)); PCA::pca_params pca_params; pca_params.n_threads = 8; @@ -44,20 +27,37 @@ void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) { pca_params.n_iterations = 1000; pca_params.tolerance = 1e-5; - PCA::pca_result result; + PCA::pca_result result = {NULL, 0}; PCA::run_single_pca(pca_params, input_tensor, result); - printf("\nPrincipal components:\n"); - float *b = (float *)result.principal_component->data; - for (int i = 0; i < result.principal_component->ne[0]; i++) { - printf("%f ", b[i]); + printf("Principal components:\n"); + for (int i = 0; i < cols; i++) { + printf("%f ", result.principal_component[i]); } printf("\nEigenvalue: %f\n", result.explained_variance); + + free(result.principal_component); } int main() { // Initialize ggml context - struct ggml_context *ctx = initialize_ggml_context(); + size_t ctx_size = 0; + ctx_size += 4 * 4 * ggml_type_size(GGML_TYPE_F32); + ctx_size += 10 * 10 * ggml_type_size(GGML_TYPE_F32); + ctx_size += 3 * 3 * ggml_type_size(GGML_TYPE_F32); + ctx_size += 3 * 3 * ggml_type_size(GGML_TYPE_F32); + ctx_size += 4 * ggml_tensor_overhead(); + ctx_size += 1024; + + // Step 2. Initialize GGML Context + struct ggml_init_params ctx_params { + ctx_size, // mem_size + NULL, // mem_buffer + false, // no_alloc + }; + struct ggml_context * ctx = ggml_init(ctx_params); + + if (ctx == NULL) { printf("Failed to initialize ggml context\n"); return 1; diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index f6e307fbc4970..7a84ac1d569bb 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -6,15 +6,15 @@ #include "ggml-cuda.h" #endif -#ifdef GGML_USE_METAL -#include "ggml-metal.h" -#endif - #include #include #include #include +#include #include +#include +#include +#include #define DEBUG_POS 5 @@ -28,288 +28,303 @@ static void print_debug_tensor(struct ggml_tensor * t, bool with_data = true) { printf(" ... ]\n"); } +// begin vanilla pca namespace namespace PCA { // input params for PCA computations struct pca_params { - int n_threads = 1; - int n_batch = 20; // number of iterations do to in one batch. larger the batch, more memory is used + int n_threads = 1; + int n_batch = 20; // number of iterations do to in one batch. larger the batch, more memory is used int n_iterations = 1000; - float tolerance = 1e-7; - - // for debugging - int i_layer = 0; - int n_layers = 0; + float tolerance = 1e-7; }; // result from each iteration struct pca_result { - struct ggml_tensor * calculated_square = NULL; - std::vector eigenvectors; - std::vector distances; + float * principal_component; // eigenvectors of the covariance matrix + float explained_variance; // eigenvalues of the covariance matrix }; -struct pca_model { - ggml_backend_t backend = NULL; - ggml_backend_buffer_t buffer; - struct ggml_context * ctx; // context to compute graph on target device - struct ggml_context * ctx_host; // host context to store results +static void compute_covariance(struct pca_params &pca_params, + struct ggml_tensor * X, + float * covariance, + struct ggml_backend * backend) { + + size_t ctx_size = 0; + ctx_size += 7 * X->ne[0] * X->ne[1] * ggml_type_size(GGML_TYPE_F32); + ctx_size += 7 * ggml_tensor_overhead(); + ctx_size += ggml_graph_overhead(); + ctx_size += 1024; + + // Memory allocation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx = NULL; + struct ggml_init_params ctx_params = { + ctx_size, + NULL, + true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx = ggml_init(ctx_params); + gf = ggml_new_graph(ctx); - // tensors on target device - struct ggml_tensor * dev_input; - struct ggml_tensor * dev_square; - struct ggml_tensor * dev_eigenvector; + // Step 0: Transpose the input because of row-major + X = ggml_cont(ctx, ggml_transpose(ctx, X)); - pca_model(struct ggml_tensor * t_input) { -#ifdef GGML_USE_CUDA - fprintf(stderr, "%s: using CUDA backend\n", __func__); - backend = ggml_backend_cuda_init(0); // init device 0 - if (!backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } -#endif + // Step 1: Compute the mean for each feature + struct ggml_tensor * mean = ggml_repeat(ctx, ggml_mean(ctx, X), X); // mean with trick to make it easier to sub + struct ggml_tensor * centered_data = ggml_sub(ctx, X, mean); -// TODO: enable Metal support when support for GGML_OP_SQRT is added -// #ifdef GGML_USE_METAL -// fprintf(stderr, "%s: using Metal backend\n", __func__); -// backend = ggml_backend_metal_init(); -// if (!backend) { -// fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); -// } -// #endif - - // if there aren't GPU Backends fallback to CPU backend - if (!backend) { - backend = ggml_backend_cpu_init(); - } + // Step 2: Compute the covariance matrix + struct ggml_tensor * cov = ggml_mul_mat(ctx, centered_data, centered_data); // C = X * X^T + cov = ggml_scale(ctx, cov, 1.0/(X->ne[0]-1)); + ggml_build_forward_expand(gf, cov); - const int num_tensors = 4; - struct ggml_init_params params { - /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - ctx = ggml_init(params); + // Step 3: Create ggml_gallocr for graph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); - auto n_samples = t_input->ne[0]; - auto n_embd = t_input->ne[1]; - - dev_input = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_samples, n_embd); - dev_square = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); - dev_eigenvector = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); - - ggml_set_name(dev_input, "dev_input"); - ggml_set_name(dev_square, "dev_square"); - ggml_set_name(dev_eigenvector, "dev_eigenvector"); - buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); - ggml_backend_tensor_set(dev_input, t_input->data, 0, ggml_nbytes(t_input)); - - // initialize eigenvector to random normalized vector - { - std::vector random_vec(ggml_nelements(dev_eigenvector), 0.0); - std::default_random_engine generator(static_cast(std::time(0))); - std::uniform_real_distribution distribution(0.0, 1.0); - float sum_sqr = 0.0; // for normalizing random_vec - for (size_t i = 0; i < random_vec.size(); ++i) { - float f = distribution(generator); - sum_sqr += f * f; - random_vec[i] = f; - } - // normalize it - float random_vec_norm = std::sqrt(sum_sqr); - for (size_t i = 0; i < random_vec.size(); ++i) { - random_vec[i] /= random_vec_norm; - } - ggml_backend_tensor_set(dev_eigenvector, random_vec.data(), 0, ggml_nbytes(dev_eigenvector)); - } + // Step 4: Check if CPU and compute the result of the graph + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); } + ggml_backend_graph_compute(backend, gf); - ~pca_model() { - ggml_free(ctx); - ggml_backend_buffer_free(buffer); - ggml_backend_free(backend); - } -}; + // Step 5: Store covariance matrix in the data pointer + struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); + ggml_backend_tensor_get(result, covariance, 0, ggml_nbytes(result)); + + // Step 6: Free memory + ggml_gallocr_free(allocr); + ggml_free(ctx); +} -static struct ggml_cgraph * build_graph_piter( - const struct pca_params & params, - const pca_model & model, - bool calc_square = false) { - GGML_ASSERT(params.n_batch > 0); - // TODO: buf_size must be able to scale with params.n_batch - static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); - - struct ggml_init_params params0 = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() +static void compute_cross_covariance(struct pca_params &pca_params, + struct ggml_tensor * A, + struct ggml_tensor * B, + float * cross_covariance, + struct ggml_backend * backend) { + + size_t ctx_size = 0; + ctx_size += 9 * A->ne[0] * B->ne[1] * ggml_type_size(GGML_TYPE_F32); + ctx_size += 9 * ggml_tensor_overhead(); + ctx_size += ggml_graph_overhead(); + ctx_size += 1024; + + // Memory allocation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx = NULL; + struct ggml_init_params ctx_params = { + ctx_size, + NULL, + true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() }; - // create a temporally context to build the graph - struct ggml_context * ctx0 = ggml_init(params0); - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - // turn v_diff_original into square matrix if needed - struct ggml_tensor * tmp_square; - if (calc_square) { - tmp_square = ggml_mul_mat(ctx0, model.dev_input, model.dev_input); - ggml_set_name(tmp_square, "tmp_square"); - } + ctx = ggml_init(ctx_params); + gf = ggml_new_graph(ctx); + + // Step 1: Compute matrices of cross_covariance + struct ggml_tensor * AT = ggml_cont(ctx, ggml_transpose(ctx, A)); + struct ggml_tensor * BT = ggml_cont(ctx, ggml_transpose(ctx, B)); + struct ggml_tensor * AT_B = ggml_mul_mat(ctx, AT, BT); + struct ggml_tensor * BT_A = ggml_cont(ctx, ggml_transpose(ctx, AT_B)); + + // Step 2: Compute the covariance matrix + struct ggml_tensor * cross_cov = ggml_add(ctx, AT_B, BT_A); + cross_cov = ggml_scale(ctx, cross_cov, 0.5); + ggml_build_forward_expand(gf, cross_cov); + + // Step 3: Create ggml_gallocr for graph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); - struct ggml_tensor * b_tensor; - struct ggml_tensor * distance; - struct ggml_tensor * old_eigen = model.dev_eigenvector; - struct ggml_tensor * input_square = calc_square ? tmp_square : model.dev_square; - - for (int i = 0; i < params.n_batch; ++i) { - // b_tensor = square * eigenvector^T - b_tensor = ggml_mul_mat(ctx0, input_square, old_eigen); - ggml_set_name(b_tensor, "b_tensor"); - - // normalize - b_tensor = ggml_div_inplace(ctx0, - b_tensor, - ggml_sqrt_inplace(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, b_tensor))) - ); - ggml_format_name(b_tensor, "b_tensor_norm_%d", i); - - // calculate distance(new eigenvector - old eigenvector) - // we don't use ggml_sub because it may not be implemented on GPU backend - struct ggml_tensor * new_sub_old = ggml_add(ctx0, old_eigen, ggml_scale(ctx0, b_tensor, -1)); - distance = ggml_sqrt_inplace(ctx0, - ggml_sum_rows(ctx0, ggml_sqr_inplace(ctx0, new_sub_old))); - ggml_format_name(distance, "distance_%d", i); - - old_eigen = b_tensor; - - // build operations nodes - ggml_build_forward_expand(gf, distance); + // Step 4: Check if CPU and compute the result of the graph + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); } + ggml_backend_graph_compute(backend, gf); + + // Step 5: Store covariance matrix in the data pointer + struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); + ggml_backend_tensor_get(result, cross_covariance, 0, ggml_nbytes(result)); - // delete the temporally context used to build the graph - ggml_free(ctx0); - return gf; + // Step 6: Free memory + ggml_gallocr_free(allocr); + ggml_free(ctx); } -static ggml_status compute_piter( - const struct pca_params & params, - const pca_model & model, - struct ggml_cgraph * gf, - ggml_gallocr_t allocr, - struct pca_result & result) { - // allocate tensors - ggml_gallocr_alloc_graph(allocr, gf); +// Find the dominant eigenvector of tensor M +static void power_iteration(struct pca_params &pca_params, + struct ggml_tensor * M, + struct pca_result &result, + struct ggml_backend * backend) { + + int m = M->ne[1]; + + // Initialize random vector + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dist(-1.0f, 1.0f); + float * b = result.principal_component; + for (int i = 0; i < m; i++) { + b[i] = dist(gen); + }; + float eigenvalue = 0; + + // Iterate + int n_rounds = pca_params.n_iterations / pca_params.n_batch; + for(int i = 0; i < n_rounds; i++) { + + // Memory allocation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx = NULL; + struct ggml_init_params params = { + ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + NULL, + true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx = ggml_init(params); + gf = ggml_new_graph(ctx); - if (ggml_backend_is_cpu(model.backend)) { - ggml_backend_cpu_set_n_threads(model.backend, params.n_threads); - } + // Fill current eigen vector + struct ggml_tensor * e_curr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m); + struct ggml_tensor * e_prev = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m); - ggml_status res = ggml_backend_graph_compute(model.backend, gf); - if (res == GGML_STATUS_SUCCESS) { - auto extract_i = [](std::string prefix, std::string str) -> int { - int i = -1; - if (str.rfind(prefix, 0) == 0) { - sscanf(str.c_str(), (prefix + "%d").c_str(), &i); - } - return i; - }; - result.calculated_square = NULL; - result.eigenvectors.clear(); - result.distances.clear(); - result.eigenvectors.resize(params.n_batch); - result.distances.resize(params.n_batch); - // get output nodes - for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) { - auto node = ggml_graph_node(gf, i); - int iter = -1; - // find b_tensor (without copying data from device) - if ((iter = extract_i("b_tensor_norm_", node->name)) > -1) { - result.eigenvectors[iter] = node; - } - // find distances, then copy data from device - if ((iter = extract_i("distance_", node->name)) > -1) { - float d; - ggml_backend_tensor_get(node, &d, 0, sizeof(float)); - result.distances[iter] = d; - // std::cout << node->name << " = " << d << "\n"; - } - // find tmp_square if it exists (without copying data from device) - if (std::string(node->name) == "tmp_square") { - result.calculated_square = node; - } - } - } - return res; -} + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); -static void power_iteration( - const struct pca_params & params, - struct ggml_tensor * input, // shape of input: [n_samples, n_embd] - struct ggml_tensor * output) { - //printf("in power iteration\n"); - struct pca_model model(input); - - ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); - struct pca_result result; - struct ggml_tensor * last_eigenvector = NULL; - - int n_iters = params.n_iterations / params.n_batch; // more batch, fewer iterations - for (int iter = 0; iter < n_iters; ++iter) { - bool calc_square = (iter == 0); // only need to calculate square for first iteration - struct ggml_cgraph * gf = build_graph_piter(params, model, calc_square); - // ggml_graph_dump_dot(gf, nullptr, "/tmp/_cgraph.dot"); - compute_piter(params, model, gf, allocr, result); - - for (size_t k = 0; k < result.distances.size(); ++k) { - last_eigenvector = result.eigenvectors[k]; - if (result.distances[k] < params.tolerance) { - break; // done - } + ggml_backend_tensor_set(e_curr, b, 0, ggml_nbytes(e_curr)); + ggml_backend_tensor_set(e_prev, b, 0, ggml_nbytes(e_curr)); + + struct ggml_tensor * e_next = NULL; + struct ggml_tensor * e_norm = NULL; + for(int j = 0; j < pca_params.n_batch; j++) { + // Compute next candidate vector multiplying M with the current vector + e_next = ggml_mul_mat(ctx, M, e_curr); + + // Compute the norm of the new vector (and normalize it) + // this will give us the next eigenvector and eigenvalue + e_norm = ggml_sqrt_inplace(ctx, ggml_sum_rows(ctx, ggml_sqr(ctx, e_next))); + e_curr = ggml_div_inplace(ctx, e_next, e_norm); + ggml_format_name(e_norm, "eigenvalue_%d", j); + ggml_format_name(e_curr, "eigenvector_%d", j); + + // Update graph + ggml_build_forward_expand(gf, e_curr); } - if (calc_square) { - // copy and store the square matrix if needed - GGML_ASSERT(result.calculated_square != NULL); - ggml_backend_tensor_copy(result.calculated_square, model.dev_square); + // Compute the similarity between the current eigenvector and the previous (dot product) + struct ggml_tensor * similarity = ggml_mul_mat(ctx, e_curr, e_prev); + ggml_build_forward_expand(gf, similarity); + + // Create ggml_gallocr for graph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + // Check if CPU and compute the result of the graph + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); } + ggml_status graph_status = ggml_backend_graph_compute(backend, gf); + + // Get graph results (eigenvector and eigenvalue) and store it in b and eigenvalue + if(graph_status == GGML_STATUS_SUCCESS){ + + // Similarity is the last node in the graph + struct ggml_tensor * similarity_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); + float similarity = (float)((float*) similarity_tensor->data)[0]; - { - // copy last eigen vector and store as input for next iteration - GGML_ASSERT(last_eigenvector != NULL); - ggml_backend_tensor_copy(last_eigenvector, model.dev_eigenvector); + // Eigenvector is the second last node in the graph + // struct ggml_tensor * eigenvector_tensor = gf->nodes[gf->n_nodes-2]; + struct ggml_tensor * eigenvector_tensor = ggml_graph_node(gf,ggml_graph_n_nodes(gf)-2); + ggml_backend_tensor_get(eigenvector_tensor, b, 0, ggml_nbytes(eigenvector_tensor)); + + // Eigenvalue computation is 1 operation before eigenvector computation + // struct ggml_tensor * eigenvalue_tensor = gf->nodes[gf->n_nodes-3]; + struct ggml_tensor * eigenvalue_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-3); + eigenvalue = (float)((float*) eigenvalue_tensor->data)[0]; + + // Check if the similarity is close enough to 1, if so we converged and should break + if(1 - similarity < pca_params.tolerance) + break; } - printf("%s: layer %d/%d, iteration: %d / total: %d (batch = %d) ...\n", - __func__, params.i_layer+1, params.n_layers, iter+1, n_iters, params.n_batch); + // Free memory + ggml_backend_buffer_free(buffer); + ggml_gallocr_free(allocr); + ggml_free(ctx); } - // get output tensor - GGML_ASSERT(last_eigenvector); - ggml_backend_tensor_get(last_eigenvector, output->data, 0, ggml_nbytes(last_eigenvector)); - //print_debug_tensor(output); - ggml_gallocr_free(allocr); + // Store result + result.principal_component = b; + result.explained_variance = eigenvalue; + return; +} + +static void run_single_pca(struct pca_params &pca_params, + struct ggml_tensor * X, + struct pca_result &result + ) { - // TODO @ngxson : The output vector is randomly inverted - // Solution: https://github.com/ggerganov/llama.cpp/pull/8069#issuecomment-2185328171 + ggml_set_name(X, "input_tensor"); + + int m = X->ne[1]; // Number of features + + // Step 1. Initialize GGML Backend + ggml_backend_t backend = NULL; + #ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + backend = ggml_backend_cuda_init(0); // init device 0 + if (!backend) { fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } + #endif + + // If there aren't GPU Backends fallback to CPU backend + if (!backend) { backend = ggml_backend_cpu_init(); } + + // Compute the context size needed + size_t ctx_size = 0; + ctx_size += m * m * ggml_type_size(GGML_TYPE_F32); + ctx_size += 1 * ggml_tensor_overhead(); + + // Step 2. Initialize GGML Context + struct ggml_init_params ctx_params { + ctx_size, // mem_size + NULL, // mem_buffer + true, // no_alloc + }; + struct ggml_context * ctx = ggml_init(ctx_params); + + // Step 3. Compute the data covariance matrix + // Using a CPU buffer to copy data from the backend + float * covariance = (float *) malloc(m * m * sizeof(float)); + compute_covariance(pca_params, X, covariance, backend); + + // Create covariance tensor on backend + struct ggml_tensor * covariance_tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, m, m); + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + ggml_backend_tensor_set(covariance_tensor, covariance, 0, ggml_nbytes(covariance_tensor)); + + // Step 4. Power iteration + result.principal_component = (float *) malloc(m * sizeof(float)); + power_iteration(pca_params, covariance_tensor, result, backend); + + // Step 5. Free ggml ctx and backend + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); } static void run_pca( struct pca_params & params, const std::vector & v_input, // shape of v_input[0]: [n_samples, n_embd] const std::vector & v_output) { - printf("%s: Running PCA...\n", __func__); - for (size_t il = 0; il < v_input.size(); ++il) { - - // prepare output vector - struct ggml_tensor * ctrl_out = v_output[il]; - ggml_format_name(ctrl_out, "direction.%ld", il+1); - - // run power_iteration - params.i_layer = il; - params.n_layers = v_input.size(); - power_iteration(params, v_input[il], ctrl_out); - printf("%s: Done layer %d / %d\n", __func__, (int) il+1, (int) v_input.size()); + + for (size_t i = 0; i < v_input.size(); i++) { + // Check shape of tensor inside v_output + GGML_ASSERT(v_output[i]->ne[0] == v_input[i]->ne[1]); + struct pca_result result = {NULL, 0}; + run_single_pca(params, v_input[i], result); + ggml_backend_tensor_set(v_output[i], result.principal_component, 0, ggml_nbytes(v_output[i])); + free(result.principal_component); } } +// end namesace } From 82efaafe9d7eaaa52d56a27916f9e7b6bbae9fe6 Mon Sep 17 00:00:00 2001 From: Lucas Nogueira Date: Sat, 16 Nov 2024 20:16:05 -0300 Subject: [PATCH 3/3] Apply suggestions from the PR: refactor test-vanilla-pca and remove unecessary allocations --- .../mini-tests/test-vanilla-pca.cpp | 56 ++-- examples/cvector-generator/pca.hpp | 3 - examples/cvector-generator/vanilla_pca.hpp | 314 ------------------ 3 files changed, 19 insertions(+), 354 deletions(-) delete mode 100644 examples/cvector-generator/vanilla_pca.hpp diff --git a/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp b/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp index 72405762b9f0a..7a78e2452d89d 100644 --- a/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp +++ b/examples/cvector-generator/mini-tests/test-vanilla-pca.cpp @@ -4,20 +4,25 @@ #include "ggml.h" #include "../pca.hpp" -#ifdef GGML_USE_CUDA -#include "ggml-cuda.h" -#endif - -#ifdef GGML_USE_METAL -#include "ggml-metal.h" -#endif +#include "ggml-cpp.h" +#include "ggml-backend.h" #include #include // Function to run PCA and print results -static void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) { - // struct ggml_tensor *input_tensor = create_tensor(ctx, matrix, rows, cols); +static void run_pca_test(float *matrix, int rows, int cols) { + // Initialize ggml context + size_t ctx_size = 0; + ctx_size += rows * cols * ggml_type_size(GGML_TYPE_F32); + ctx_size += 1 * ggml_tensor_overhead(); + + struct ggml_init_params ctx_params { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + struct ggml_context * ctx = ggml_init(ctx_params); struct ggml_tensor *input_tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, rows, cols); memcpy(input_tensor->data, matrix, rows * cols * sizeof(float)); @@ -37,32 +42,10 @@ static void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int printf("\nEigenvalue: %f\n", result.explained_variance); free(result.principal_component); + ggml_free(ctx); } int main() { - // Initialize ggml context - size_t ctx_size = 0; - ctx_size += 4 * 4 * ggml_type_size(GGML_TYPE_F32); - ctx_size += 10 * 10 * ggml_type_size(GGML_TYPE_F32); - ctx_size += 3 * 3 * ggml_type_size(GGML_TYPE_F32); - ctx_size += 3 * 3 * ggml_type_size(GGML_TYPE_F32); - ctx_size += 4 * ggml_tensor_overhead(); - ctx_size += 1024; - - // Step 2. Initialize GGML Context - struct ggml_init_params ctx_params { - ctx_size, // mem_size - NULL, // mem_buffer - false, // no_alloc - }; - struct ggml_context * ctx = ggml_init(ctx_params); - - - if (ctx == NULL) { - printf("Failed to initialize ggml context\n"); - return 1; - } - // Define matrices float input_matrix1[16] = { -0.124132, 0.740341, -0.452462, 0.777050, @@ -98,19 +81,18 @@ int main() { // Run PCA for each matrix printf("Testing Matrix 1:\n"); - run_pca_test(ctx, input_matrix1, 4, 4); + run_pca_test(input_matrix1, 4, 4); printf("\nTesting Matrix 2:\n"); - run_pca_test(ctx, input_matrix2, 10, 10); + run_pca_test(input_matrix2, 10, 10); printf("\nTesting Matrix 3:\n"); - run_pca_test(ctx, input_matrix3, 3, 3); + run_pca_test(input_matrix3, 3, 3); printf("\nTesting Matrix 4:\n"); - run_pca_test(ctx, input_matrix4, 3, 3); + run_pca_test(input_matrix4, 3, 3); // Cleanup - ggml_free(ctx); return 0; } diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 7a84ac1d569bb..a6ecd22aee3c4 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -51,7 +51,6 @@ static void compute_covariance(struct pca_params &pca_params, struct ggml_backend * backend) { size_t ctx_size = 0; - ctx_size += 7 * X->ne[0] * X->ne[1] * ggml_type_size(GGML_TYPE_F32); ctx_size += 7 * ggml_tensor_overhead(); ctx_size += ggml_graph_overhead(); ctx_size += 1024; @@ -105,7 +104,6 @@ static void compute_cross_covariance(struct pca_params &pca_params, struct ggml_backend * backend) { size_t ctx_size = 0; - ctx_size += 9 * A->ne[0] * B->ne[1] * ggml_type_size(GGML_TYPE_F32); ctx_size += 9 * ggml_tensor_overhead(); ctx_size += ggml_graph_overhead(); ctx_size += 1024; @@ -280,7 +278,6 @@ static void run_single_pca(struct pca_params &pca_params, // Compute the context size needed size_t ctx_size = 0; - ctx_size += m * m * ggml_type_size(GGML_TYPE_F32); ctx_size += 1 * ggml_tensor_overhead(); // Step 2. Initialize GGML Context diff --git a/examples/cvector-generator/vanilla_pca.hpp b/examples/cvector-generator/vanilla_pca.hpp deleted file mode 100644 index b4350db820b7e..0000000000000 --- a/examples/cvector-generator/vanilla_pca.hpp +++ /dev/null @@ -1,314 +0,0 @@ -#include "common.h" -#include "llama.h" -#include "ggml.h" - -#ifdef GGML_USE_CUDA -#include "ggml-cuda.h" -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define DEBUG_POS 5 - -static void print_debug_tensor(struct ggml_tensor * t, bool with_data = true) { - printf("%s: %s (%s): [%d, %d]\n", __func__, t->name, ggml_type_name(t->type), (int) t->ne[0], (int) t->ne[1]); - if (!with_data) return; - printf("%s: %s[0] = [", __func__, t->name); - for (size_t i = 0; i <= DEBUG_POS; i++) { - printf(" %f,", ggml_get_f32_nd(t, i, 0, 0, 0)); - } - printf(" ... ]\n"); -} - -// begin vanilla pca namespace -namespace PCA { - -// input params for PCA computations -struct pca_params { - int n_threads = 1; - int n_batch = 20; // number of iterations do to in one batch. larger the batch, more memory is used - int n_iterations = 1000; - float tolerance = 1e-7; -}; - -// result from each iteration -struct pca_result { - struct ggml_tensor * principal_component; // eigenvectors of the covariance matrix - float explained_variance; // eigenvalues of the covariance matrix -}; - -void compute_covariance(struct pca_params &pca_params, - struct ggml_tensor * X, - struct ggml_tensor * covariance, - struct ggml_backend * backend) { - - // Memory allocation - struct ggml_cgraph * gf = NULL; - struct ggml_context * ctx = NULL; - struct ggml_init_params ctx_params = { - ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), - NULL, - true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - ctx = ggml_init(ctx_params); - gf = ggml_new_graph(ctx); - - // Step 0: Transpose the input because of row-major - X = ggml_cont(ctx, ggml_transpose(ctx, X)); - - // Step 1: Compute the mean for each feature - struct ggml_tensor * mean = ggml_repeat(ctx, ggml_mean(ctx, X), X); // mean with trick to make it easier to sub - struct ggml_tensor * centered_data = ggml_sub(ctx, X, mean); - - // Step 2: Compute the covariance matrix - struct ggml_tensor * cov = ggml_mul_mat(ctx, centered_data, centered_data); // C = X * X^T - cov = ggml_scale(ctx, cov, 1.0/(X->ne[0]-1)); - ggml_build_forward_expand(gf, cov); - - // Step 3: Create ggml_gallocr for graph computation - ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); - ggml_gallocr_alloc_graph(allocr, gf); - - // Step 4: Check if CPU and compute the result of the graph - if (ggml_backend_is_cpu(backend)) { - ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); - } - ggml_backend_graph_compute(backend, gf); - - // Step 5: Store covariance matrix in the data pointer - struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); - float * result_data = (float*) malloc(ggml_nbytes(result)); - ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); - covariance->data = result_data; - - // Step 6: Free memory - ggml_gallocr_free(allocr); - ggml_free(ctx); -} - -static void compute_cross_covariance(struct pca_params &pca_params, - struct ggml_tensor * A, - struct ggml_tensor * B, - struct ggml_tensor * cross_covariance, - struct ggml_backend * backend) { - - // Memory allocation - struct ggml_cgraph * gf = NULL; - struct ggml_context * ctx = NULL; - struct ggml_init_params ctx_params = { - ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), - NULL, - true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - ctx = ggml_init(ctx_params); - gf = ggml_new_graph(ctx); - - // Step 1: Compute matrices of cross_covariance - struct ggml_tensor * AT = ggml_cont(ctx, ggml_transpose(ctx, A)); - struct ggml_tensor * BT = ggml_cont(ctx, ggml_transpose(ctx, B)); - struct ggml_tensor * AT_B = ggml_mul_mat(ctx, AT, BT); - struct ggml_tensor * BT_A = ggml_cont(ctx, ggml_transpose(ctx, AT_B)); - - // Step 2: Compute the covariance matrix - struct ggml_tensor * cross_cov = ggml_add(ctx, AT_B, BT_A); - cross_cov = ggml_scale(ctx, cross_cov, 0.5); - ggml_build_forward_expand(gf, cross_cov); - - // Step 3: Create ggml_gallocr for graph computation - ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); - ggml_gallocr_alloc_graph(allocr, gf); - - // Step 4: Check if CPU and compute the result of the graph - if (ggml_backend_is_cpu(backend)) { - ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); - } - ggml_backend_graph_compute(backend, gf); - - // Step 5: Store covariance matrix in the data pointer - struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); - float * result_data = (float*) malloc(ggml_nbytes(result)); - ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); - cross_covariance->data = result_data; - - // Step 6: Free memory - ggml_gallocr_free(allocr); - ggml_free(ctx); -} - -// Find the dominant eigenvector of tensor M -static void power_iteration(struct pca_params &pca_params, - struct ggml_tensor * M, - struct pca_result &result, - struct ggml_backend * backend) { - - int m = M->ne[1]; - - // Initialize random vector - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_real_distribution dist(-1.0f, 1.0f); - float * b = (float*) malloc(m * sizeof(float)); - for (int i = 0; i < m; i++) { - b[i] = dist(gen); - }; - float eigenvalue = 0; - - // Iterate - int n_rounds = pca_params.n_iterations / pca_params.n_batch; - for(int i = 0; i < n_rounds; i++) { - - // Memory allocation - struct ggml_cgraph * gf = NULL; - struct ggml_context * ctx = NULL; - struct ggml_init_params params = { - ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), - NULL, - true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - ctx = ggml_init(params); - gf = ggml_new_graph(ctx); - - // Fill current eigen vector - struct ggml_tensor * e_curr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m); - struct ggml_tensor * e_prev = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m); - - ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); - - ggml_backend_tensor_set(e_curr, b, 0, ggml_nbytes(e_curr)); - ggml_backend_tensor_set(e_prev, b, 0, ggml_nbytes(e_curr)); - - struct ggml_tensor * e_next = NULL; - struct ggml_tensor * e_norm = NULL; - for(int j = 0; j < pca_params.n_batch; j++) { - // Compute next candidate vector multiplying M with the current vector - e_next = ggml_mul_mat(ctx, M, e_curr); - - // Compute the norm of the new vector (and normalize it) - // this will give us the next eigenvector and eigenvalue - e_norm = ggml_sqrt_inplace(ctx, ggml_sum_rows(ctx, ggml_sqr(ctx, e_next))); - e_curr = ggml_div_inplace(ctx, e_next, e_norm); - ggml_format_name(e_norm, "eigenvalue_%d", j); - ggml_format_name(e_curr, "eigenvector_%d", j); - - // Update graph - ggml_build_forward_expand(gf, e_curr); - } - - // Compute the similarity between the current eigenvector and the previous (dot product) - struct ggml_tensor * similarity = ggml_mul_mat(ctx, e_curr, e_prev); - ggml_build_forward_expand(gf, similarity); - - // Create ggml_gallocr for graph computation - ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); - ggml_gallocr_alloc_graph(allocr, gf); - - // Check if CPU and compute the result of the graph - if (ggml_backend_is_cpu(backend)) { - ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads); - } - ggml_status graph_status = ggml_backend_graph_compute(backend, gf); - - // Get graph results (eigenvector and eigenvalue) and store it in b and eigenvalue - if(graph_status == GGML_STATUS_SUCCESS){ - - // Similarity is the last node in the graph - struct ggml_tensor * similarity_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1); - float similarity = (float)((float*) similarity_tensor->data)[0]; - - // Eigenvector is the second last node in the graph - // struct ggml_tensor * eigenvector_tensor = gf->nodes[gf->n_nodes-2]; - struct ggml_tensor * eigenvector_tensor = ggml_graph_node(gf,ggml_graph_n_nodes(gf)-2); - float * eigenvector_data = (float*) malloc(ggml_nbytes(eigenvector_tensor)); - ggml_backend_tensor_get(eigenvector_tensor, eigenvector_data, 0, ggml_nbytes(eigenvector_tensor)); - b = eigenvector_data; - - // Eigenvalue computation is 1 operation before eigenvector computation - // struct ggml_tensor * eigenvalue_tensor = gf->nodes[gf->n_nodes-3]; - struct ggml_tensor * eigenvalue_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-3); - eigenvalue = (float)((float*) eigenvalue_tensor->data)[0]; - - // Check if the similarity is close enough to 1, if so we converged and should break - if(1 - similarity < pca_params.tolerance) - break; - } - - // Free memory - ggml_gallocr_free(allocr); - ggml_free(ctx); - } - - // Store result - result.principal_component->data = b; - result.explained_variance = eigenvalue; - return; -} - -static void run_single_pca(struct pca_params &pca_params, - struct ggml_tensor * X, - struct pca_result &result - ) { - - ggml_set_name(X, "input_tensor"); - - int m = X->ne[1]; // Number of features - - // Step 1. Initialize GGML Backend - ggml_backend_t backend = NULL; - #ifdef GGML_USE_CUDA - fprintf(stderr, "%s: using CUDA backend\n", __func__); - backend = ggml_backend_cuda_init(0); // init device 0 - if (!backend) { fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } - #endif - // If there aren't GPU Backends fallback to CPU backend - if (!backend) { backend = ggml_backend_cpu_init(); } - - // Compute the context size needed - size_t ctx_size = 2 * ggml_tensor_overhead(); - - // Step 2. Initialize GGML Context - struct ggml_init_params ctx_params { - ctx_size, // mem_size - NULL, // mem_buffer - true, // no_alloc - }; - struct ggml_context * ctx = ggml_init(ctx_params); - - ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); - - // Step 3. Compute the data covariance matrix - struct ggml_tensor * covariance = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, m, m); - ggml_set_name(covariance, "covariance_tensor"); - compute_covariance(pca_params, X, covariance, backend); - - // Step 4. Power iteration - result.principal_component = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m); - power_iteration(pca_params, covariance, result, backend); - - // Free ggml context and backend - ggml_free(ctx); - ggml_backend_free(backend); -} - - -static void run_pca( - struct pca_params & params, - const std::vector & v_input, // shape of v_input[0]: [n_samples, n_embd] - const std::vector & v_output) { - - for (size_t i = 0; i < v_input.size(); i++) { - struct pca_result result; - run_single_pca(params, v_input[i], result); - ggml_backend_tensor_get(result.principal_component, v_output[i]->data, 0, ggml_nbytes(result.principal_component)); - } -} - -// end namespace -}