From a34900aad194ae0239b3fad94c48dc82b9f3a1a1 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Wed, 10 Jul 2024 03:29:12 -0700 Subject: [PATCH] restrict to nsplit=2 --- src/llama.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 6bd0863c63e8c..4a309b999205a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14556,10 +14556,12 @@ static int llama_decode_internal( ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active); // Disable future graph caching in presence of env var, - // if there are multiple devices, or if batch size is greater than 1 + // if there are multiple devices, if batch size is greater than 1, + // or if nsplits is not 2. // TO DO enable graph caching for these cases bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr) - || (llama_get_device_count(model) > 1); + || (llama_get_device_count(model) > 1) + || (ggml_backend_sched_get_n_splits(lctx.sched) != 2); for (int i = 0 ; i < gf->n_nodes; i++) { if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) { disable_cached_ggml_graph = true;