Skip to content

Commit

Permalink
move BLAS to a separate backend
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Jun 4, 2024
1 parent bde7cd3 commit cde46d2
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 219 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ if (LLAMA_BLAS)
add_compile_definitions(GGML_BLAS_USE_MKL)
endif()

set(GGML_HEADERS_BLAS ggml-blas.h)
set(GGML_SOURCES_BLAS ggml-blas.c)

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES})
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
else()
Expand Down Expand Up @@ -1268,6 +1271,7 @@ add_library(ggml OBJECT
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN}
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
)

Expand Down
21 changes: 17 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ ifndef LLAMA_NO_ACCELERATE
MK_CPPFLAGS += -DACCELERATE_NEW_LAPACK
MK_CPPFLAGS += -DACCELERATE_LAPACK_ILP64
MK_LDFLAGS += -framework Accelerate
OBJS += ggml-blas.o
endif
endif # LLAMA_NO_ACCELERATE

Expand All @@ -421,23 +422,35 @@ ifdef LLAMA_OPENBLAS
MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas)
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas)
MK_LDFLAGS += $(shell pkg-config --libs openblas)
OBJS += ggml-blas.o
endif # LLAMA_OPENBLAS

ifndef LLAMA_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
OBJS += sgemm.o
endif
ifdef LLAMA_OPENBLAS64
MK_CPPFLAGS += -DGGML_USE_OPENBLAS $(shell pkg-config --cflags-only-I openblas64)
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas64)
MK_LDFLAGS += $(shell pkg-config --libs openblas64)
OBJS += ggml-blas.o
endif # LLAMA_OPENBLAS64

ifdef LLAMA_BLIS
MK_CPPFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis
MK_LDFLAGS += -lblis -L/usr/local/lib
OBJS += ggml-blas.o
endif # LLAMA_BLIS

ifndef LLAMA_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
OBJS += sgemm.o
endif

ifdef LLAMA_RPC
MK_CPPFLAGS += -DGGML_USE_RPC
OBJS += ggml-rpc.o
endif # LLAMA_RPC

ggml-blas.o: ggml-blas.c ggml-blas.h
$(CC) $(CFLAGS) -c $< -o $@

ifdef LLAMA_CUBLAS
# LLAMA_CUBLAS is deprecated and will be removed in the future
LLAMA_CUDA := 1
Expand Down
86 changes: 61 additions & 25 deletions ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,9 @@ GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_
}

GGML_CALL static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
return ggml_backend_is_cpu(backend);
// HACK
static ggml_guid blas_guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
return ggml_backend_is_cpu(backend) || ggml_guid_matches(backend->guid, &blas_guid);

GGML_UNUSED(buft);
}
Expand Down Expand Up @@ -1097,15 +1099,16 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
return -1;
}

static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor) {
static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
ggml_backend_buffer_t buffer = tensor->buffer;
if (buffer == NULL) {
return -1;
}

// find highest prio backend that supports the buffer type
// find highest prio backend that supports the buffer type and the op
for (int i = 0; i < sched->n_backends; i++) {
if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i]) &&
ggml_backend_supports_op(sched->backends[i], op)) {
return i;
}
}
Expand All @@ -1126,20 +1129,25 @@ static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED
#define GET_CAUSE(node) ""
#endif

//#define DEBUG_PASS1
//#define DEBUG_PASS2
//#define DEBUG_PASS3
//#define DEBUG_PASS4

// returns the backend that should be used for the node based on the current locations
static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
// TODO: use supports_op to check if the backend supports the op

// assign pre-allocated nodes to their backend
int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor);
int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
if (cur_backend_id != -1) {
SET_CAUSE(tensor, "1.dst");
return cur_backend_id;
}

