From ca3f025d54fcd5e94d1a345cbe5befdcfdfd7192 Mon Sep 17 00:00:00 2001 From: thxCode Date: Thu, 30 May 2024 11:24:08 +0800 Subject: [PATCH] refactor: estimate Signed-off-by: thxCode --- README.md | 13 +- cmd/gguf-parser/main.go | 65 ++- file.go | 865 +++++++++++++------------------------- file_architecture_test.go | 4 +- file_estimate.go | 141 +++++-- file_estimate_option.go | 11 + file_estimate_test.go | 78 ++-- file_model_test.go | 4 +- file_option.go | 17 +- file_test.go | 18 +- file_tokenizer_test.go | 4 +- 11 files changed, 498 insertions(+), 722 deletions(-) diff --git a/README.md b/README.md index d9b117f..4182db1 100644 --- a/README.md +++ b/README.md @@ -55,13 +55,10 @@ if err != nil { ``` -#### Use approximate parsing - -> The approximate parsing is faster than the accurate one, -> but the result may not be accurate. +#### Skip large metadata ```go -f, err := ParseGGUFFile("path/to/model.gguf", UseApproximate()) +f, err := ParseGGUFFile("path/to/model.gguf", SkipLargeMetadata()) if err != nil { panic(err) } @@ -124,6 +121,12 @@ spew.Dump(f.Estimate(WithContextSize(4096) /* 4K */)) ``` +#### Estimate with specific offload layers + +```go +spew.Dump(f.Estimate(WithOffloadLayers(10))) +``` + ## License MIT diff --git a/cmd/gguf-parser/main.go b/cmd/gguf-parser/main.go index 86c63f9..0aac7eb 100644 --- a/cmd/gguf-parser/main.go +++ b/cmd/gguf-parser/main.go @@ -26,14 +26,14 @@ func main() { url string repo, model string // read options - debug bool - approximate = true - mmap = true - skipProxy bool - skipTLS bool + debug bool + mmap = true + skipProxy bool + skipTLS bool // estimate options - ctxSize = 512 - kvType = "f16" + ctxSize = 512 + kvType = "f16" + offloadLayers uint64 // output options skipModel bool skipArchitecture bool @@ -58,12 +58,12 @@ func main() { fs.StringVar(&model, "model", model, "Model below the --repo, e.g. "+ "Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf") fs.BoolVar(&debug, "debug", debug, "Debug mode") - fs.BoolVar(&approximate, "approximate", approximate, "Speed up reading") fs.BoolVar(&mmap, "mmap", mmap, "Use mmap to read the local file") fs.BoolVar(&skipProxy, "skip-proxy", skipProxy, "Skip using proxy when reading from a remote URL") fs.BoolVar(&skipTLS, "skip-tls", skipTLS, "Skip TLS verification when reading from a remote URL") fs.IntVar(&ctxSize, "ctx-size", ctxSize, "Context size to estimate memory usage") fs.StringVar(&kvType, "kv-type", kvType, "Key-Value cache type, select from [f32, f16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1]") + fs.Uint64Var(&offloadLayers, "offload-layers", offloadLayers, "Specify how many layers to offload, default is fully offloading") fs.BoolVar(&skipModel, "skip-model", skipModel, "Skip model metadata") fs.BoolVar(&skipArchitecture, "skip-architecture", skipArchitecture, "Skip architecture metadata") fs.BoolVar(&skipTokenizer, "skip-tokenizer", skipTokenizer, "Skip tokenizer metadata") @@ -77,9 +77,11 @@ func main() { // Prepare options. - var ropts []GGUFReadOption - if approximate { - ropts = append(ropts, UseApproximate()) + ropts := []GGUFReadOption{ + SkipLargeMetadata(), + } + if debug { + ropts = append(ropts, UseDebug()) } if mmap { ropts = append(ropts, UseMMap()) @@ -91,7 +93,9 @@ func main() { ropts = append(ropts, SkipTLSVerification()) } - var eopts []GGUFEstimateOption + eopts := []GGUFEstimateOption{ + WithContextSize(512), + } if ctxSize > 0 { eopts = append(eopts, WithContextSize(int32(ctxSize))) } @@ -117,6 +121,9 @@ func main() { } eopts = append(eopts, WithCacheKeyType(kv), WithCacheValueType(kv)) } + if offloadLayers > 0 { + eopts = append(eopts, WithOffloadLayers(offloadLayers)) + } // Parse GGUF file. @@ -190,8 +197,9 @@ func main() { if !skipModel { tprintf( - []string{"Name", "Architecture", "Quantization Version", "File Type", "Little Endian", "Size", "Parameters", "BPW"}, + []string{"", "Name", "Architecture", "Quantization Version", "File Type", "Little Endian", "Size", "Parameters", "BPW"}, []string{ + "MODEL", m.Name, m.Architecture, sprintf(m.QuantizationVersion), @@ -205,8 +213,9 @@ func main() { if !skipArchitecture { tprintf( - []string{"Maximum Context Length", "Embedding Length", "Layers", "Feed Forward Length", "Expert Count", "Vocabulary Length"}, + []string{"", "Maximum Context", "Embedding", "Layers", "Feed Forward", "Expert Count", "Vocabulary"}, []string{ + "ARCHITECTURE", sprintf(a.MaximumContextLength), sprintf(a.EmbeddingLength), fmt.Sprintf("%d + 1 = %d", @@ -220,8 +229,9 @@ func main() { if !skipTokenizer { tprintf( - []string{"Tokenizer Model", "Tokens Length", "Added Tokens Length", "BOS", "EOS", "Unknown", "Separator", "Padding"}, + []string{"", "Model", "Tokens", "Added Tokens", "BOS", "EOS", "Unknown", "Separator", "Padding"}, []string{ + "TOKENIZER", t.Model, sprintf(t.TokensLength), sprintf(t.AddedTokensLength), @@ -235,12 +245,25 @@ func main() { if !skipEstimate { tprintf( - []string{"Load Memory", "KVCache Memory", "Total Memory"}, + []string{"", "KV Cache", "Compute", "IO", "Sum"}, []string{ - e.MemoryLoad.String(), - e.KVCache.MemoryTotal.String(), - e.MemoryTotal.String(), + "ESTIMATE TOTAL", + e.Total.KVCache.Sum().String(), + e.Total.Compute.String(), + e.Total.IO.String(), + e.Total.Sum().String(), }) + if e.Offload != nil { + tprintf( + []string{"", "KV Cache", "Compute", "IO", "Sum"}, + []string{ + "ESTIMATE OFFLOAD", + e.Offload.KVCache.Sum().String(), + e.Offload.Compute.String(), + e.Offload.IO.String(), + e.Offload.Sum().String(), + }) + } } } @@ -277,10 +300,12 @@ func tprintf(headers, rows []string) { tb := tablewriter.NewWriter(os.Stdout) tb.SetHeaderAlignment(tablewriter.ALIGN_CENTER) tb.SetAlignment(tablewriter.ALIGN_CENTER) - tb.SetHeaderLine(true) tb.SetBorder(true) tb.SetTablePadding("\t") + tb.SetHeaderLine(true) tb.SetHeader(headers) + tb.SetAutoMergeCells(true) + tb.SetRowLine(true) tb.Append(rows) tb.Render() fmt.Println() diff --git a/file.go b/file.go index 8ef076a..c9a6ffb 100644 --- a/file.go +++ b/file.go @@ -10,6 +10,7 @@ import ( "net/http" "regexp" "strconv" + "strings" "time" "github.com/dustin/go-humanize" @@ -33,9 +34,7 @@ type GGUFFile struct { Header GGUFHeader `json:"header"` // TensorInfos are the tensor infos of the GGUF file, // the size of TensorInfos is equal to `Header.TensorCount`. - // - // TensorInfos may be empty if read approximately. - TensorInfos GGUFTensorInfos `json:"tensorInfos,omitempty"` + TensorInfos GGUFTensorInfos `json:"tensorInfos"` // Padding is the padding size of the GGUF file, // which is used to split Header and TensorInfos from tensor data. Padding int64 `json:"padding"` @@ -151,7 +150,7 @@ type ( Len uint64 `json:"len"` // Array holds all array items. // - // Array may be empty if read approximately. + // Array may be empty if skipping. Array []any `json:"array,omitempty"` /* Appendix */ @@ -219,17 +218,6 @@ const ( _GGMLTypeCount // Unknown ) -// Sizes for GGML constant. -const ( - // GGMLTensorSize is the size of a GGML tensor in bytes, - // see https://github.com/ggerganov/ggml/blob/0cbb7c0e053f5419cfbebb46fbf4d4ed60182cf5/include/ggml/ggml.h#L606. - GGMLTensorSize = 368 - - // GGMLObjectSize is the size of a GGML object in bytes, - // see https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/include/ggml/ggml.h#L563. - GGMLObjectSize = 32 -) - // Types for GGUFTensorInfo. type ( // GGUFTensorInfo represents a tensor info in a GGUF file. @@ -423,23 +411,14 @@ func parseGGUFFile(s int64, f io.ReadSeeker, o _GGUFReadOptions) (_ *GGUFFile, e // tensor infos { rd := _GGUFTensorInfoReader{_GGUFReader: rd} - if !o.Approximate { - tis := make(GGUFTensorInfos, gf.Header.TensorCount) - for i := uint64(0); i < gf.Header.TensorCount; i++ { - tis[i], err = rd.Read() - if err != nil { - return nil, fmt.Errorf("read tensor info %d: %w", i, err) - } - } - gf.TensorInfos = tis - } else { - for i := uint64(0); i < gf.Header.TensorCount; i++ { - _, err = rd.Read() - if err != nil { - return nil, fmt.Errorf("read tensor info %d: %w", i, err) - } + tis := make(GGUFTensorInfos, gf.Header.TensorCount) + for i := uint64(0); i < gf.Header.TensorCount; i++ { + tis[i], err = rd.Read() + if err != nil { + return nil, fmt.Errorf("read tensor info %d: %w", i, err) } } + gf.TensorInfos = tis } pds, err := f.Seek(0, io.SeekCurrent) @@ -463,19 +442,11 @@ func parseGGUFFile(s int64, f io.ReadSeeker, o _GGUFReadOptions) (_ *GGUFFile, e // tensor data offset gf.TensorDataStartOffset = pds + gf.Padding - if o.Approximate { - // size - gf.ModelSize = GGUFBytesScalar(s - gf.TensorDataStartOffset) - // parameters - gf.ModelParameters = gf.guessParameters() - } else { - for i := range gf.TensorInfos { - // size - gf.ModelSize += GGUFBytesScalar(gf.TensorInfos[i].Bytes()) - // parameters - gf.ModelParameters += GGUFParametersScalar(gf.TensorInfos[i].Elements()) - } - } + // model size + gf.ModelSize = GGUFBytesScalar(s - gf.TensorDataStartOffset) + + // model parameters + gf.ModelParameters = GGUFParametersScalar(gf.TensorInfos.Elements()) // bpw if gf.ModelParameters != 0 { @@ -485,469 +456,101 @@ func parseGGUFFile(s int64, f io.ReadSeeker, o _GGUFReadOptions) (_ *GGUFFile, e return &gf, nil } -// guessParameters guesses the number of parameters, -// which is inspired by https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L3969-L4388. -func (gf *GGUFFile) guessParameters() GGUFParametersScalar { - const ( - K = 1e3 - M = 1e3 * K - B = 1e3 * M - - // https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L1718-L1761 - _14M = 14 * M - _17M = 17 * M - _22M = 22 * M - _33M = 33 * M - _70M = 70 * M - _109M = 109 * M - _137M = 137 * M - _160M = 160 * M - _335M = 335 * M - _410M = 410 * M - _0_5B = 0.5 * B - _1B = 1 * B - _1_4B = 1.4 * B - _2B = 2 * B - _2_8B = 2.8 * B - _3B = 3 * B - _4B = 4 * B - _6_9B = 6.9 * B - _7B = 7 * B - _8B = 8 * B - _12B = 12 * B - _13B = 13 * B - _14B = 14 * B - _15B = 15 * B - _20B = 20 * B - _30B = 30 * B - _34B = 34 * B - _35B = 35 * B - _40B = 40 * B - _65B = 65 * B - _70B = 70 * B - _314B = 314 * B - _SMALL = 0.1 * B - _MEDIUM = 0.4 * B - _LARGE = 0.8 * B - _XL = 1.5 * B - _A2_7B = 14.3 * B // Guess - _8x7B = 47 * B // Guess - _8x22B = 141 * B // Guess - _16x12B = 132 * B // Guess - _10B_128x3_66B = 480 * B // Guess - ) - - arch := "llama" - if v, ok := gf.Header.MetadataKV.Get("general.architecture"); ok { - arch = v.ValueString() +// Types for GGUF hierarchical tensors. +type ( + // IGGUFTensorInfos is an interface for GGUFTensorInfos. + IGGUFTensorInfos interface { + // Get returns the GGUFTensorInfo with the given name, + // and true if found, and false otherwise. + Get(name string) (info GGUFTensorInfo, found bool) + // Search returns a list of GGUFTensorInfo with the names that match the given regex. + Search(nameRegex *regexp.Regexp) (infos []GGUFTensorInfo) + // Index returns a map value to the GGUFTensorInfo with the given names, + // and the number of names found. + Index(names []string) (infos map[string]GGUFTensorInfo, found int) + // Elements returns the number of elements of the GGUFTensorInfo. + Elements() uint64 + // Bytes returns the number of bytes of the GGUFTensorInfo. + Bytes() uint64 + } + + // GGUFLayerTensorInfos represents hierarchical tensor infos of a GGUF file, + // it can save GGUFNamedTensorInfos, GGUFTensorInfos, and GGUFTensorInfo. + GGUFLayerTensorInfos []IGGUFTensorInfos + + // GGUFNamedTensorInfos is the namespace for relevant tensors, + // which must has a name. + GGUFNamedTensorInfos struct { + // Name is the name of the namespace. + Name string `json:"name"` + // GGUFLayerTensorInfos can save GGUFNamedTensorInfos, GGUFTensorInfos, or GGUFTensorInfo. + // + // If the item is type of GGUFTensorInfo, it must be the leaf node. + // + // Any branch nodes are type of GGUFNamedTensorInfos or GGUFTensorInfos, + // which can be nested. + // + // Branch nodes store in type pointer. + GGUFLayerTensorInfos `json:"items,omitempty"` } +) - var ( - contextLengthKey = arch + ".context_length" - embeddingLengthKey = arch + ".embedding_length" - blockCountKey = arch + ".block_count" - feedForwardLengthKey = arch + ".feed_forward_length" - expertCountKey = arch + ".expert_count" - attentionHeadCountKey = arch + ".attention.head_count" - attentionHeadCountKVKey = arch + ".attention.head_count_kv" - vocabularyLengthKey = arch + ".vocab_size" // uint32 maybe - tokenizerGGMLTokensKey = "tokenizer.ggml.tokens" - ) - m, _ := gf.Header.MetadataKV.Index([]string{ - contextLengthKey, - embeddingLengthKey, - blockCountKey, - feedForwardLengthKey, - expertCountKey, - attentionHeadCountKey, - attentionHeadCountKVKey, - vocabularyLengthKey, - tokenizerGGMLTokensKey, - }) +// Layers converts the GGUFTensorInfos to GGUFLayerTensorInfos. +func (gf *GGUFFile) Layers() GGUFLayerTensorInfos { + var ret GGUFLayerTensorInfos - var ( - embeddingLength uint64 - blockCount uint64 - feedForwardLength uint64 - expertCount uint32 - attentionHeadCount uint64 - attentionHeadCountKV uint64 - vocabularyLength uint64 - ) - if v, ok := m[embeddingLengthKey]; ok { - embeddingLength = ValueNumeric[uint64](v) - } - if v, ok := m[blockCountKey]; ok { - blockCount = ValueNumeric[uint64](v) - } - if v, ok := m[feedForwardLengthKey]; ok { - feedForwardLength = ValueNumeric[uint64](v) - } - if v, ok := m[expertCountKey]; ok { - expertCount = ValueNumeric[uint32](v) - } - if v, ok := m[attentionHeadCountKey]; ok { - attentionHeadCount = ValueNumeric[uint64](v) - } - if v, ok := m[attentionHeadCountKVKey]; ok { - attentionHeadCountKV = ValueNumeric[uint64](v) - } else { - attentionHeadCountKV = attentionHeadCount - } - if v, ok := m[vocabularyLengthKey]; ok { - vocabularyLength = ValueNumeric[uint64](v) - } else if v, ok := m[tokenizerGGMLTokensKey]; ok { - vocabularyLength = v.ValueArray().Len - } - - // Try historical statistics, - // https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L228-L263 - switch arch { - case "llama": - if expertCount == 8 { - switch blockCount { - case 32: - return _8x7B - case 56: - return _8x22B - } - } else { - switch blockCount { - case 22: - return _1B - case 26: - return _3B - case 32: - if vocabularyLength < 40000 { - return _7B - } - return _8B - case 40: - return _13B - case 48: - return _34B - case 60: - return _30B - case 80: - if attentionHeadCount == attentionHeadCountKV { - return _65B - } - return _70B - } - } - case "falcon": - switch blockCount { - case 32: - return _7B - case 60: - return _40B - } - case "grok": - if blockCount == 64 { - return _314B - } - case "gpt2": - switch blockCount { - case 12: - return _SMALL - case 24: - return _MEDIUM - case 36: - return _LARGE - case 48: - return _XL - } - case "gptj": - case "gptneox": - switch blockCount { - case 6: - switch feedForwardLength { - case 512: - return _14M - case 2048: - return _70M - } - case 12: - if feedForwardLength == 3072 { - return _160M - } - case 16: - if feedForwardLength == 8192 { - return _1B - } - case 24: - switch feedForwardLength { - case 4096: - return _410M - case 8192: - return _1_4B - } - case 32: - switch feedForwardLength { - case 10240: - return _2_8B - case 16384: - return _6_9B - } - case 36: - if feedForwardLength == 20480 { - return _12B - } - case 44: - if feedForwardLength == 24576 { - return _20B - } - } - case "mpt": - switch blockCount { - case 32: - return _7B - case 48: - return _30B - } - case "baichuan": - switch blockCount { - case 32: - return _7B - case 40: - return _13B - } - case "starcoder": - switch blockCount { - case 24: - return _1B - case 36: - return _3B - case 42: - return _7B - case 40: - return _15B - } - case "refact": - if blockCount == 32 { - return _1B - } - case "bert": - switch blockCount { - case 3: - return _17M - case 6: - return _22M - case 12: - switch embeddingLength { - case 384: - return _33M - case 768: - return _109M - } - case 24: - return _335M - } - case "nomic-bert": - if blockCount == 12 && embeddingLength == 768 { - return _137M - } - case "jina-bert-v2": - switch blockCount { - case 4: - return _33M - case 12: - return _137M - } - case "bloom": - switch blockCount { - case 24: - return _1B - case 30: - switch embeddingLength { - case 2560: - return _3B - case 4096: - return _7B - } - } - case "stablelm": - switch blockCount { - case 24: - return _1B - case 32: - return _3B - case 40: - return _12B - } - case "qwen": - switch blockCount { - case 32: - return _7B - case 40: - return _13B - } - case "qwen2": - switch blockCount { - case 24: - if embeddingLength == 1024 { - return _0_5B - } - return _1B - case 32: - return _7B - case 40: - if attentionHeadCount == 20 { - return _4B + pm := make(map[string]any) + for i := range gf.TensorInfos { + ps := strings.Split(gf.TensorInfos[i].Name, ".") + switch { + default: + ret = append(ret, gf.TensorInfos[i]) + continue + case len(ps) >= 2 && ps[0] == "blk": + p := strings.Join([]string{ps[0], ps[1]}, ".") + if _, ok := pm[p]; !ok { + l := &GGUFNamedTensorInfos{Name: p} + pm[p] = l + ret = append(ret, l) } - return _13B - case 80: - return _70B - } - case "qwen2moe": - if blockCount == 24 { - return _A2_7B - } - case "phi2": - switch blockCount { - case 24: - return _1B - case 32: - return _3B - } - case "phi3": - switch blockCount { - case 24: - return _1B - case 32: - return _3B - case 40: - return _14B - } - case "plamo": - if blockCount == 40 { - return _13B - } - case "codeshell": - if blockCount == 42 { - return _SMALL - } - case "orion": - if blockCount == 40 { - return _14B - } - case "internlm2": - switch blockCount { - case 32: - return _7B - case 48: - return _20B - } - case "minicpm": - if blockCount == 40 { - return _2B - } - case "gemma": - switch blockCount { - case 18: - return _2B - case 28: - return _7B - } - case "starcoder2": - switch blockCount { - case 30: - return _3B - case 32: - return _7B - case 40: - return _15B - } - case "mamba": - switch blockCount { - case 24: - if embeddingLength == 768 { - return _SMALL + l := pm[p].(*GGUFNamedTensorInfos) + l.GGUFLayerTensorInfos = append(l.GGUFLayerTensorInfos, gf.TensorInfos[i]) + case len(ps) >= 3 && (ps[0] == "decoder" || ps[0] == "encoder"): + p := ps[0] + if _, ok := pm[p]; !ok { + xl := &GGUFNamedTensorInfos{Name: p} + pm[p] = xl + ret = append(ret, xl) } - case 48: - switch embeddingLength { - case 1024: - return _MEDIUM - case 1536: - return _LARGE - case 2048: - return _XL + xl := pm[p].(*GGUFNamedTensorInfos) + if ps[1] != "block" { + xl.GGUFLayerTensorInfos = append(xl.GGUFLayerTensorInfos, gf.TensorInfos[i]) + continue } - case 64: - if embeddingLength == 2560 { - return _3B + p = strings.Join([]string{ps[0], ps[1], ps[2]}, ".") + if _, ok := pm[p]; !ok { + l := &GGUFNamedTensorInfos{Name: p} + pm[p] = l + xl.GGUFLayerTensorInfos = append(xl.GGUFLayerTensorInfos, l) } - } - case "xverse": - switch blockCount { - case 32: - return _7B - case 40: - return _13B - case 80: - return _65B - } - case "command-r": - if blockCount == 40 { - return _35B - } - case "dbrx": - if blockCount == 40 { - return _16x12B - } - case "olmo": - switch blockCount { - case 22: - return _1B - case 32: - return _7B - case 80: - return _70B - } - case "arctic": - if expertCount == 128 && blockCount == 35 { - return _10B_128x3_66B + l := pm[p].(*GGUFNamedTensorInfos) + l.GGUFLayerTensorInfos = append(l.GGUFLayerTensorInfos, gf.TensorInfos[i]) } } - - // Otherwise, calculate by experience. - // - // Let's say, the model is based on Transformer architecture, - // and use decoder-only. - // - // Vocabulary embedding parameter number(VeP), mainly includes the embedding matrix. - // The embedding matrix shape is [VocabularyLength, EmbeddingLength]. - // So the VeP value is VocabularyLength * EmbeddingLength. - // - // Self-Attention parameter number(SaP), includes Wq, Wk, Wv, Wo, and their bias. - // The all weight matrix shapes are [EmbeddingLength, EmbeddingLength], - // and the bias shapes are [EmbeddingLength]. - // So the SaP value is 4 * (EmbeddingLength * EmbeddingLength) + 4 * EmbeddingLength. - // - // Feed-Forward parameter number(FfP), includes W1, W2, and their bias. - // The W1 shape is [EmbeddingLength, 4*EmbeddingLength], its bias shape is [4*EmbeddingLength]. - // The W2 shape is [4*EmbeddingLength, EmbeddingLength], its bias shape is [EmbeddingLength]. - // So the FfP value is (EmbeddingLength * 4 * EmbeddingLength) + 4 * EmbeddingLength + (4 * EmbeddingLength * EmbeddingLength) + EmbeddingLength. - // - // There are two LayerNorm, one for Self-Attention, and another for Feed-Forward. - // Layer Normalization parameter number(LnP), includes scale and bias. - // The scale and bias shapes are [EmbeddingLength]. - // So the LnP value is 2 * (2 * EmbeddingLength). - // - // So the total parameters of a decoder-only model can estimate as below. - // Parameters = BlockCount * (SaP + FfP + LnP) + VeP - // = BlockCount * (12 * EmbeddingLength * EmbeddingLength + 13 * EmbeddingLength) + VocabularyLength * EmbeddingLength - - ret := blockCount*(12*embeddingLength*embeddingLength+13*embeddingLength) + vocabularyLength*embeddingLength - // TODO MoE / SSM / RoPE. - return GGUFParametersScalar(ret) + return ret } func (s GGUFBytesScalar) String() string { + if s == 0 { + return "0 B" + } return humanize.IBytes(uint64(s)) } func (s GGUFParametersScalar) String() string { + if s == 0 { + return "0" + } switch { case s >= 1e15: return humanize.CommafWithDigits(float64(s)/1e15, 1) + " Q" @@ -966,7 +569,7 @@ func (s GGUFParametersScalar) String() string { func (s GGUFBitsPerWeightScalar) String() string { if s == 0 { - return "Unknown" + return "0 bpw" } return strconv.FormatFloat(float64(s), 'f', 2, 64) + " bpw" } @@ -1277,26 +880,6 @@ func ValuesNumeric[T constraints.Integer | constraints.Float](av GGUFMetadataKVA return v } -// HasAll returns true if the GGUFMetadataKVs has all the given keys, -// and false otherwise. -func (kvs GGUFMetadataKVs) HasAll(keys []string) bool { - ks := make(map[string]struct{}, len(keys)) - for i := range keys { - ks[keys[i]] = struct{}{} - } - for i := range kvs { - k := kvs[i].Key - if _, ok := ks[k]; !ok { - continue - } - delete(ks, k) - if len(ks) == 0 { - break - } - } - return len(ks) == 0 -} - // Get returns the GGUFMetadataKV with the given key, // and true if found, and false otherwise. func (kvs GGUFMetadataKVs) Get(key string) (value GGUFMetadataKV, found bool) { @@ -1405,12 +988,41 @@ func (t GGMLType) RowSizeOf(dimensions []uint64) uint64 { return ds } +// Get returns the GGUFTensorInfo with the given name, +// and true if found, and false otherwise. +func (ti GGUFTensorInfo) Get(name string) (info GGUFTensorInfo, found bool) { + if ti.Name == name { + return ti, true + } + return GGUFTensorInfo{}, false +} + +// Search returns a list of GGUFTensorInfo with the names that match the given regex. +func (ti GGUFTensorInfo) Search(nameRegex *regexp.Regexp) (infos []GGUFTensorInfo) { + if nameRegex.MatchString(ti.Name) { + return []GGUFTensorInfo{ti} + } + return nil +} + +// Index returns a map value to the GGUFTensorInfo with the given names, +// and the number of names found. +func (ti GGUFTensorInfo) Index(names []string) (infos map[string]GGUFTensorInfo, found int) { + if len(names) == 0 { + return nil, 0 + } + if names[0] == ti.Name { + return map[string]GGUFTensorInfo{ti.Name: ti}, 1 + } + return nil, 0 +} + // Elements returns the number of elements of the GGUFTensorInfo, // which is inspired by // https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/src/ggml.c#L2597-L2601. func (ti GGUFTensorInfo) Elements() uint64 { if ti.NDimensions == 0 { - panic(errors.New("no dimensions")) + return 0 } ret := uint64(1) @@ -1425,7 +1037,7 @@ func (ti GGUFTensorInfo) Elements() uint64 { // https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/src/ggml.c#L2609-L2626. func (ti GGUFTensorInfo) Bytes() uint64 { if ti.NDimensions == 0 { - panic(errors.New("no dimensions")) + return 0 } tt, ok := ti.Type.Trait() @@ -1459,26 +1071,6 @@ func (ti GGUFTensorInfo) Bytes() uint64 { return ret } -// HasAll returns true if the GGUFTensorInfos has all the given names, -// and false otherwise. -func (tis GGUFTensorInfos) HasAll(names []string) bool { - ns := make(map[string]struct{}, len(names)) - for i := range names { - ns[names[i]] = struct{}{} - } - for i := range tis { - n := tis[i].Name - if _, ok := ns[n]; !ok { - continue - } - delete(ns, n) - if len(ns) == 0 { - break - } - } - return len(ns) == 0 -} - // Get returns the GGUFTensorInfo with the given name, // and true if found, and false otherwise. func (tis GGUFTensorInfos) Get(name string) (info GGUFTensorInfo, found bool) { @@ -1520,6 +1112,136 @@ func (tis GGUFTensorInfos) Index(names []string) (infos map[string]GGUFTensorInf return infos, found } +// Elements returns the number of elements of the GGUFTensorInfos. +func (tis GGUFTensorInfos) Elements() uint64 { + var ret uint64 + for i := range tis { + ret += tis[i].Elements() + } + return ret +} + +// Bytes returns the number of bytes of the GGUFTensorInfos. +func (tis GGUFTensorInfos) Bytes() uint64 { + var ret uint64 + for i := range tis { + ret += tis[i].Bytes() + } + return ret +} + +// Get returns the GGUFTensorInfo with the given name, +// and true if found, and false otherwise. +func (ltis GGUFLayerTensorInfos) Get(name string) (info GGUFTensorInfo, found bool) { + for i := range ltis { + switch v := ltis[i].(type) { + case GGUFTensorInfo: + if v.Name == name { + return v, true + } + case *GGUFNamedTensorInfos: + info, found = v.GGUFLayerTensorInfos.Get(name) + if found { + return info, true + } + } + } + return GGUFTensorInfo{}, false +} + +// Search returns a list of GGUFTensorInfo with the names that match the given regex. +func (ltis GGUFLayerTensorInfos) Search(nameRegex *regexp.Regexp) (infos []GGUFTensorInfo) { + for i := range ltis { + switch v := ltis[i].(type) { + case GGUFTensorInfo: + if nameRegex.MatchString(v.Name) { + infos = append(infos, v) + } + case *GGUFNamedTensorInfos: + infos = append(infos, v.Search(nameRegex)...) + } + } + return infos +} + +// Index returns a map value to the GGUFTensorInfos with the given names, +// and the number of names found. +func (ltis GGUFLayerTensorInfos) Index(names []string) (infos map[string]GGUFTensorInfo, found int) { + ns := make(map[string]struct{}, len(names)) + for i := range names { + ns[names[i]] = struct{}{} + } + infos = make(map[string]GGUFTensorInfo) + for i := range ltis { + switch v := ltis[i].(type) { + case GGUFTensorInfo: + if _, ok := ns[v.Name]; ok { + infos[v.Name] = v + found++ + } + case *GGUFNamedTensorInfos: + inf, _ := v.Index(names) + for k := range inf { + infos[k] = inf[k] + found++ + } + } + if found == len(ns) { + break + } + } + return infos, found +} + +// Elements returns the number of elements of the GGUFLayerTensorInfos. +func (ltis GGUFLayerTensorInfos) Elements() uint64 { + var ret uint64 + for i := range ltis { + ret += ltis[i].Elements() + } + return ret +} + +// Bytes returns the number of bytes of the GGUFLayerTensorInfos. +func (ltis GGUFLayerTensorInfos) Bytes() uint64 { + var ret uint64 + for i := range ltis { + ret += ltis[i].Bytes() + } + return ret +} + +// Cut splits the GGUFLayerTensorInfos into two parts, +// and returns the GGUFLayerTensorInfos with the names that match the given names at first, +// and the GGUFLayerTensorInfos without the names at second, +// and true if the GGUFLayerTensorInfos with the names are found, and false otherwise. +func (ltis GGUFLayerTensorInfos) Cut(names []string) (before, after GGUFLayerTensorInfos, found bool) { + ns := make(map[string]struct{}, len(names)) + for i := range names { + ns[names[i]] = struct{}{} + } + before = make(GGUFLayerTensorInfos, 0, len(names)) + after = make(GGUFLayerTensorInfos, 0, len(ltis)) + + for i := range ltis { + switch v := ltis[i].(type) { + case GGUFTensorInfo: + if _, ok := ns[v.Name]; ok { + before = append(before, v) + continue + } + after = append(after, v) + case *GGUFNamedTensorInfos: + if _, ok := ns[v.Name]; ok { + before = append(before, v) + continue + } + after = append(after, v) + } + } + return before, after, len(before) > 0 +} + type _GGUFReader struct { v GGUFVersion o _GGUFReadOptions @@ -1652,7 +1374,7 @@ func (rd _GGUFReader) ReadArray() (v GGUFMetadataKVArrayValue, err error) { return v, fmt.Errorf("read array length: %w", err) } - if !rd.o.Approximate { + if !rd.o.SkipLargeMetadata { v.Array = make([]any, v.Len) for i := uint64(0); i < v.Len; i++ { v.Array[i], err = rd.ReadValue(v.Type) @@ -1795,65 +1517,42 @@ func (rd _GGUFTensorInfoReader) Read() (ti GGUFTensorInfo, err error) { return ti, fmt.Errorf("seek tensor info start: %w", err) } - if !rd.o.Approximate { - ti.Name, err = rd.ReadString() - if err != nil { - return ti, fmt.Errorf("read name: %w", err) - } - - ti.NDimensions, err = rd.ReadUint32() - if err != nil { - return ti, fmt.Errorf("read n dimensions: %w", err) - } + ti.Name, err = rd.ReadString() + if err != nil { + return ti, fmt.Errorf("read name: %w", err) + } - ti.Dimensions = make([]uint64, ti.NDimensions) - for i := uint32(0); i < ti.NDimensions; i++ { - if rd.v <= GGUFVersionV1 { - ti.Dimensions[i], err = rd.ReadUint64FromUint32() - } else { - ti.Dimensions[i], err = rd.ReadUint64() - } - if err != nil { - return ti, fmt.Errorf("read dimension %d: %w", i, err) - } - } + ti.NDimensions, err = rd.ReadUint32() + if err != nil { + return ti, fmt.Errorf("read n dimensions: %w", err) + } - { - v, err := rd.ReadUint32() - if err != nil { - return ti, fmt.Errorf("read type: %w", err) - } - ti.Type = GGMLType(v) - if ti.Type >= _GGMLTypeCount { - return ti, fmt.Errorf("invalid type: %v", ti.Type) - } + ti.Dimensions = make([]uint64, ti.NDimensions) + for i := uint32(0); i < ti.NDimensions; i++ { + if rd.v <= GGUFVersionV1 { + ti.Dimensions[i], err = rd.ReadUint64FromUint32() + } else { + ti.Dimensions[i], err = rd.ReadUint64() } - - ti.Offset, err = rd.ReadUint64() if err != nil { - return ti, fmt.Errorf("read offset: %w", err) + return ti, fmt.Errorf("read dimension %d: %w", i, err) } - - return ti, nil } - err = rd.SkipReadingString() - if err != nil { - return ti, fmt.Errorf("seek name: %w", err) - } - - nd, err := rd.ReadUint32() - if err != nil { - return ti, fmt.Errorf("seek n dimensions: %w", err) + { + v, err := rd.ReadUint32() + if err != nil { + return ti, fmt.Errorf("read type: %w", err) + } + ti.Type = GGMLType(v) + if ti.Type >= _GGMLTypeCount { + return ti, fmt.Errorf("invalid type: %v", ti.Type) + } } - if rd.v <= GGUFVersionV1 { - _, err = rd.f.Seek(int64(nd)*4 + /* Dimension */ +4 /* Type */ + 8 /* Offset */, io.SeekCurrent) - } else { - _, err = rd.f.Seek(int64(nd)*8 /* Dimension */ +4 /* Type */ +8 /* Offset */, io.SeekCurrent) - } + ti.Offset, err = rd.ReadUint64() if err != nil { - return ti, fmt.Errorf("seek dimensions/type/offset: %w", err) + return ti, fmt.Errorf("read offset: %w", err) } return ti, nil diff --git a/file_architecture_test.go b/file_architecture_test.go index f1c0247..84fa433 100644 --- a/file_architecture_test.go +++ b/file_architecture_test.go @@ -15,7 +15,7 @@ func TestGGUFFile_Architecture(t *testing.T) { ctx, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", - UseApproximate()) + SkipLargeMetadata()) if err != nil { t.Fatal(err) return @@ -31,7 +31,7 @@ func BenchmarkGGUFFile_Architecture(b *testing.B) { return } - f, err := ParseGGUFFile(mp, UseApproximate(), UseMMap()) + f, err := ParseGGUFFile(mp, SkipLargeMetadata(), UseMMap()) if err != nil { b.Fatal(err) return diff --git a/file_estimate.go b/file_estimate.go index e3398a6..366aa85 100644 --- a/file_estimate.go +++ b/file_estimate.go @@ -2,69 +2,122 @@ package gguf_parser // GGUFEstimate represents the estimated result of the GGUF file. type GGUFEstimate struct { - // MemoryTotal is the total memory usage. - MemoryTotal GGUFBytesScalar `json:"memoryTotal"` - // MemoryLoad is memory usage to load the model. - MemoryLoad GGUFBytesScalar `json:"memoryLoad"` - // KVCache is the usage of key-value cache. - KVCache GGUFEstimateKVCache `json:"kvCache"` + // Offload is the offloaded layers usage. + Offload *GGUFMemoryUsage `json:"offload,omitempty"` + // Total is the total memory usage. + Total GGUFMemoryUsage `json:"total"` } -// GGUFEstimateKVCache represents the usage of kv-cache. -type GGUFEstimateKVCache struct { - // MemoryTotal is the total memory usage. - MemoryTotal GGUFBytesScalar `json:"memoryTotal"` - // MemoryKey is the memory usage of the cached key. - MemoryKey GGUFBytesScalar `json:"memoryKey"` - // MemoryValue is the memory usage of the cached value. - MemoryValue GGUFBytesScalar `json:"memoryValue"` -} +type ( + // GGUFMemoryUsage represents the memory usage of the GGUF file. + GGUFMemoryUsage struct { + // KVCache is the usage of key-value cache. + KVCache GGUFKVCacheUsage `json:"kvCache"` + // Compute is the usage of transformer layers. + Compute GGUFBytesScalar `json:"compute"` + // IO is the usage of input/output layers. + IO GGUFBytesScalar `json:"io"` + } -// Estimate returns the estimated result of the GGUF file. + // GGUFKVCacheUsage represents the usage of kv-cache. + GGUFKVCacheUsage struct { + // Key is the memory usage of the cached key. + Key GGUFBytesScalar `json:"key"` + // Value is the memory usage of the cached value. + Value GGUFBytesScalar `json:"value"` + } +) + +// Estimate returns the inference usage estimated result of the GGUF file. func (gf *GGUFFile) Estimate(opts ...GGUFEstimateOption) (ge GGUFEstimate) { var o _GGUFEstimateOptions for _, opt := range opts { opt(&o) } - ge.MemoryLoad = gf.ModelSize - ge.KVCache = gf.estimateKVCache(gf.Architecture(), o) - ge.MemoryTotal = ge.MemoryLoad + ge.KVCache.MemoryTotal - + ge.Offload, ge.Total = gf.estimateMemoryUsage(gf.Architecture(), o) return ge } -// estimateKVCache estimates the key-value cache, -// which is inspired by https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L2479-L2501 -func (gf *GGUFFile) estimateKVCache(a GGUFArchitectureMetadata, o _GGUFEstimateOptions) (kv GGUFEstimateKVCache) { - kt, vt := GGMLTypeF16, GGMLTypeF16 +func (m GGUFMemoryUsage) Sum() GGUFBytesScalar { + return m.Compute + m.KVCache.Sum() + m.IO +} - if o.CacheKeyType != nil { - kt = *o.CacheKeyType - } - if o.CacheValueType != nil { - vt = *o.CacheValueType +func (c GGUFKVCacheUsage) Sum() GGUFBytesScalar { + return c.Key + c.Value +} + +func (gf *GGUFFile) estimateMemoryUsage(a GGUFArchitectureMetadata, o _GGUFEstimateOptions) (offload *GGUFMemoryUsage, total GGUFMemoryUsage) { + if o.OffloadLayers != nil { + offload = &GGUFMemoryUsage{} } - var ( - embedKeyGQA = uint64(a.AttentionKeyLength) * a.AttentionHeadCountKV - embedValGQA = uint64(a.AttentionValueLength) * a.AttentionHeadCountKV - kvSize = a.MaximumContextLength - ) + // KV cache. + // https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L2479-L2501 { - // Correct. - if a.SSMConvolutionKernel > 0 { - embedKeyGQA += uint64(a.SSMConvolutionKernel - 1*a.SSMInnerSize) - embedValGQA += uint64(a.SSMStateSize * a.SSMInnerSize) + kt, vt := GGMLTypeF16, GGMLTypeF16 + + if o.CacheKeyType != nil { + kt = *o.CacheKeyType } - if o.ContextSize != nil { - kvSize = uint64(*o.ContextSize) + if o.CacheValueType != nil { + vt = *o.CacheValueType } + + var ( + embedKeyGQA = uint64(a.AttentionKeyLength) * a.AttentionHeadCountKV + embedValGQA = uint64(a.AttentionValueLength) * a.AttentionHeadCountKV + kvSize = a.MaximumContextLength + ) + { + // Correct. + if a.SSMConvolutionKernel > 0 { + embedKeyGQA += uint64(a.SSMConvolutionKernel - 1*a.SSMInnerSize) + embedValGQA += uint64(a.SSMStateSize * a.SSMInnerSize) + } + if o.ContextSize != nil { + kvSize = uint64(*o.ContextSize) + } + } + + krs := kt.RowSizeOf([]uint64{embedKeyGQA * kvSize}) + vrs := vt.RowSizeOf([]uint64{embedValGQA * kvSize}) + + if offload != nil { + v := *o.OffloadLayers + if v > a.BlockCount { + v = a.BlockCount + } + offload.KVCache.Key = GGUFBytesScalar(krs * v) + offload.KVCache.Value = GGUFBytesScalar(vrs * v) + } + + total.KVCache.Key = GGUFBytesScalar(krs * a.BlockCount) + total.KVCache.Value = GGUFBytesScalar(vrs * a.BlockCount) } - kv.MemoryKey = GGUFBytesScalar(kt.RowSizeOf([]uint64{embedKeyGQA * kvSize}) * a.BlockCount) - kv.MemoryValue = GGUFBytesScalar(vt.RowSizeOf([]uint64{embedValGQA * kvSize}) * a.BlockCount) - kv.MemoryTotal = kv.MemoryKey + kv.MemoryValue + ls := gf.Layers() + bls, als, _ := ls.Cut([]string{ + "token_embd.weight", + "output.weight", + "output_norm.weight", + }) + + // IO. + total.IO = GGUFBytesScalar(bls.Bytes()) + + // Compute. + if offload != nil { + v := *o.OffloadLayers + if v >= a.BlockCount { + offload.Compute = GGUFBytesScalar(als.Bytes()) + } else { + for i := uint64(len(als) - 1); i >= uint64(len(als))-v; i-- { + offload.Compute += GGUFBytesScalar(als[i].Bytes()) + } + } + } + total.Compute = GGUFBytesScalar(als.Bytes()) - return kv + return offload, total } diff --git a/file_estimate_option.go b/file_estimate_option.go index d233515..d63ef41 100644 --- a/file_estimate_option.go +++ b/file_estimate_option.go @@ -9,6 +9,7 @@ type ( ContextSize *int32 CacheKeyType *GGMLType CacheValueType *GGMLType + OffloadLayers *uint64 } GGUFEstimateOption func(*_GGUFEstimateOptions) ) @@ -50,3 +51,13 @@ func WithCacheValueType(t GGMLType) GGUFEstimateOption { } } } + +// WithOffloadLayers sets the number of layers to offload. +func WithOffloadLayers(layers uint64) GGUFEstimateOption { + return func(o *_GGUFEstimateOptions) { + if layers <= 0 { + return + } + o.OffloadLayers = &layers + } +} diff --git a/file_estimate_test.go b/file_estimate_test.go index 760723a..d922174 100644 --- a/file_estimate_test.go +++ b/file_estimate_test.go @@ -16,25 +16,12 @@ func TestGGUFFile_Estimate(t *testing.T) { }{ { name: "mixtral 7B", - given: func() *GGUFFile { - f, err := ParseGGUFFileFromHuggingFace( - ctx, - "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", - "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf") - if err != nil { - t.Fatal(err) - } - return f - }(), - }, - { - name: "mixtral 7B with approximate", given: func() *GGUFFile { f, err := ParseGGUFFileFromHuggingFace( ctx, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", - UseApproximate()) + SkipLargeMetadata()) if err != nil { t.Fatal(err) } @@ -43,25 +30,12 @@ func TestGGUFFile_Estimate(t *testing.T) { }, { name: "mixtral 8x7B", - given: func() *GGUFFile { - f, err := ParseGGUFFileFromHuggingFace( - ctx, - "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO-GGUF", - "Nous-Hermes-2-Mixtral-8x7B-DPO.Q5_K_M.gguf") - if err != nil { - t.Fatal(err) - } - return f - }(), - }, - { - name: "mixtral 8x7B with approximate", given: func() *GGUFFile { f, err := ParseGGUFFileFromHuggingFace( ctx, "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO-GGUF", "Nous-Hermes-2-Mixtral-8x7B-DPO.Q5_K_M.gguf", - UseApproximate()) + SkipLargeMetadata()) if err != nil { t.Fatal(err) } @@ -70,25 +44,12 @@ func TestGGUFFile_Estimate(t *testing.T) { }, { name: "wizardlm 8x22B", - given: func() *GGUFFile { - f, err := ParseGGUFFileFromHuggingFace( - ctx, - "MaziyarPanahi/WizardLM-2-8x22B-GGUF", - "WizardLM-2-8x22B.IQ1_M.gguf") - if err != nil { - t.Fatal(err) - } - return f - }(), - }, - { - name: "wizardlm 8x22B with approximate", given: func() *GGUFFile { f, err := ParseGGUFFileFromHuggingFace( ctx, "MaziyarPanahi/WizardLM-2-8x22B-GGUF", "WizardLM-2-8x22B.IQ1_M.gguf", - UseApproximate()) + SkipLargeMetadata()) if err != nil { t.Fatal(err) } @@ -111,7 +72,7 @@ func TestGGUFFile_Estimate_KVCache(t *testing.T) { ctx, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", - UseApproximate()) + SkipLargeMetadata()) if err != nil { t.Fatal(err) return @@ -128,7 +89,36 @@ func TestGGUFFile_Estimate_KVCache(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - t.Log("\n", spew.Sdump(f.Estimate(tc.opts...).KVCache), "\n") + t.Log("\n", spew.Sdump(f.Estimate(tc.opts...)), "\n") + }) + } +} + +func TestGGUFFile_Estimate_Offload(t *testing.T) { + ctx := context.Background() + + f, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", + SkipLargeMetadata()) + if err != nil { + t.Fatal(err) + return + } + + cases := []struct { + name string + opts []GGUFEstimateOption + }{ + {"offload 0 layer", []GGUFEstimateOption{WithContextSize(512), WithOffloadLayers(0)}}, + {"offload 1 layer", []GGUFEstimateOption{WithContextSize(512), WithOffloadLayers(1)}}, + {"offload 10 layers", []GGUFEstimateOption{WithContextSize(512), WithOffloadLayers(10)}}, + {"offload 33 layers", []GGUFEstimateOption{WithContextSize(512), WithOffloadLayers(33)}}, // exceeds the number of layers + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Log("\n", spew.Sdump(f.Estimate(tc.opts...)), "\n") }) } } diff --git a/file_model_test.go b/file_model_test.go index 3122fb3..c52019a 100644 --- a/file_model_test.go +++ b/file_model_test.go @@ -17,7 +17,7 @@ func TestGGUFFile_Model(t *testing.T) { ctx, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", - UseApproximate()) + SkipLargeMetadata()) if err != nil { t.Fatal(err) return @@ -33,7 +33,7 @@ func BenchmarkGGUFFile_Model(b *testing.B) { return } - f, err := ParseGGUFFile(mp, UseMMap(), UseApproximate()) + f, err := ParseGGUFFile(mp, UseMMap(), SkipLargeMetadata()) if err != nil { b.Fatal(err) return diff --git a/file_option.go b/file_option.go index 28a3e0d..df5bce0 100644 --- a/file_option.go +++ b/file_option.go @@ -4,8 +4,8 @@ import "net/url" type ( _GGUFReadOptions struct { - Debug bool - Approximate bool + Debug bool + SkipLargeMetadata bool // Local. MMap bool @@ -26,16 +26,11 @@ func UseDebug() GGUFReadOption { } } -// UseApproximate uses approximate mode to read the file. -// -// With this, the file is read in a faster way, -// for example, -// skips reading tedious GGUFMetadataKV items, -// skips reading GGUFTensorInfos, -// guess model size/parameters/bpw, etc. -func UseApproximate() GGUFReadOption { +// SkipLargeMetadata skips reading large GGUFMetadataKV items, +// which are not necessary for most cases. +func SkipLargeMetadata() GGUFReadOption { return func(o *_GGUFReadOptions) { - o.Approximate = true + o.SkipLargeMetadata = true } } diff --git a/file_test.go b/file_test.go index 83ed85d..7991c39 100644 --- a/file_test.go +++ b/file_test.go @@ -31,7 +31,7 @@ func TestParseGGUFFile(t *testing.T) { // Fast read. { - f, err := ParseGGUFFile(mp, UseApproximate(), UseMMap()) + f, err := ParseGGUFFile(mp, SkipLargeMetadata(), UseMMap()) if err != nil { t.Fatal(err) return @@ -72,7 +72,7 @@ func BenchmarkParseGGUFFileMMap(b *testing.B) { }) } -func BenchmarkParseGGUFFileApproximate(b *testing.B) { +func BenchmarkParseGGUFFileSkipLargeMetadata(b *testing.B) { mp, ok := os.LookupEnv("TEST_MODEL_PATH") if !ok { b.Skip("TEST_MODEL_PATH is not set") @@ -93,9 +93,9 @@ func BenchmarkParseGGUFFileApproximate(b *testing.B) { }) b.ResetTimer() - b.Run("UseApproximate", func(b *testing.B) { + b.Run("SkipLargeMetadata", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, err := ParseGGUFFile(mp, UseMMap(), UseApproximate()) + _, err := ParseGGUFFile(mp, SkipLargeMetadata(), UseMMap()) if err != nil { b.Fatal(err) return @@ -126,7 +126,7 @@ func TestParseGGUFFileRemote(t *testing.T) { // Fast read. { - f, err := ParseGGUFFileRemote(ctx, u, UseDebug(), UseApproximate()) + f, err := ParseGGUFFileRemote(ctx, u, UseDebug(), SkipLargeMetadata()) if err != nil { t.Fatal(err) return @@ -146,7 +146,7 @@ func BenchmarkParseGGUFFileRemoteWithBufferSize(b *testing.B) { b.ResetTimer() b.Run("256KibBuffer", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, err := ParseGGUFFileRemote(ctx, u, UseApproximate(), UseBufferSize(256*1024)) + _, err := ParseGGUFFileRemote(ctx, u, SkipLargeMetadata(), UseBufferSize(256*1024)) if err != nil { b.Fatal(err) return @@ -157,7 +157,7 @@ func BenchmarkParseGGUFFileRemoteWithBufferSize(b *testing.B) { b.ResetTimer() b.Run("1MibBuffer", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, err := ParseGGUFFileRemote(ctx, u, UseApproximate(), UseBufferSize(1024*1024)) + _, err := ParseGGUFFileRemote(ctx, u, SkipLargeMetadata(), UseBufferSize(1024*1024)) if err != nil { b.Fatal(err) return @@ -168,7 +168,7 @@ func BenchmarkParseGGUFFileRemoteWithBufferSize(b *testing.B) { b.ResetTimer() b.Run("4MibBuffer", func(b *testing.B) { for i := 0; i < b.N; i++ { - _, err := ParseGGUFFileRemote(ctx, u, UseApproximate(), UseBufferSize(4*1024*1024)) + _, err := ParseGGUFFileRemote(ctx, u, SkipLargeMetadata(), UseBufferSize(4*1024*1024)) if err != nil { b.Fatal(err) return @@ -192,7 +192,7 @@ func TestParseGGUFFileFromHuggingFace(t *testing.T) { } for _, tc := range cases { t.Run(tc[0]+"/"+tc[1], func(t *testing.T) { - f, err := ParseGGUFFileFromHuggingFace(ctx, tc[0], tc[1], UseApproximate()) + f, err := ParseGGUFFileFromHuggingFace(ctx, tc[0], tc[1], SkipLargeMetadata()) if err != nil { t.Fatal(err) return diff --git a/file_tokenizer_test.go b/file_tokenizer_test.go index a888575..39cdf44 100644 --- a/file_tokenizer_test.go +++ b/file_tokenizer_test.go @@ -15,7 +15,7 @@ func TestGGUFFile_Tokenizer(t *testing.T) { ctx, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", - UseApproximate()) + SkipLargeMetadata()) if err != nil { t.Fatal(err) return @@ -31,7 +31,7 @@ func BenchmarkGGUFFile_Tokenizer(b *testing.B) { return } - f, err := ParseGGUFFile(mp, UseApproximate(), UseMMap()) + f, err := ParseGGUFFile(mp, SkipLargeMetadata(), UseMMap()) if err != nil { b.Fatal(err) return