Skip to content

Commit

Permalink
support M1/M2
Browse files Browse the repository at this point in the history
  • Loading branch information
simonJJJ committed Nov 15, 2023
1 parent 9648bf6 commit 6c9cccf
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 2 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ if (GGML_CUBLAS)
set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES ${CUDA_ARCHITECTURES})
endif ()

if (GGML_METAL)
add_compile_definitions(GGML_USE_METAL)
configure_file(third_party/ggml/src/ggml-metal.metal ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
endif ()

file(GLOB CPP_SOURCES
${PROJECT_SOURCE_DIR}/*.h
${PROJECT_SOURCE_DIR}/*.cpp)
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ cuBLAS uses NVIDIA GPU to accelerate BLAS. Add the CMake flag `-DGGML_CUBLAS=ON`
cmake -B build -DGGML_CUBLAS=ON && cmake --build build -j
```

**Metal**

MPS (Metal Performance Shaders) allows computation to run on powerful Apple Silicon GPU. Add the CMake flag `-DGGML_METAL=ON` to enable it.
```sh
cmake -B build -DGGML_METAL=ON && cmake --build build -j
```

## Python Binding

The Python binding provides high-level `chat` and `stream_chat` interface similar to the original Hugging Face Qwen-7B.
Expand Down
22 changes: 20 additions & 2 deletions qwen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,21 @@ auto ggml_graph_compute_helper(std::vector<uninitialized_char> &buf, ggml_cgraph
ggml_graph_compute(graph, &plan);
}

auto ModelContext::init_device_context() -> void {}
auto ModelContext::init_device_context() -> void {
#ifdef GGML_USE_METAL
ctx_metal = make_unique_ggml_metal_context(1);
const size_t max_size = ggml_get_max_tensor_size(ctx_w.get());
void *weight_data = weight_buffer.empty() ? ggml_get_mem_buffer(ctx_w.get()) : (void *)weight_buffer.data();
size_t weight_size = weight_buffer.empty() ? ggml_get_mem_size(ctx_w.get()) : weight_buffer.size();
QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "weights", weight_data, weight_size, max_size));
QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "kv", ggml_get_mem_buffer(ctx_kv.get()),
ggml_get_mem_size(ctx_kv.get()), 0));
void *compute_data = ctx_b ? ggml_get_mem_buffer(ctx_b.get()) : compute_buffer.data();
size_t compute_size = ctx_b ? ggml_get_mem_size(ctx_b.get()) : compute_buffer.size();
QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "compute", compute_data, compute_size, 0));
QWEN_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "scratch", scratch.data, scratch.size, 0));
#endif
}

// ===== streamer =====

Expand Down Expand Up @@ -482,7 +496,7 @@ auto get_num_physical_cores() -> int {
}

auto get_default_num_threads() -> int {
#ifdef GGML_USE_CUBLAS
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_METAL)
return 1;
#else
return std::min(get_num_physical_cores(), 16);
Expand Down Expand Up @@ -583,7 +597,11 @@ auto QwenForCausalLM::generate_next_token(
}

ggml_build_forward_expand(&ctx_.gf, lm_logits);
#ifdef GGML_USE_METAL
ggml_metal_graph_compute(ctx_.ctx_metal.get(), &ctx_.gf);
#else
ggml_graph_compute_helper(ctx_.work_buffer, &ctx_.gf, n_threads);
#endif

int vocab_size = lm_logits->ne[0];
float *next_token_logits = (float *)lm_logits->data;
Expand Down
21 changes: 21 additions & 0 deletions qwen.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
#include <ggml-cuda.h>
#endif

#ifdef GGML_USE_METAL
#include <ggml-metal.h>
#endif

namespace qwen {

class QwenTokenizer;
Expand Down Expand Up @@ -58,6 +62,20 @@ static inline auto make_unique_ggml_context(
return unique_ggml_context_t(ggml_init({mem_size, mem_buffer, no_alloc}));
}

#ifdef GGML_USE_METAL
struct ggml_metal_context_deleter_t {
auto operator()(ggml_metal_context *ctx) const noexcept -> void { ggml_metal_free(ctx); }
};

using unique_ggml_metal_context_t = std::unique_ptr<ggml_metal_context, ggml_metal_context_deleter_t>;

static inline auto make_unique_ggml_metal_context(
int n_cb
) -> unique_ggml_metal_context_t {
return unique_ggml_metal_context_t(ggml_metal_init(n_cb));
}
#endif

struct uninitialized_char {
char m;
uninitialized_char() {}
Expand All @@ -70,6 +88,9 @@ struct ModelContext {
unique_ggml_context_t ctx_w; // weight
unique_ggml_context_t ctx_kv; // kv cache
unique_ggml_context_t ctx_b; // buffer
#ifdef GGML_USE_METAL
unique_ggml_metal_context_t ctx_metal;
#endif
ggml_cgraph gf;
ggml_scratch scratch;
std::vector<uninitialized_char> compute_buffer; // BLAS buffer
Expand Down

0 comments on commit 6c9cccf

Please sign in to comment.