// view_src
if (tensor->view_src != NULL) {
cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src);
cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor);
if (cur_backend_id != -1) {
SET_CAUSE(tensor, "1.vsrc");
return cur_backend_id;
Expand All @@ -1161,7 +1169,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
continue;
}
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src);
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
// check if a backend with higher prio wants to offload the op
if (src_backend_id == sched->n_backends - 1) {
for (int b = 0; b < src_backend_id; b++) {
Expand Down Expand Up @@ -1223,10 +1231,30 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
}
}

//#define DEBUG_PASS1
//#define DEBUG_PASS2
//#define DEBUG_PASS3
//#define DEBUG_PASS4
static int set_if_supports(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) {
if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) {
*node_backend_id = cur_backend_id;
SET_CAUSE(node, "2.2");
} else {
for (int b = 0; b < sched->n_backends; b++) {
if (b == cur_backend_id) {
continue;
}
if (ggml_backend_supports_op(sched->backends[b], node)) {
*node_backend_id = b;
cur_backend_id = b;
SET_CAUSE(node, "2.2");
break;
}
}
}
return cur_backend_id;
}

static bool buffer_supported(ggml_backend_sched_t sched, const struct ggml_tensor * t, int cur_backend_id) {
ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer;
return buf != NULL && ggml_backend_buft_supports_backend(buf->buft, sched->backends[cur_backend_id]);
}

// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
Expand Down Expand Up @@ -1306,9 +1334,13 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
} else {
cur_backend_id = *node_backend_id;
}
} else {
*node_backend_id = cur_backend_id;
SET_CAUSE(node, "2.2");
} else if (cur_backend_id != -1) {
// FIXME: clean this
cur_backend_id = set_if_supports(sched, node, cur_backend_id, node_backend_id);
if (cur_backend_id == sched->n_backends - 1) {
// skip cpu (lowest prio backend)
cur_backend_id = -1;
}
}
}
}
Expand All @@ -1328,9 +1360,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
} else {
cur_backend_id = *node_backend_id;
}
} else {
*node_backend_id = cur_backend_id;
SET_CAUSE(node, "2.1");
} else if (cur_backend_id != -1) {
cur_backend_id = set_if_supports(sched, node, cur_backend_id, node_backend_id);
if (cur_backend_id == sched->n_backends - 1) {
// skip cpu (lowest prio backend)
cur_backend_id = -1;
}
}
}
}
Expand All @@ -1345,9 +1380,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
int * node_backend_id = &tensor_backend_id(node);
if (*node_backend_id != -1) {
cur_backend_id = *node_backend_id;
} else {
*node_backend_id = cur_backend_id;
SET_CAUSE(node, "2.4");
} else if (cur_backend_id != -1) {
cur_backend_id = set_if_supports(sched, node, cur_backend_id, node_backend_id);
}
}
}
Expand All @@ -1362,9 +1396,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
int * node_backend_id = &tensor_backend_id(node);
if (*node_backend_id != -1) {
cur_backend_id = *node_backend_id;
} else {
*node_backend_id = cur_backend_id;
SET_CAUSE(node, "2.3");
} else if (cur_backend_id != -1) {
cur_backend_id = set_if_supports(sched, node, cur_backend_id, node_backend_id);
}
}
}
Expand Down Expand Up @@ -1448,10 +1481,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
}
}
// check if the split has too many inputs
// FIXME: count the number of inputs instead of only checking when full
if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {
const size_t id = hash_id(src);
int src_backend_id = sched->tensor_backend_id[id];
if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) {
bool supported = buffer_supported(sched, src, cur_backend_id);
if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL && !supported) {
//printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
need_new_split = true;
break;
Expand Down Expand Up @@ -1511,7 +1546,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
}
}

if (src_backend_id != node_backend_id) {
bool supported = buffer_supported(sched, src, cur_backend_id);
if (src_backend_id != cur_backend_id && !supported) {
// create a copy of the input in the split's backend
const size_t id = hash_id(src);
if (sched->tensor_copies[id][cur_backend_id][0] == NULL) {
Expand Down
Loading

0 comments on commit cde46d2

Please sign in to comment.