You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I did some prototyping for llama.cpp training support to determine what would be needed to make it work. The changes would I think be relatively extensive so I think it will be easier to first implement training support for the GGML GPT-2 example. I think the following modifications to ggml_opt will be needed:
Extend ggml_opt_dataset with support for sequential data and sparse labels. Currently you can get support for a sequence of tokens with a workaround where you set the context as the input data and the token following the token as the label (with one of the values set to 1 and all others set to 0). This workaround greatly increases the memory usage though.
Implement support for defining a new forward graph in user code. The current structure where the forward graph is defined once in user code and then passed to ggml_opt_init has comparatively little overhead and is suitable for e.g. image classification models that have no internal state and fixed inputs and outputs. But with e.g. language models the situation is more complicated since there is a KV cache. In llama.cpp there is a function that dynamically rebuilds the forward graph and I think this function can be re-used in a minimally invasive (for llama.cpp) way by modularizing the code in ggml_opt_init. Basically, move the code for constructing the backward graphs to dedicated functions and call them if the user explicitly allocates a new graph. The only challenge will I think be to correctly re-use the same buffers for gradient accumulation and optimizer momenta after allocating a new graph from user code. If however the callback always returns the same operations in the same order (which to my knowledge it does) I think this can be done relatively easily. Care also needs to be taken (in user code) that creating and allocating a new graph would invalidate references to e.g. the model predictions. Apart from the changes to ggml_opt I think llama.cpp would need a function like llama_opt_epoch equivalent to ggml_opt_epoch that loads data from a dataset, repackages it as a llama.cpp batch, builds and allocates a new graph, and then evaluates the graph.
For the GPT-2 GGML examples I plan to:
Refactor and deduplicate the code, and update the documentation. My understanding is that going forward the intended use of GGML is via ggml_backend so I will remove the code using no_alloc=false. I plan to have one example with ggml_backend_sched, one where everything is allocated to a single backend, and the examples for batched inference and quantization.
Use GGUF for model storage.
Add training support. Training a GPT-2 model from scratch will not be viable given the current code but I think you can already demonstrate finetuning. The training code will essentially give you a way to calculate perplexity for free by only doing the forward pass so the way I imagine it is that users would take the base model, evaluate perplexity on wikitext, then finetune the model on the same dataset, and then evaluate perplexity again to see that the value went down. Long-term I think the llama.cpp perplexity example can be simplified by reusing the cod in ggml_opt. It would then also be possible to evaluate models on arbitrary multiple-choice benchmarks as long as you can repackage the benchmark as a ggml_opt_dataset (should be possible to do in Python if GGUF support is added).
The text was updated successfully, but these errors were encountered:
I did some prototyping for llama.cpp training support to determine what would be needed to make it work. The changes would I think be relatively extensive so I think it will be easier to first implement training support for the GGML GPT-2 example. I think the following modifications to
ggml_opt
will be needed:ggml_opt_dataset
with support for sequential data and sparse labels. Currently you can get support for a sequence of tokens with a workaround where you set the context as the input data and the token following the token as the label (with one of the values set to 1 and all others set to 0). This workaround greatly increases the memory usage though.ggml_opt_init
has comparatively little overhead and is suitable for e.g. image classification models that have no internal state and fixed inputs and outputs. But with e.g. language models the situation is more complicated since there is a KV cache. In llama.cpp there is a function that dynamically rebuilds the forward graph and I think this function can be re-used in a minimally invasive (for llama.cpp) way by modularizing the code inggml_opt_init
. Basically, move the code for constructing the backward graphs to dedicated functions and call them if the user explicitly allocates a new graph. The only challenge will I think be to correctly re-use the same buffers for gradient accumulation and optimizer momenta after allocating a new graph from user code. If however the callback always returns the same operations in the same order (which to my knowledge it does) I think this can be done relatively easily. Care also needs to be taken (in user code) that creating and allocating a new graph would invalidate references to e.g. the model predictions. Apart from the changes toggml_opt
I think llama.cpp would need a function likellama_opt_epoch
equivalent toggml_opt_epoch
that loads data from a dataset, repackages it as a llama.cpp batch, builds and allocates a new graph, and then evaluates the graph.For the GPT-2 GGML examples I plan to:
ggml_backend
so I will remove the code usingno_alloc=false
. I plan to have one example withggml_backend_sched
, one where everything is allocated to a single backend, and the examples for batched inference and quantization.ggml_opt
. It would then also be possible to evaluate models on arbitrary multiple-choice benchmarks as long as you can repackage the benchmark as aggml_opt_dataset
(should be possible to do in Python if GGUF support is added).The text was updated successfully, but these errors were encountered: