Skip to content

Commit

Permalink
Server: enable lookup decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Apr 22, 2024
1 parent 40f74e4 commit 87e5656
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 59 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o ngram-cache.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)

Expand Down
7 changes: 3 additions & 4 deletions common/ngram-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,11 @@ void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filen

}

llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
bool llama_ngram_cache_load(llama_ngram_cache & ngram_cache, std::string & filename) {
std::ifstream hashmap_file(filename, std::ios::binary);
if (!hashmap_file) {
throw std::ifstream::failure("Unable to open file " + filename);
return false;
}
llama_ngram_cache ngram_cache;

llama_ngram ngram;
int32_t ntokens;
Expand Down Expand Up @@ -251,7 +250,7 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
}
GGML_ASSERT(hashmap_file.eof());

return ngram_cache;
return true;
}

void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {
Expand Down
3 changes: 2 additions & 1 deletion common/ngram-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ void llama_ngram_cache_draft(
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename);

// Load an ngram cache saved with llama_ngram_cache_save.
// ngram_cache: the ngram cache to load the data into.
// filename: the path from which to load the ngram cache.
// returns: an ngram cache containing the information saved to filename.
llama_ngram_cache llama_ngram_cache_load(std::string & filename);
bool llama_ngram_cache_load(llama_ngram_cache & ngram_cache, std::string & filename);

// Merge two ngram caches.
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
Expand Down
6 changes: 4 additions & 2 deletions examples/lookup/lookup-merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ int main(int argc, char ** argv){
}

fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
llama_ngram_cache ngram_cache_merged;
GGML_ASSERT(llama_ngram_cache_load(ngram_cache_merged, args[0]));

for (size_t i = 1; i < args.size()-1; ++i) {
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
llama_ngram_cache ngram_cache;
GGML_ASSERT(llama_ngram_cache_load(ngram_cache, args[i]));

llama_ngram_cache_merge(ngram_cache_merged, ngram_cache);
}
Expand Down
9 changes: 3 additions & 6 deletions examples/lookup/lookup-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,15 @@ int main(int argc, char ** argv){
const int64_t t_start_draft_us = ggml_time_us();

if (!params.lookup_cache_static.empty()) {
try {
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
} catch (std::ifstream::failure const &) {
if(!llama_ngram_cache_load(ngram_cache_static, params.lookup_cache_static)) {
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
exit(1);
}
}

if (!params.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
// If the dynamic lookup cache doesn't exist it will be created at the end of the program:
llama_ngram_cache_load(ngram_cache_dynamic, params.lookup_cache_dynamic);
}

t_draft_flat_us += ggml_time_us() - t_start_draft_us;
Expand Down
9 changes: 3 additions & 6 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,15 @@ int main(int argc, char ** argv){
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);

if (!params.lookup_cache_static.empty()) {
try {
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
} catch (std::ifstream::failure const &) {
if(!llama_ngram_cache_load(ngram_cache_static, params.lookup_cache_static)) {
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
exit(1);
}
}

if (!params.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
// If the dynamic lookup cache doesn't exist it will be created at the end of the program:
llama_ngram_cache_load(ngram_cache_dynamic, params.lookup_cache_dynamic);
}

t_draft_flat_us += ggml_time_us() - t_start_draft_us;
Expand Down
8 changes: 8 additions & 0 deletions examples/server/bench/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
parser.add_argument("--draft", type=int, help="Max. number of additional tokens to draft for lookup decoding", required=False, default=5)
parser.add_argument("-lcs", "--lookup-cache-static", type=str, help="Path to optional static lookup cache to use.", required=False, default=None)
parser.add_argument("-lcd", "--lookup-cache-dynamic", type=str, help="Path to optional dynamic lookup cache to use. Will be overwritten upon server shutdown.", required=False, default=None)

args = parser.parse_args(args_in)

Expand Down Expand Up @@ -269,6 +272,11 @@ def start_server_background(args):
server_args.append('--cont-batching')
server_args.append('--metrics')
server_args.extend(['--log-format', "text"])
server_args.extend(['--draft', args.draft])
if args.lookup_cache_static is not None:
server_args.extend(['--lookup-cache-static', args.lookup_cache_static])
if args.lookup_cache_dynamic is not None:
server_args.extend(['--lookup-cache-dynamic', args.lookup_cache_dynamic])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
pkwargs = {
Expand Down
Loading

0 comments on commit 87e5656

Please sign in to comment.