Skip to content

Commit

Permalink
llama : refactor session file management
Browse files Browse the repository at this point in the history
* llama : saving and restoring state checks for overflow

The size of the buffers should now be given to the functions working
with them, otherwise a truncated file could cause out of bound reads.

* llama : stream from session file instead of copying into a big buffer

Loading session files should no longer cause a memory usage spike.

* llama : llama_state_get_size returns the actual size instead of max

This is a breaking change, but makes that function *much* easier
to keep up to date, and it also makes it reflect the behavior
of llama_state_seq_get_size.

* llama : share code between whole and seq_id-specific state saving

Both session file types now use a more similar format.

* llama : no longer store all hparams in session files

Instead, the model arch name is stored.
The layer count and the embedding dimensions of the KV cache
are still verified when loading.
Storing all the hparams is not necessary.
  • Loading branch information
compilade committed Jul 25, 2024
1 parent de28008 commit 17a5ee2
Show file tree
Hide file tree
Showing 3 changed files with 604 additions and 705 deletions.
20 changes: 13 additions & 7 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data());
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());

FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
Expand Down Expand Up @@ -99,13 +99,16 @@ int main(int argc, char ** argv) {

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
std::vector<uint8_t> state_mem;

FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx2, state_mem.data())) {
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
Expand Down Expand Up @@ -159,13 +162,16 @@ int main(int argc, char ** argv) {

// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
std::vector<uint8_t> state_mem;

FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);

if (read != llama_state_set_data(ctx3, state_mem.data())) {
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx3);
llama_free_model(model);
Expand All @@ -182,7 +188,7 @@ int main(int argc, char ** argv) {
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3);
Expand All @@ -196,7 +202,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : kv cache cleared\n", __func__);

// restore kv into seq 1
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3);
Expand Down
21 changes: 13 additions & 8 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'

#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 7
#define LLAMA_SESSION_VERSION 8

#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 1
#define LLAMA_STATE_SEQ_VERSION 2

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -687,18 +687,20 @@ extern "C" {
// State / sessions
//

// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
// Returns the *actual* size in bytes of the state
// (rng, logits, embedding and kv_cache)
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
"use llama_state_get_size instead");

// Copies the state to the specified destination address.
// Destination needs to have allocated enough memory.
// Returns the number of bytes copied
LLAMA_API size_t llama_state_get_data(
struct llama_context * ctx,
uint8_t * dst);
uint8_t * dst,
size_t size);
LLAMA_API DEPRECATED(size_t llama_copy_state_data(
struct llama_context * ctx,
uint8_t * dst),
Expand All @@ -708,7 +710,8 @@ extern "C" {
// Returns the number of bytes read
LLAMA_API size_t llama_state_set_data(
struct llama_context * ctx,
const uint8_t * src);
const uint8_t * src,
size_t size);
LLAMA_API DEPRECATED(size_t llama_set_state_data(
struct llama_context * ctx,
const uint8_t * src),
Expand Down Expand Up @@ -750,6 +753,7 @@ extern "C" {
LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx,
uint8_t * dst,
size_t size,
llama_seq_id seq_id);

// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
Expand All @@ -759,6 +763,7 @@ extern "C" {
LLAMA_API size_t llama_state_seq_set_data(
struct llama_context * ctx,
const uint8_t * src,
size_t size,
llama_seq_id dest_seq_id);

LLAMA_API size_t llama_state_seq_save_file(
Expand Down
Loading

0 comments on commit 17a5ee2

Please sign in to comment.