Skip to content

Commit

Permalink
update rwkv.cpp h file
Browse files Browse the repository at this point in the history
--todo: impl rwkv_eval_sequence_in_chunks
  • Loading branch information
Cyberhan123 committed Oct 22, 2023
1 parent 2eef7df commit 2aeb25c
Showing 1 changed file with 72 additions and 13 deletions.
85 changes: 72 additions & 13 deletions deps/rwkv.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
#include <stdint.h>
#include <stdbool.h>

#ifdef RWKV_SHARED
#if defined(RWKV_SHARED)
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef RWKV_BUILD
# if defined(RWKV_BUILD)
# define RWKV_API __declspec(dllexport)
# else
# define RWKV_API __declspec(dllimport)
Expand All @@ -32,7 +32,7 @@
// Default file version is the latest version.
#define RWKV_FILE_VERSION RWKV_FILE_VERSION_MAX

#ifdef __cplusplus
#if defined(__cplusplus)
extern "C" {
#endif
Expand Down Expand Up @@ -76,11 +76,11 @@ extern "C" {
// If NULL, affects model load (rwkv_init_from_file) and quantization (rwkv_quantize_model_file) errors,
// as well as the default for new context.
// - print_errors: whether error messages should be automatically printed.
RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, bool print_errors);
RWKV_API void rwkv_set_print_errors(struct rwkv_context * ctx, const bool print_errors);

// Gets whether errors are automatically printed to stderr.
// - ctx: the context to retrieve the setting for, or NULL for the global setting.
RWKV_API bool rwkv_get_print_errors(struct rwkv_context * ctx);
RWKV_API bool rwkv_get_print_errors(const struct rwkv_context * ctx);

// Retrieves and clears the error flags.
// - ctx: the context the retrieve the error for, or NULL for the global error.
Expand All @@ -100,30 +100,87 @@ extern "C" {
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32_t n_threads);

// Offloads specified layers of context onto GPU using cuBLAS, if it is enabled.
// If rwkv.cpp was compiled without cuBLAS support, this function is a no-op.
RWKV_API bool rwkv_gpu_offload_layers(const struct rwkv_context * ctx, const uint32_t n_gpu_layers);
// Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
// For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
// - pass `rwkv_get_n_layer(ctx)` to offload all layers except model head
// - pass `rwkv_get_n_layer(ctx) + 1` to offload all layers, including model head
// Returns true if at least one layer was offloaded.
// If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
RWKV_API bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers);

// Evaluates the model for a single token.
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
// Returns false on any error.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_len(); or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval(struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out);
RWKV_API bool rwkv_eval(
struct rwkv_context * ctx,
const uint32_t token,
const float * state_in,
float * state_out,
float * logits_out
);

// Evaluates the model for a sequence of tokens.
// Uses a faster algorithm than rwkv_eval if you do not need the state and logits for every token. Best used with batch sizes of 64 or so.
// Uses a faster algorithm than `rwkv_eval` if you do not need the state and logits for every token. Best used with sequence lengths of 64 or so.
// Has to build a computation graph on the first call for a given sequence, but will use this cached graph for subsequent calls of the same sequence length.
// Not thread-safe. For parallel inference, call rwkv_clone_context to create one rwkv_context for each thread.
//
// NOTE ON GGML NODE LIMIT
//
// ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
// this limit when using large models and/or large sequence lengths.
// Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
//
// If you get `GGML_ASSERT: ...\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
// To get rid of the assertion failure, reduce the model size and/or sequence length.
//
// TODO When Metal (MPS) support is implemented, check that large sequence lengths work
//
// You can pass NULL to logits_out whenever logits are not needed. This can improve speed by ~10 ms per iteration, because logits are not calculated.
// Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
// Returns false on any error.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval_sequence(
struct rwkv_context * ctx,
const uint32_t * tokens,
const size_t sequence_len,
const float * state_in,
float * state_out,
float * logits_out
);

// Evaluates the model for a sequence of tokens using `rwkv_eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
// This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
// It is recommended to use this function instead of `rwkv_eval_sequence` to avoid mistakes and get maximum performance.
//
// Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
// A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
// and choose one that works the best in your use case.
//
// Not thread-safe. For parallel inference, call `rwkv_clone_context` to create one rwkv_context for each thread.
// Returns false on any error.
// - tokens: pointer to an array of tokens. If NULL, the graph will be built and cached, but not executed: this can be useful for initialization.
// - sequence_len: number of tokens to read from the array.
// - chunk_size: size of each chunk in tokens, must be positive.
// - state_in: FP32 buffer of size rwkv_get_state_len(), or NULL if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_len(). This buffer will be written to if non-NULL.
// - logits_out: FP32 buffer of size rwkv_get_logits_len(). This buffer will be written to if non-NULL.
RWKV_API bool rwkv_eval_sequence(struct rwkv_context * ctx, const uint32_t * tokens, size_t sequence_len, const float * state_in, float * state_out, float * logits_out);
RWKV_API bool rwkv_eval_sequence_in_chunks(
struct rwkv_context * ctx,
const uint32_t * tokens,
const size_t sequence_len,
const size_t chunk_size,
const float * state_in,
float * state_out,
float * logits_out
);

// Returns the number of tokens in the given model's vocabulary.
// Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536).
Expand All @@ -134,6 +191,8 @@ extern "C" {
RWKV_API size_t rwkv_get_n_embed(const struct rwkv_context * ctx);

// Returns the number of layers in the given model.
// A layer is a pair of RWKV and FFN operations, stacked multiple times throughout the model.
// Embedding matrix and model head (unembedding matrix) are NOT counted in `n_layer`.
// Useful for always offloading the entire model to GPU.
RWKV_API size_t rwkv_get_n_layer(const struct rwkv_context * ctx);

Expand Down Expand Up @@ -171,7 +230,7 @@ extern "C" {
// Returns system information string.
RWKV_API const char * rwkv_get_system_info_string(void);

#ifdef __cplusplus
#if defined(__cplusplus)
}
#endif

Expand Down

0 comments on commit 2aeb25c

Please sign in to comment.