Skip to content

Commit

Permalink
refactor: estimate
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed May 30, 2024
1 parent 67ab801 commit ca3f025
Show file tree
Hide file tree
Showing 11 changed files with 498 additions and 722 deletions.
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,10 @@ if err != nil {

```

#### Use approximate parsing

> The approximate parsing is faster than the accurate one,
> but the result may not be accurate.
#### Skip large metadata

```go
f, err := ParseGGUFFile("path/to/model.gguf", UseApproximate())
f, err := ParseGGUFFile("path/to/model.gguf", SkipLargeMetadata())
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -124,6 +121,12 @@ spew.Dump(f.Estimate(WithContextSize(4096) /* 4K */))

```

#### Estimate with specific offload layers

```go
spew.Dump(f.Estimate(WithOffloadLayers(10)))
```

## License

MIT
65 changes: 45 additions & 20 deletions cmd/gguf-parser/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ func main() {
url string
repo, model string
// read options
debug bool
approximate = true
mmap = true
skipProxy bool
skipTLS bool
debug bool
mmap = true
skipProxy bool
skipTLS bool
// estimate options
ctxSize = 512
kvType = "f16"
ctxSize = 512
kvType = "f16"
offloadLayers uint64
// output options
skipModel bool
skipArchitecture bool
Expand All @@ -58,12 +58,12 @@ func main() {
fs.StringVar(&model, "model", model, "Model below the --repo, e.g. "+
"Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf")
fs.BoolVar(&debug, "debug", debug, "Debug mode")
fs.BoolVar(&approximate, "approximate", approximate, "Speed up reading")
fs.BoolVar(&mmap, "mmap", mmap, "Use mmap to read the local file")
fs.BoolVar(&skipProxy, "skip-proxy", skipProxy, "Skip using proxy when reading from a remote URL")
fs.BoolVar(&skipTLS, "skip-tls", skipTLS, "Skip TLS verification when reading from a remote URL")
fs.IntVar(&ctxSize, "ctx-size", ctxSize, "Context size to estimate memory usage")
fs.StringVar(&kvType, "kv-type", kvType, "Key-Value cache type, select from [f32, f16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1]")
fs.Uint64Var(&offloadLayers, "offload-layers", offloadLayers, "Specify how many layers to offload, default is fully offloading")
fs.BoolVar(&skipModel, "skip-model", skipModel, "Skip model metadata")
fs.BoolVar(&skipArchitecture, "skip-architecture", skipArchitecture, "Skip architecture metadata")
fs.BoolVar(&skipTokenizer, "skip-tokenizer", skipTokenizer, "Skip tokenizer metadata")
Expand All @@ -77,9 +77,11 @@ func main() {

// Prepare options.

var ropts []GGUFReadOption
if approximate {
ropts = append(ropts, UseApproximate())
ropts := []GGUFReadOption{
SkipLargeMetadata(),
}
if debug {
ropts = append(ropts, UseDebug())
}
if mmap {
ropts = append(ropts, UseMMap())
Expand All @@ -91,7 +93,9 @@ func main() {
ropts = append(ropts, SkipTLSVerification())
}

var eopts []GGUFEstimateOption
eopts := []GGUFEstimateOption{
WithContextSize(512),
}
if ctxSize > 0 {
eopts = append(eopts, WithContextSize(int32(ctxSize)))
}
Expand All @@ -117,6 +121,9 @@ func main() {
}
eopts = append(eopts, WithCacheKeyType(kv), WithCacheValueType(kv))
}
if offloadLayers > 0 {
eopts = append(eopts, WithOffloadLayers(offloadLayers))
}

// Parse GGUF file.

Expand Down Expand Up @@ -190,8 +197,9 @@ func main() {

if !skipModel {
tprintf(
[]string{"Name", "Architecture", "Quantization Version", "File Type", "Little Endian", "Size", "Parameters", "BPW"},
[]string{"", "Name", "Architecture", "Quantization Version", "File Type", "Little Endian", "Size", "Parameters", "BPW"},
[]string{
"MODEL",
m.Name,
m.Architecture,
sprintf(m.QuantizationVersion),
Expand All @@ -205,8 +213,9 @@ func main() {

if !skipArchitecture {
tprintf(
[]string{"Maximum Context Length", "Embedding Length", "Layers", "Feed Forward Length", "Expert Count", "Vocabulary Length"},
[]string{"", "Maximum Context", "Embedding", "Layers", "Feed Forward", "Expert Count", "Vocabulary"},
[]string{
"ARCHITECTURE",
sprintf(a.MaximumContextLength),
sprintf(a.EmbeddingLength),
fmt.Sprintf("%d + 1 = %d",
Expand All @@ -220,8 +229,9 @@ func main() {

if !skipTokenizer {
tprintf(
[]string{"Tokenizer Model", "Tokens Length", "Added Tokens Length", "BOS", "EOS", "Unknown", "Separator", "Padding"},
[]string{"", "Model", "Tokens", "Added Tokens", "BOS", "EOS", "Unknown", "Separator", "Padding"},
[]string{
"TOKENIZER",
t.Model,
sprintf(t.TokensLength),
sprintf(t.AddedTokensLength),
Expand All @@ -235,12 +245,25 @@ func main() {

if !skipEstimate {
tprintf(
[]string{"Load Memory", "KVCache Memory", "Total Memory"},
[]string{"", "KV Cache", "Compute", "IO", "Sum"},
[]string{
e.MemoryLoad.String(),
e.KVCache.MemoryTotal.String(),
e.MemoryTotal.String(),
"ESTIMATE TOTAL",
e.Total.KVCache.Sum().String(),
e.Total.Compute.String(),
e.Total.IO.String(),
e.Total.Sum().String(),
})
if e.Offload != nil {
tprintf(
[]string{"", "KV Cache", "Compute", "IO", "Sum"},
[]string{
"ESTIMATE OFFLOAD",
e.Offload.KVCache.Sum().String(),
e.Offload.Compute.String(),
e.Offload.IO.String(),
e.Offload.Sum().String(),
})
}
}
}

Expand Down Expand Up @@ -277,10 +300,12 @@ func tprintf(headers, rows []string) {
tb := tablewriter.NewWriter(os.Stdout)
tb.SetHeaderAlignment(tablewriter.ALIGN_CENTER)
tb.SetAlignment(tablewriter.ALIGN_CENTER)
tb.SetHeaderLine(true)
tb.SetBorder(true)
tb.SetTablePadding("\t")
tb.SetHeaderLine(true)
tb.SetHeader(headers)
tb.SetAutoMergeCells(true)
tb.SetRowLine(true)
tb.Append(rows)
tb.Render()
fmt.Println()
Expand Down
Loading

0 comments on commit ca3f025

Please sign in to comment.