From c5a7ff2ec0be567d4a4a45027e3a314e5b9a62c9 Mon Sep 17 00:00:00 2001 From: thxCode Date: Tue, 28 May 2024 18:29:23 +0800 Subject: [PATCH] feat: first commit Signed-off-by: thxCode --- .gitattributes | 4 + .github/workflows/ci.yml | 47 + .github/workflows/cmd.yml | 44 + .gitignore | 31 + .golangci.yaml | 146 ++ LICENSE | 21 + Makefile | 86 ++ README.md | 131 ++ cmd/gguf-parser/go.mod | 22 + cmd/gguf-parser/go.sum | 28 + cmd/gguf-parser/main.go | 175 +++ file.go | 1311 +++++++++++++++++ file_architecture.go | 246 ++++ file_architecture_test.go | 46 + file_estimate.go | 540 +++++++ file_estimate_option.go | 60 + file_estimate_test.go | 237 +++ file_general.go | 314 ++++ file_general_test.go | 81 + file_option.go | 76 + file_test.go | 203 +++ file_tokenizer.go | 89 ++ file_tokenizer_test.go | 46 + filename.go | 115 ++ filename_test.go | 137 ++ gen.go | 2 + gen.stringer.go | 10 + go.mod | 21 + go.sum | 26 + util/bytex/pool.go | 91 ++ util/funcx/error.go | 65 + util/httpx/client.go | 226 +++ util/httpx/client_helper.go | 60 + util/httpx/client_options.go | 112 ++ util/httpx/file.go | 197 +++ util/httpx/proxy.go | 37 + util/httpx/transport.go | 25 + util/httpx/transport_options.go | 193 +++ util/osx/env.go | 29 + util/osx/file.go | 84 ++ util/osx/file_mmap.go | 109 ++ util/osx/file_mmap_js.go | 27 + util/osx/file_mmap_unix.go | 30 + util/osx/file_mmap_windows.go | 33 + util/osx/file_mmap_windows_386.go | 16 + util/osx/file_mmap_windows_non386.go | 18 + util/ptr/pointer.go | 163 ++ util/stringx/string.go | 91 ++ zz_generated.ggmltype.stringer.go | 54 + zz_generated.gguffiletype.stringer.go | 48 + zz_generated.ggufmagic.stringer.go | 41 + ...enerated.ggufmetadatavaluetype.stringer.go | 36 + zz_generated.ggufversion.stringer.go | 26 + 53 files changed, 6106 insertions(+) create mode 100644 .gitattributes create mode 100644 .github/workflows/ci.yml create mode 100644 .github/workflows/cmd.yml create mode 100644 .gitignore create mode 100644 .golangci.yaml create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 cmd/gguf-parser/go.mod create mode 100644 cmd/gguf-parser/go.sum create mode 100644 cmd/gguf-parser/main.go create mode 100644 file.go create mode 100644 file_architecture.go create mode 100644 file_architecture_test.go create mode 100644 file_estimate.go create mode 100644 file_estimate_option.go create mode 100644 file_estimate_test.go create mode 100644 file_general.go create mode 100644 file_general_test.go create mode 100644 file_option.go create mode 100644 file_test.go create mode 100644 file_tokenizer.go create mode 100644 file_tokenizer_test.go create mode 100644 filename.go create mode 100644 filename_test.go create mode 100644 gen.go create mode 100644 gen.stringer.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 util/bytex/pool.go create mode 100644 util/funcx/error.go create mode 100644 util/httpx/client.go create mode 100644 util/httpx/client_helper.go create mode 100644 util/httpx/client_options.go create mode 100644 util/httpx/file.go create mode 100644 util/httpx/proxy.go create mode 100644 util/httpx/transport.go create mode 100644 util/httpx/transport_options.go create mode 100644 util/osx/env.go create mode 100644 util/osx/file.go create mode 100644 util/osx/file_mmap.go create mode 100644 util/osx/file_mmap_js.go create mode 100644 util/osx/file_mmap_unix.go create mode 100644 util/osx/file_mmap_windows.go create mode 100644 util/osx/file_mmap_windows_386.go create mode 100644 util/osx/file_mmap_windows_non386.go create mode 100644 util/ptr/pointer.go create mode 100644 util/stringx/string.go create mode 100644 zz_generated.ggmltype.stringer.go create mode 100644 zz_generated.gguffiletype.stringer.go create mode 100644 zz_generated.ggufmagic.stringer.go create mode 100644 zz_generated.ggufmetadatavaluetype.stringer.go create mode 100644 zz_generated.ggufversion.stringer.go diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..45b58eb --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +* text=auto eol=lf + +**/go.sum linguist-generated=true +**/zz_generated.*.go linguist-generated=true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..041d6c0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,47 @@ +name: ci + +permissions: + contents: read + pull-requests: read + actions: read + +defaults: + run: + shell: bash + +on: + push: + branches: + - 'main' + pull_request: + branches: + - 'main' + +jobs: + ci: + timeout-minutes: 15 + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 1 + persist-credentials: false + - name: Setup Go + timeout-minutes: 15 + uses: actions/setup-go@v5 + with: + go-version: "1.22.3" + cache-dependency-path: | + **/go.sum + - name: Setup Toolbox + timeout-minutes: 5 + uses: actions/cache@v3 + with: + key: toolbox-${{ runner.os }} + path: | + ${{ github.workspace }}/.sbin + - name: Make + run: make ci + env: + LINT_DIRTY: "true" diff --git a/.github/workflows/cmd.yml b/.github/workflows/cmd.yml new file mode 100644 index 0000000..8be8cc6 --- /dev/null +++ b/.github/workflows/cmd.yml @@ -0,0 +1,44 @@ +name: cmd + +permissions: + contents: write + actions: read + id-token: write + +defaults: + run: + shell: bash + +on: + push: + tags: + - "v*.*.*" + +jobs: + build: + timeout-minutes: 15 + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 1 + persist-credentials: false + - name: Setup Go + timeout-minutes: 15 + uses: actions/setup-go@v5 + with: + go-version: "1.22.3" + cache-dependency-path: | + **/go.sum + - name: Make + run: make build + env: + VERSION: "${{ github.ref_name }}" + - name: Release + uses: softprops/action-gh-release@v1 + with: + fail_on_unmatched_files: true + tag_name: "${{ github.ref_name }}" + prerelease: ${{ contains(github.ref, 'rc') }} + files: ${{ github.workspace }}/.dist/* diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..03725f2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,31 @@ +# Files +.DS_Store +*.lock +*.test +*.out +*.swp +*.swo +*.db +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.log +go.work +go.work.* + +# Dirs +/.idea +/.vscode +/.kube +/.terraform +/.vagrant +/.bundle +/.cache +/.docker +/.entc +/.sbin +/.dist +/log +/certs diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..480355e --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,146 @@ +run: + timeout: 10m + tests: true + modules-download-mode: readonly + go: "1.22" + +# output configuration options +output: + print-issued-lines: true + print-linter-name: true + uniq-by-line: true + path-prefix: "" + sort-results: true + +linters: + disable-all: true + enable: + - asciicheck + - bidichk + - decorder + - durationcheck + - errcheck + - errname + - errorlint + - exportloopref + - godot + - goconst + - gocritic + - gosimple + - gosec + - govet + - gofumpt + - gofmt + - ineffassign + - importas + - lll + - makezero + - misspell + - nakedret + - nilerr + - prealloc + - predeclared + - revive + - staticcheck + - stylecheck + - typecheck + - unconvert + - unparam + - unused + - usestdlibvars + - whitespace + +linters-settings: + decorder: + dec-order: + - const + - var + - func + disable-init-func-first-check: false + disable-dec-order-check: true + errorlint: + errorf: true + asserts: true + comparison: true + godot: + scope: all + exclude: + - "(?i)^ FIXME:" + - "(?i)^ TODO:" + - "(?i)^ SPDX\\-License\\-Identifier:" + - "(?i)^ +" + period: true + capital: false + goconst: + min-len: 3 + min-occurrences: 10 + gosimple: + checks: [ "all" ] + gosec: + severity: "low" + confidence: "low" + excludes: + - G101 + - G107 + - G112 + - G404 + gofumpt: + extra-rules: true + gofmt: + simplify: true + rewrite-rules: + - pattern: 'interface{}' + replacement: 'any' + - pattern: 'a[b:len(a)]' + replacement: 'a[b:]' + importas: + no-unaliased: true + lll: + line-length: 150 + tab-width: 1 + makezero: + always: false + misspell: + locale: US + nakedret: + max-func-lines: 60 + revive: + rules: + - name: var-naming + disabled: true + arguments: + - [ "HTTP", "ID", "TLS", "TCP", "UDP", "API", "CA", "URL", "DNS" ] + staticcheck: + checks: [ "all", "-SA1019", "-SA2002", "-SA5008" ] + stylecheck: + checks: [ "all", "-ST1003" ] + unparam: + check-exported: false + unused: + field-writes-are-uses: true + post-statements-are-reads: true + exported-is-used: true + exported-fields-are-used: true + parameters-are-used: true + local-variables-are-used: true + generated-is-used: true + usestdlibvars: + http-method: true + http-status-code: true + time-weekday: true + time-month: true + time-layout: true + crypto-hash: true + +issues: + exclude-files: + - "doc.go" + - "zz_generated.*.go" + - "gen.*.go" + exclude-rules: + - path: _test\.go + linters: + - errcheck + - gosec + - makezero + - lll diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..804750e --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 gguf-parser-go authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..561ecc5 --- /dev/null +++ b/Makefile @@ -0,0 +1,86 @@ +.SILENT: +.DEFAULT_GOAL := ci + +SHELL := /bin/bash + +SRCDIR := $(patsubst %/,%,$(dir $(abspath $(lastword $(MAKEFILE_LIST))))) +GOOS ?= $(shell go env GOOS) +GOARCH ?= $(shell go env GOARCH) +LINT_DIRTY ?= false +VERSION ?= $(shell git rev-parse --abbrev-ref HEAD 2>/dev/null | tr '[:upper:]' '[:lower:]' || echo "unknown") + +deps: + @echo "+++ deps +++" + + go mod tidy + go mod download + + @echo "--- deps ---" + +generate: + @echo "+++ generate +++" + + go generate $(SRCDIR)/... + + @echo "--- generate ---" + +lint: + @echo "+++ lint +++" + + if [[ "$(LINT_DIRTY)" == "true" ]]; then \ + if [[ -n $$(git status --porcelain) ]]; then \ + echo "Code tree is dirty."; \ + exit 1; \ + fi; \ + fi + + [[ -f "$(SRCDIR)/.sbin/goimports-reviser" ]] || \ + curl --retry 3 --retry-all-errors --retry-delay 3 -sSfL "https://github.com/incu6us/goimports-reviser/releases/download/v3.6.4/goimports-reviser_3.6.4_$(GOOS)_$(GOARCH).tar.gz" \ + | tar -zxvf - --directory "$(SRCDIR)/.sbin" --no-same-owner --exclude ./LICENSE --exclude ./README.md && chmod +x "$(SRCDIR)/.sbin/goimports-reviser" + go list -f "{{.Dir}}" $(SRCDIR)/... | xargs -I {} find {} -maxdepth 1 -type f -name '*.go' ! -name 'gen.*' ! -name 'zz_generated.*' \ + | xargs -I {} "$(SRCDIR)/.sbin/goimports-reviser" -use-cache -imports-order=std,general,company,project,blanked,dotted -output=file {} + + [[ -f "$(SRCDIR)/.sbin/golangci-lint" ]] || \ + curl --retry 3 --retry-all-errors --retry-delay 3 -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh \ + | sh -s -- -b "$(SRCDIR)/.sbin" "v1.57.2" + "$(SRCDIR)/.sbin/golangci-lint" run --fix $(SRCDIR)/... + + @echo "--- lint ---" + +test: + @echo "+++ test +++" + + go test -v -failfast -race -cover -timeout=30m $(SRCDIR)/... + + @echo "--- test ---" + +benchmark: + @echo "+++ benchmark +++" + + go test -v -failfast -run="^Benchmark[A-Z]+" -bench=. -benchmem -timeout=30m $(SRCDIR)/... + + @echo "--- benchmark ---" + +gguf-parser: + @echo "+++ gguf-parser +++" + [[ -d "$(SRCDIR)/.dist" ]] || mkdir -p "$(SRCDIR)/.dist" + + cd "$(SRCDIR)/cmd/gguf-parser" && for GOOS in darwin linux windows; do \ + for GOARCH in amd64 arm64; do \ + echo "Building gguf-parser for $$GOOS-$$GOARCH"; \ + if [[ $$GOOS == "windows" ]]; then \ + SUFFIX=".exe"; \ + else \ + SUFFIX=""; \ + fi; \ + GOOS="$$GOOS" GOARCH="$$GOARCH" CGO_ENABLED=1 go build \ + -trimpath \ + -ldflags="-w -s -X main.Version=${GIT_VERSION}" \ + -tags="netgo" \ + -o $(SRCDIR)/.dist/gguf-parser-$$GOOS-$$GOARCH$$SUFFIX; \ + done; \ + done + +ci: deps generate test lint + +build: gguf-parser diff --git a/README.md b/README.md new file mode 100644 index 0000000..5330b98 --- /dev/null +++ b/README.md @@ -0,0 +1,131 @@ +# GGUF Parser + +> tl;dr, Go parser for the [GGUR](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md). + +[GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) is a file format for storing models for inference +with GGML and executors based on GGML. GGUF is a binary format that is designed for fast loading and saving of models, +and for ease of reading. Models are traditionally developed using PyTorch or another framework, and then converted to +GGUF for use in GGML. + +GGUF Parser provides some functions to parse the GGUF file in Go for the following purposes: + +- Read metadata from the GGUF file without downloading the whole model remotely. +- Estimate the model usage. + +Import the package as below. + +```shell +go get github.com/thxcode/gguf-parser-go +``` + +## Examples + +### Load model + +```go +import ( + "github.com/davecgh/go-spew/spew" + . "github.com/thxcode/gguf-parser-go" +) + +f, err := ParseGGUFFile("path/to/model.gguf") +if err != nil { + panic(err) +} + +spew.Dump(f) + +``` + +#### Use MMap + +```go +f, err := ParseGGUFFile("path/to/model.gguf", UseMMap()) +if err != nil { + panic(err) +} + +``` + +#### Skip reading tedious array metadata + +```go +f, err := ParseGGUFFile("path/to/model.gguf", UseDeferReading()) +if err != nil { + panic(err) +} + +``` + +### Load model from remote + +```go +import ( + "context" + "github.com/davecgh/go-spew/spew" + . "github.com/thxcode/gguf-parser-go" +) + +f, err := ParseGGUFFileRemote(context.Background(), "https://example.com/model.gguf") +if err != nil { + panic(err) +} + +spew.Dump(f) + +``` + +#### Adjust requesting buffer size + +```go +f, err := ParseGGUFFileRemote(context.Background(), "https://example.com/model.gguf", UseBufferSize(1 * 1024 * 1024) /* 1M */) +if err != nil { + panic(err) +} + +``` + +### View information + +```go +// General +spew.Dump(f.General()) + +// Architecture +spew.Dump(f.Architecture()) + +// Tokenizer +spew.Dump(f.Tokenizer()) + +``` + +### Estimate usage + +```go +g, a := f.General(), f.Architecture() +spew.Dump(f.Estimate(g, a)) + +``` + +#### Estimate with larger prompt + +```go +g, a := f.General(), f.Architecture() +spew.Dump(f.Estimate(g, a, WithContextSize(4096) /* 4K */)) + +``` + +#### Estimate approximately to speed up + +> The approximate estimation is faster than the accurate one, +> but the result may not be accurate. + +```go +g, a := f.General(), f.Architecture() +spew.Dump(f.Estimate(g, a, WithApproximate())) + +``` + +## License + +MIT diff --git a/cmd/gguf-parser/go.mod b/cmd/gguf-parser/go.mod new file mode 100644 index 0000000..a8cbfc5 --- /dev/null +++ b/cmd/gguf-parser/go.mod @@ -0,0 +1,22 @@ +module github.com/thxcode/gguf-parser-go/cmd/gguf-parser + +go 1.22 + +replace github.com/thxcode/gguf-parser-go => ../../ + +require ( + github.com/dustin/go-humanize v1.0.1 + github.com/olekukonko/tablewriter v0.0.5 + github.com/thxcode/gguf-parser-go v0.0.0-00010101000000-000000000000 +) + +require ( + github.com/henvic/httpretty v0.1.3 // indirect + github.com/mattn/go-runewidth v0.0.9 // indirect + github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/tools v0.21.0 // indirect +) diff --git a/cmd/gguf-parser/go.sum b/cmd/gguf-parser/go.sum new file mode 100644 index 0000000..59b81c1 --- /dev/null +++ b/cmd/gguf-parser/go.sum @@ -0,0 +1,28 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/henvic/httpretty v0.1.3 h1:4A6vigjz6Q/+yAfTD4wqipCv+Px69C7Th/NhT0ApuU8= +github.com/henvic/httpretty v0.1.3/go.mod h1:UUEv7c2kHZ5SPQ51uS3wBpzPDibg2U3Y+IaXyHy5GBg= +github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b h1:e9eeuSYSLmUKxy7ALzKcxo7ggTceQaVcBhjDIcewa9c= +github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= +golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/cmd/gguf-parser/main.go b/cmd/gguf-parser/main.go new file mode 100644 index 0000000..a2fa5e8 --- /dev/null +++ b/cmd/gguf-parser/main.go @@ -0,0 +1,175 @@ +package main + +import ( + "flag" + "os" + "fmt" + "context" + "strconv" + + "github.com/olekukonko/tablewriter" + "github.com/dustin/go-humanize" + + . "github.com/thxcode/gguf-parser-go" +) + +func main() { + ctx := context.Background() + + var ( + // model + path string + url string + repo, model string + // read options + deferReading = true + mmap = true + skipProxy bool + skipTLS bool + // estimate options + ctxSize = 512 + approximate bool + ) + fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + fs.StringVar(&path, "path", path, "Path to load model, e.g. ~/.cache"+ + "/lm-studio/models/NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF/"+ + "Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf") + fs.StringVar(&url, "url", url, "Url to load model, e.g. "+ + "https://huggingface.co/NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF"+ + "/resolve/main/Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf") + fs.StringVar(&repo, "repo", repo, "Repo of HuggingFace, e.g. "+ + "NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF") + fs.StringVar(&model, "model", model, "Model below the --repo, e.g. "+ + "Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf") + fs.BoolVar(&deferReading, "defer-reading", deferReading, "Defer reading "+ + "large or complex data value to speed up the process") + fs.BoolVar(&mmap, "mmap", mmap, "Use mmap to read the local file") + fs.BoolVar(&skipProxy, "skip-proxy", skipProxy, "Skip using proxy when reading from a remote URL") + fs.BoolVar(&skipTLS, "skip-tls", skipTLS, "Skip TLS verification when reading from a remote URL") + fs.IntVar(&ctxSize, "ctx-size", ctxSize, "Number of tokens to predict") + fs.BoolVar(&approximate, "approximate", approximate, "Enable approximate estimate") + if err := fs.Parse(os.Args[1:]); err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } + + var ropts []GGUFReadOption + if deferReading { + ropts = append(ropts, UseDeferReading()) + } + if mmap { + ropts = append(ropts, UseMMap()) + } + if skipProxy { + ropts = append(ropts, SkipProxy()) + } + if skipTLS { + ropts = append(ropts, SkipTLSVerification()) + } + + var gf *GGUFFile + { + var err error + switch { + default: + _, _ = fmt.Fprintf(os.Stderr, "no model specified\n") + os.Exit(1) + case path != "": + gf, err = ParseGGUFFile(path, ropts...) + case url != "": + gf, err = ParseGGUFFileRemote(ctx, url, ropts...) + case repo != "" && model != "": + gf, err = ParseGGUFFileFromHuggingFace(ctx, repo, model, ropts...) + } + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "failed to parse GGUF file: %s\n", err.Error()) + os.Exit(1) + } + } + + g := gf.General() + tprintf( + []string{"Name", "Architecture", "Little Endian", "Quantization", "Size", "Type"}, + []string{ + g.Name, + g.Architecture, + sprintf(g.LittleEndian), + sprintf(g.QuantizationVersion), + humanize.IBytes(uint64(g.Size)), + sprintf(g.FileType), + }) + + a := gf.Architecture() + tprintf( + []string{"Context", "Embedding", "Layer", "Feed Forward", "Expert", "Vocabulary"}, + []string{ + sprintf(a.ContextLength), + sprintf(a.EmbeddingLength), + sprintf(a.BlockCount), + sprintf(a.FeedForwardLength), + sprintf(a.ExpertCount), + sprintf(a.VocabularyLength), + }) + + var eopts []GGUFEstimateOption + if ctxSize > 0 { + eopts = append(eopts, WithContextSize(int32(ctxSize))) + } + if approximate { + eopts = append(eopts, WithApproximate()) + } + + e := gf.Estimate(g, a, eopts...) + tprintf( + []string{"Parameters", "BPW", "Inference Memory"}, + []string{ + e.Parameters.String(), + e.BitsPerWeight.String(), + fmt.Sprintf("%s + %s ≈ %s", + e.InferenceUsage.Memory, + e.InferenceUsage.KVCache.Total(), + e.InferenceUsage.Total()), + }) +} + +func sprintf(a any) string { + switch v := a.(type) { + case string: + return v + case []byte: + return string(v) + case int: + return strconv.Itoa(v) + case int32: + return strconv.Itoa(int(v)) + case int64: + return strconv.Itoa(int(v)) + case uint: + return strconv.Itoa(int(v)) + case uint32: + return strconv.Itoa(int(v)) + case uint64: + return strconv.Itoa(int(v)) + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + case bool: + return strconv.FormatBool(v) + default: + return fmt.Sprintf("%v", v) + } +} + +func tprintf(headers, rows []string) { + tb := tablewriter.NewWriter(os.Stdout) + tb.SetHeaderAlignment(tablewriter.ALIGN_CENTER) + tb.SetAlignment(tablewriter.ALIGN_CENTER) + tb.SetHeaderLine(true) + tb.SetBorder(true) + tb.SetTablePadding("\t") + tb.SetHeader(headers) + tb.Append(rows) + tb.Render() + fmt.Println() +} diff --git a/file.go b/file.go new file mode 100644 index 0000000..9d91108 --- /dev/null +++ b/file.go @@ -0,0 +1,1311 @@ +package gguf_parser + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net/http" + "regexp" + "time" + + "golang.org/x/exp/constraints" + + "github.com/thxcode/gguf-parser-go/util/bytex" + "github.com/thxcode/gguf-parser-go/util/funcx" + "github.com/thxcode/gguf-parser-go/util/httpx" + "github.com/thxcode/gguf-parser-go/util/osx" +) + +// GGUFFile represents a GGUF file, +// see https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#file-structure. +// +// Compared with the complete GGUF file, +// this structure lacks the tensor data part. +type GGUFFile struct { + // Size is the size of the GGUF file in bytes. + Size int64 `json:"size"` + // Header is the header of the GGUF file. + Header GGUFHeader `json:"header"` + // TensorInfos are the tensor infos of the GGUF file, + // the size of TensorInfos is equal to `Header.TensorCount`. + TensorInfos GGUFTensorInfos `json:"tensorInfos"` + // Padding is the padding size of the GGUF file, + // which is used to split Header and TensorInfos from tensor data. + Padding int64 `json:"padding"` + // TensorDataStartOffset is the offset in bytes of the tensor data in this file. + // + // The offset is the start of the file. + TensorDataStartOffset int64 `json:"tensorDataStartOffset"` +} + +// GGUFMagic is a magic number of GGUF file, +// see https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#historical-state-of-affairs. +type GGUFMagic uint32 + +// GGUFMagic constants. +const ( + GGUFMagicGGML GGUFMagic = 0x67676d6c + GGUFMagicGGMF GGUFMagic = 0x67676d66 + GGUFMagicGGJT GGUFMagic = 0x67676a74 + GGUFMagicGGUFLe GGUFMagic = 0x46554747 // GGUF + GGUFMagicGGUFBe GGUFMagic = 0x47475546 // GGUF +) + +// GGUFVersion is a version of GGUF file format, +// see https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#version-history. +type GGUFVersion uint32 + +// GGUFVersion constants. +const ( + GGUFVersionV1 GGUFVersion = iota + 1 + GGUFVersionV2 + GGUFVersionV3 +) + +// GGUFHeader represents the header of a GGUF file. +type GGUFHeader struct { + // Magic is a magic number that announces that this is a GGUF file. + Magic GGUFMagic `json:"magic"` + // Version is a version of the GGUF file format. + Version GGUFVersion `json:"version"` + // TensorCount is the number of tensors in the file. + TensorCount uint64 `json:"tensorCount"` + // MetadataKVCount is the number of key-value pairs in the metadata. + MetadataKVCount uint64 `json:"metadataKVCount"` + // MetadataKV are the key-value pairs in the metadata, + MetadataKV GGUFMetadataKVs `json:"metadataKV"` +} + +// GGUFMetadataValueType is a type of GGUF metadata value, +// see https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#file-structure. +type GGUFMetadataValueType uint32 + +// GGUFMetadataValueType constants. +const ( + GGUFMetadataValueTypeUint8 GGUFMetadataValueType = iota + GGUFMetadataValueTypeInt8 + GGUFMetadataValueTypeUint16 + GGUFMetadataValueTypeInt16 + GGUFMetadataValueTypeUint32 + GGUFMetadataValueTypeInt32 + GGUFMetadataValueTypeFloat32 + GGUFMetadataValueTypeBool + GGUFMetadataValueTypeString + GGUFMetadataValueTypeArray + GGUFMetadataValueTypeUint64 + GGUFMetadataValueTypeInt64 + GGUFMetadataValueTypeFloat64 + _GGUFMetadataValueTypeCount // Unknown +) + +// Types for GGUFMetadataKV. +type ( + // GGUFMetadataKV is a key-value pair in the metadata of a GGUF file. + GGUFMetadataKV struct { + // Key is the key of the metadata key-value pair, + // which is no larger than 64 bytes long. + Key string `json:"key"` + // ValueType is the type of the metadata value. + ValueType GGUFMetadataValueType `json:"valueType"` + // Value is the value of the metadata key-value pair. + Value any `json:"value"` + } + + // GGUFMetadataKVArrayValue is a value of a GGUFMetadataKV with type GGUFMetadataValueTypeArray. + GGUFMetadataKVArrayValue struct { + // StartOffset is the offset in bytes of the GGUFMetadataKVArrayValue in the GGUFFile file. + // + // The offset is the start of the file. + StartOffset int64 `json:"startOffset"` + // Type is the type of the array item. + Type GGUFMetadataValueType `json:"type"` + // Len is the length of the array. + Len uint64 `json:"len"` + // Array holds all array items. + Array []any `json:"array,omitempty"` + } + + // GGUFMetadataKVs is a list of GGUFMetadataKV. + GGUFMetadataKVs []GGUFMetadataKV +) + +// Types for GGMLType. +type ( + // GGMLType is a type of GGML tensor, + // see https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#file-structure. + GGMLType uint32 + + // GGMLTypeTrait holds the trait of a GGMLType, + // see https://github.com/ggerganov/ggml/blob/0cbb7c0e053f5419cfbebb46fbf4d4ed60182cf5/src/ggml.c#L564-L918. + GGMLTypeTrait struct { + BlockSize uint64 // Original is int, in order to reduce conversion, here we use uint64. + TypeSize uint64 // Original is uint32, in order to reduce conversion, here we use uint64. + Quantized bool + } +) + +// GGMLType constants. +// +// GGMLTypeQ4_2, GGMLTypeQ4_3 are deprecated. +const ( + GGMLTypeF32 GGMLType = iota + GGMLTypeF16 + GGMLTypeQ4_0 + GGMLTypeQ4_1 + GGMLTypeQ4_2 + GGMLTypeQ4_3 + GGMLTypeQ5_0 + GGMLTypeQ5_1 + GGMLTypeQ8_0 + GGMLTypeQ8_1 + GGMLTypeQ2_K + GGMLTypeQ3_K + GGMLTypeQ4_K + GGMLTypeQ5_K + GGMLTypeQ6_K + GGMLTypeQ8_K + GGMLTypeIQ2_XXS + GGMLTypeIQ2_XS + GGMLTypeIQ3_XXS + GGMLTypeIQ1_S + GGMLTypeIQ4_NL + GGMLTypeIQ3_S + GGMLTypeIQ2_S + GGMLTypeIQ4_XS + GGMLTypeI8 + GGMLTypeI16 + GGMLTypeI32 + GGMLTypeI64 + GGMLTypeF64 + GGMLTypeIQ1_M + GGMLTypeBF16 + _GGMLTypeCount // Unknown +) + +// Sizes for GGML constant. +const ( + // GGMLTensorSize is the size of a GGML tensor in bytes, + // see https://github.com/ggerganov/ggml/blob/0cbb7c0e053f5419cfbebb46fbf4d4ed60182cf5/include/ggml/ggml.h#L606. + GGMLTensorSize = 368 + + // GGMLObjectSize is the size of a GGML object in bytes, + // see https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/include/ggml/ggml.h#L563. + GGMLObjectSize = 32 +) + +// Types for GGUFTensorInfo. +type ( + // GGUFTensorInfo represents a tensor info in a GGUF file. + GGUFTensorInfo struct { + // StartOffset is the offset in bytes of the GGUFTensorInfo in the GGUFFile file. + // + // The offset is the start of the file. + StartOffset int64 `json:"startOffset"` + // Name is the name of the tensor, + // which is no larger than 64 bytes long. + Name string `json:"name"` + // NDimensions is the number of dimensions of the tensor. + NDimensions uint32 `json:"nDimensions"` + // Dimensions is the dimensions of the tensor, + // the length is NDimensions. + Dimensions []uint64 `json:"dimensions"` + // Type is the type of the tensor. + Type GGMLType `json:"type"` + // Offset is the offset in bytes of the tensor's data in this file. + // + // The offset is relative to tensor data, not to the start of the file. + Offset uint64 `json:"offset"` + } + + // GGUFTensorInfos is a list of GGUFTensorInfo. + GGUFTensorInfos []GGUFTensorInfo +) + +var ErrGGUFFileInvalidFormat = errors.New("invalid GGUF format") + +// ParseGGUFFile parses a GGUF file from the local given path, +// and returns the GGUFFile, or an error if any. +func ParseGGUFFile(path string, opts ...GGUFReadOption) (*GGUFFile, error) { + var o _GGUFReadOptions + for _, opt := range opts { + opt(&o) + } + + var ( + f io.ReadSeeker + s int64 + ) + if o.UseMMap { + mf, err := osx.OpenMmapFile(path) + if err != nil { + return nil, fmt.Errorf("open mmap file: %w", err) + } + defer osx.Close(mf) + f = io.NewSectionReader(mf, 0, mf.Len()) + s = mf.Len() + } else { + ff, err := osx.Open(path) + if err != nil { + return nil, fmt.Errorf("open file: %w", err) + } + defer osx.Close(ff) + f = ff + s = funcx.MustNoError(ff.Stat()).Size() + } + + return parseGGUFFile(s, f, o) +} + +// ParseGGUFFileRemote parses a GGUF file from a remote URL, +// and returns a GGUFFile, or an error if any. +func ParseGGUFFileRemote(ctx context.Context, url string, opts ...GGUFReadOption) (*GGUFFile, error) { + var o _GGUFReadOptions + for _, opt := range opts { + opt(&o) + } + + cli := httpx.Client( + httpx.ClientOptions(). + WithUserAgent("gguf-parser-go"). + If(o.UseDebug, func(x *httpx.ClientOption) *httpx.ClientOption { + return x.WithDebug() + }). + WithTimeout(0). + WithTransport( + httpx.TransportOptions(). + WithoutKeepalive(). + TimeoutForDial(5*time.Second). + TimeoutForTLSHandshake(5*time.Second). + TimeoutForResponseHeader(5*time.Second). + If(o.SkipProxy, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithoutProxy() + }). + If(o.ProxyURL != nil, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithProxy(http.ProxyURL(o.ProxyURL)) + }). + If(o.SkipTLSVerification, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithoutInsecureVerify() + }))) + + var ( + f io.ReadSeeker + s int64 + ) + { + req, err := httpx.NewGetRequestWithContext(ctx, url) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + + var sf *httpx.SeekerFile + if o.BufferSize > 0 { + sf, err = httpx.OpenSeekerFileWithSize(cli, req, o.BufferSize, 0) + } else { + sf, err = httpx.OpenSeekerFile(cli, req) + } + if err != nil { + return nil, fmt.Errorf("open http file: %w", err) + } + defer osx.Close(sf) + f = io.NewSectionReader(sf, 0, sf.Len()) + s = sf.Len() + } + + return parseGGUFFile(s, f, o) +} + +// ParseGGUFFileFromHuggingFace parses a GGUF file from Hugging Face, +// and returns a GGUFFile, or an error if any. +func ParseGGUFFileFromHuggingFace(ctx context.Context, repo, model string, opts ...GGUFReadOption) (*GGUFFile, error) { + return ParseGGUFFileRemote(ctx, fmt.Sprintf("https://huggingface.co/%s/resolve/main/%s", repo, model), opts...) +} + +func parseGGUFFile(s int64, f io.ReadSeeker, o _GGUFReadOptions) (_ *GGUFFile, err error) { + var gf GGUFFile + var bo binary.ByteOrder = binary.LittleEndian + + // size + gf.Size = s + + // magic + if err = binary.Read(f, bo, &gf.Header.Magic); err != nil { + return nil, fmt.Errorf("read magic: %w", err) + } + switch gf.Header.Magic { + default: + return nil, ErrGGUFFileInvalidFormat + case GGUFMagicGGML, GGUFMagicGGMF, GGUFMagicGGJT: + return nil, fmt.Errorf("unsupported format: %s", gf.Header.Magic) + case GGUFMagicGGUFLe: + case GGUFMagicGGUFBe: + bo = binary.BigEndian + } + + // version + if err = binary.Read(f, bo, &gf.Header.Version); err != nil { + return nil, fmt.Errorf("read version: %w", err) + } + + rd := _GGUFReader{v: gf.Header.Version, o: o, f: f, bo: bo} + + // tensor count + if gf.Header.Version <= GGUFVersionV1 { + gf.Header.TensorCount, err = rd.ReadUint64FromUint32() + } else { + gf.Header.TensorCount, err = rd.ReadUint64() + } + if err != nil { + return nil, fmt.Errorf("read tensor count: %w", err) + } + + // metadata kv count + if gf.Header.Version <= GGUFVersionV1 { + gf.Header.MetadataKVCount, err = rd.ReadUint64FromUint32() + } else { + gf.Header.MetadataKVCount, err = rd.ReadUint64() + } + if err != nil { + return nil, fmt.Errorf("read metadata kv count: %w", err) + } + + // metadata kv + { + rd := _GGUFMetadataReader{_GGUFReader: rd} + kvs := make(GGUFMetadataKVs, gf.Header.MetadataKVCount) + for i := uint64(0); i < gf.Header.MetadataKVCount; i++ { + kvs[i], err = rd.Read() + if err != nil { + return nil, fmt.Errorf("read metadata kv %d: %w", i, err) + } + } + gf.Header.MetadataKV = kvs + } + + // tensor infos + { + rd := _GGUFTensorInfoReader{_GGUFReader: rd} + tis := make(GGUFTensorInfos, gf.Header.TensorCount) + for i := uint64(0); i < gf.Header.TensorCount; i++ { + tis[i], err = rd.Read() + if err != nil { + return nil, fmt.Errorf("read tensor info %d: %w", i, err) + } + } + gf.TensorInfos = tis + } + + pds, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return nil, fmt.Errorf("seek padding start: %w", err) + } + + // padding + { + // The global alignment to use, as described above. + // This can vary to allow for different alignment schemes, but it must be a multiple of 8. + // Some writers may not write the alignment. + // If the alignment is not specified, assume it is 32. + var ag uint32 = 32 + if v, ok := gf.Header.MetadataKV.Get("general.alignment"); ok { + ag = v.ValueUint32() + } + gf.Padding = int64(ag) - (pds % int64(ag)) + } + + // tensor data offset + gf.TensorDataStartOffset = pds + gf.Padding + + return &gf, nil +} + +// LittleEndian returns true if the GGUF file is little-endian, +// and false for big-endian. +func (gf *GGUFFile) LittleEndian() bool { + return gf.Header.Version < GGUFVersionV3 || gf.Header.Magic == GGUFMagicGGUFLe +} + +func (kv GGUFMetadataKV) ValueUint8() uint8 { + if kv.ValueType != GGUFMetadataValueTypeUint8 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(uint8) +} + +func (kv GGUFMetadataKV) ValueInt8() int8 { + if kv.ValueType != GGUFMetadataValueTypeInt8 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(int8) +} + +func (kv GGUFMetadataKV) ValueUint16() uint16 { + if kv.ValueType != GGUFMetadataValueTypeUint16 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(uint16) +} + +func (kv GGUFMetadataKV) ValueInt16() int16 { + if kv.ValueType != GGUFMetadataValueTypeInt16 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(int16) +} + +func (kv GGUFMetadataKV) ValueUint32() uint32 { + if kv.ValueType != GGUFMetadataValueTypeUint32 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(uint32) +} + +func (kv GGUFMetadataKV) ValueInt32() int32 { + if kv.ValueType != GGUFMetadataValueTypeInt32 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(int32) +} + +func (kv GGUFMetadataKV) ValueFloat32() float32 { + if kv.ValueType != GGUFMetadataValueTypeFloat32 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(float32) +} + +func (kv GGUFMetadataKV) ValueBool() bool { + if kv.ValueType != GGUFMetadataValueTypeBool { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(bool) +} + +func (kv GGUFMetadataKV) ValueString() string { + if kv.ValueType != GGUFMetadataValueTypeString { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(string) +} + +func (kv GGUFMetadataKV) ValueArray() GGUFMetadataKVArrayValue { + if kv.ValueType != GGUFMetadataValueTypeArray { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(GGUFMetadataKVArrayValue) +} + +func (kv GGUFMetadataKV) ValueUint64() uint64 { + if kv.ValueType != GGUFMetadataValueTypeUint64 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(uint64) +} + +func (kv GGUFMetadataKV) ValueInt64() int64 { + if kv.ValueType != GGUFMetadataValueTypeInt64 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(int64) +} + +func (kv GGUFMetadataKV) ValueFloat64() float64 { + if kv.ValueType != GGUFMetadataValueTypeFloat64 { + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) + } + return kv.Value.(float64) +} + +// ValueNumeric returns the numeric values of the GGUFMetadataKV, +// and panics if the value type is not numeric. +// +// ValueNumeric is a generic function, and the type T must be constraints.Integer or constraints.Float. +// +// Compare to the GGUFMetadataKV's Value* functions, +// ValueNumeric will cast the original value to the target type. +func ValueNumeric[T constraints.Integer | constraints.Float](kv GGUFMetadataKV) T { + switch kv.ValueType { + case GGUFMetadataValueTypeUint8: + return T(kv.Value.(uint8)) + case GGUFMetadataValueTypeInt8: + return T(kv.Value.(int8)) + case GGUFMetadataValueTypeUint16: + return T(kv.Value.(int16)) + case GGUFMetadataValueTypeInt16: + return T(kv.Value.(int16)) + case GGUFMetadataValueTypeUint32: + return T(kv.Value.(uint32)) + case GGUFMetadataValueTypeInt32: + return T(kv.Value.(int32)) + case GGUFMetadataValueTypeFloat32: + return T(kv.Value.(float32)) + case GGUFMetadataValueTypeUint64: + return T(kv.Value.(uint64)) + case GGUFMetadataValueTypeInt64: + return T(kv.Value.(int64)) + case GGUFMetadataValueTypeFloat64: + return T(kv.Value.(float64)) + default: + } + panic(fmt.Errorf("invalid type: %v", kv.ValueType)) +} + +func (av GGUFMetadataKVArrayValue) ValuesUint8() []uint8 { + if av.Type != GGUFMetadataValueTypeUint8 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]uint8, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(uint8) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesInt8() []int8 { + if av.Type != GGUFMetadataValueTypeInt8 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]int8, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(int8) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesUint16() []uint16 { + if av.Type != GGUFMetadataValueTypeUint16 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]uint16, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(uint16) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesInt16() []int16 { + if av.Type != GGUFMetadataValueTypeInt16 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]int16, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(int16) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesUint32() []uint32 { + if av.Type != GGUFMetadataValueTypeUint32 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]uint32, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(uint32) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesInt32() []int32 { + if av.Type != GGUFMetadataValueTypeInt32 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]int32, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(int32) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesFloat32() []float32 { + if av.Type != GGUFMetadataValueTypeFloat32 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]float32, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(float32) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesBool() []bool { + if av.Type != GGUFMetadataValueTypeBool { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]bool, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(bool) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesString() []string { + if av.Type != GGUFMetadataValueTypeString { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]string, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(string) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesArray() []GGUFMetadataKVArrayValue { + if av.Type != GGUFMetadataValueTypeArray { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]GGUFMetadataKVArrayValue, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(GGUFMetadataKVArrayValue) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesUint64() []uint64 { + if av.Type != GGUFMetadataValueTypeUint64 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]uint64, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(uint64) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesInt64() []int64 { + if av.Type != GGUFMetadataValueTypeInt64 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]int64, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(int64) + } + return v +} + +func (av GGUFMetadataKVArrayValue) ValuesFloat64() []float64 { + if av.Type != GGUFMetadataValueTypeFloat64 { + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + v := make([]float64, av.Len) + for i := uint64(0); i < av.Len; i++ { + v[i] = av.Array[i].(float64) + } + return v +} + +// ValuesNumeric returns the numeric values of the GGUFMetadataKVArrayValue, +// and panics if the value type is not numeric. +// +// ValuesNumeric is a generic function, and the type T must be constraints.Integer or constraints.Float. +// +// Compare to the GGUFMetadataKVArrayValue's Value* functions, +// ValuesNumeric will cast the original value to the target type. +func ValuesNumeric[T constraints.Integer | constraints.Float](av GGUFMetadataKVArrayValue) []T { + v := make([]T, av.Len) + for i := uint64(0); i < av.Len; i++ { + switch av.Type { + case GGUFMetadataValueTypeUint8: + v[i] = T(av.Array[i].(uint8)) + case GGUFMetadataValueTypeInt8: + v[i] = T(av.Array[i].(int8)) + case GGUFMetadataValueTypeUint16: + v[i] = T(av.Array[i].(uint16)) + case GGUFMetadataValueTypeInt16: + v[i] = T(av.Array[i].(int16)) + case GGUFMetadataValueTypeUint32: + v[i] = T(av.Array[i].(uint32)) + case GGUFMetadataValueTypeInt32: + v[i] = T(av.Array[i].(int32)) + case GGUFMetadataValueTypeFloat32: + v[i] = T(av.Array[i].(float32)) + case GGUFMetadataValueTypeUint64: + v[i] = T(av.Array[i].(uint64)) + case GGUFMetadataValueTypeInt64: + v[i] = T(av.Array[i].(int64)) + case GGUFMetadataValueTypeFloat64: + v[i] = T(av.Array[i].(float64)) + default: + panic(fmt.Errorf("invalid type: %v", av.Type)) + } + } + return v +} + +// Load loads the value of the GGUFMetadataKVArrayValue from the GGUF file, +// returns the GGUFMetadataKVArrayValue with the loaded rest part(i.e. Array), +// or an error too if raised. +// +// Load is a defer reading function to compensate for the lack of data. +func (av GGUFMetadataKVArrayValue) Load(v GGUFVersion, f io.ReadSeeker, bo binary.ByteOrder) (GGUFMetadataKVArrayValue, error) { + _, err := f.Seek(av.StartOffset, io.SeekStart) + if err != nil { + return av, fmt.Errorf("seek array start: %w", err) + } + rd := _GGUFReader{v: v, o: _GGUFReadOptions{UseMMap: true}, f: f, bo: bo} + + nav, err := rd.ReadArray() + if err != nil { + return av, fmt.Errorf("read array: %w", err) + } + return nav, nil +} + +// HasAll returns true if the GGUFMetadataKVs has all the given keys, +// and false otherwise. +func (kvs GGUFMetadataKVs) HasAll(keys []string) bool { + ks := make(map[string]struct{}, len(keys)) + for i := range keys { + ks[keys[i]] = struct{}{} + } + for i := range kvs { + k := kvs[i].Key + if _, ok := ks[k]; !ok { + continue + } + delete(ks, k) + if len(ks) == 0 { + break + } + } + return len(ks) == 0 +} + +// Get returns the GGUFMetadataKV with the given key, +// and true if found, and false otherwise. +func (kvs GGUFMetadataKVs) Get(key string) (value GGUFMetadataKV, found bool) { + for i := range kvs { + if kvs[i].Key == key { + return kvs[i], true + } + } + return GGUFMetadataKV{}, false +} + +// Search returns a list of GGUFMetadataKV with the keys that match the given regex. +func (kvs GGUFMetadataKVs) Search(keyRegex *regexp.Regexp) (values []GGUFMetadataKV) { + for i := range kvs { + if keyRegex.MatchString(kvs[i].Key) { + values = append(values, kvs[i]) + } + } + return values +} + +// Index returns a map value to the GGUFMetadataKVs with the given keys, +// and the number of keys found. +func (kvs GGUFMetadataKVs) Index(keys []string) (values map[string]GGUFMetadataKV, found int) { + ks := make(map[string]struct{}, len(keys)) + for i := range keys { + ks[keys[i]] = struct{}{} + } + values = make(map[string]GGUFMetadataKV) + for i := range kvs { + if _, ok := ks[kvs[i].Key]; ok { + values[kvs[i].Key] = kvs[i] + found++ + } + if found == len(ks) { + break + } + } + return values, found +} + +// _GGMLTypeTraits is a table of GGMLTypeTrait for GGMLType. +var _GGMLTypeTraits = map[GGMLType]GGMLTypeTrait{ + GGMLTypeF32: {BlockSize: 1, TypeSize: 4}, + GGMLTypeF16: {BlockSize: 1, TypeSize: 2}, + GGMLTypeQ4_0: {BlockSize: 32, TypeSize: 18, Quantized: true}, + GGMLTypeQ4_1: {BlockSize: 32, TypeSize: 20, Quantized: true}, + GGMLTypeQ4_2: {BlockSize: 0, TypeSize: 0}, // Deprecated + GGMLTypeQ4_3: {BlockSize: 0, TypeSize: 0}, // Deprecated + GGMLTypeQ5_0: {BlockSize: 32, TypeSize: 22, Quantized: true}, + GGMLTypeQ5_1: {BlockSize: 32, TypeSize: 24, Quantized: true}, + GGMLTypeQ8_0: {BlockSize: 32, TypeSize: 34, Quantized: true}, + GGMLTypeQ8_1: {BlockSize: 32, TypeSize: 36, Quantized: true}, + GGMLTypeQ2_K: {BlockSize: 256, TypeSize: 84, Quantized: true}, + GGMLTypeQ3_K: {BlockSize: 256, TypeSize: 110, Quantized: true}, + GGMLTypeQ4_K: {BlockSize: 256, TypeSize: 144, Quantized: true}, + GGMLTypeQ5_K: {BlockSize: 256, TypeSize: 176, Quantized: true}, + GGMLTypeQ6_K: {BlockSize: 256, TypeSize: 210, Quantized: true}, + GGMLTypeQ8_K: {BlockSize: 256, TypeSize: 292, Quantized: true}, + GGMLTypeIQ2_XXS: {BlockSize: 256, TypeSize: 66, Quantized: true}, + GGMLTypeIQ2_XS: {BlockSize: 256, TypeSize: 74, Quantized: true}, + GGMLTypeIQ3_XXS: {BlockSize: 256, TypeSize: 98, Quantized: true}, + GGMLTypeIQ1_S: {BlockSize: 256, TypeSize: 50, Quantized: true}, + GGMLTypeIQ4_NL: {BlockSize: 32, TypeSize: 18, Quantized: true}, + GGMLTypeIQ3_S: {BlockSize: 256, TypeSize: 110, Quantized: true}, + GGMLTypeIQ2_S: {BlockSize: 256, TypeSize: 82, Quantized: true}, + GGMLTypeIQ4_XS: {BlockSize: 256, TypeSize: 136, Quantized: true}, + GGMLTypeI8: {BlockSize: 1, TypeSize: 1}, + GGMLTypeI16: {BlockSize: 1, TypeSize: 2}, + GGMLTypeI32: {BlockSize: 1, TypeSize: 4}, + GGMLTypeI64: {BlockSize: 1, TypeSize: 8}, + GGMLTypeF64: {BlockSize: 1, TypeSize: 8}, + GGMLTypeIQ1_M: {BlockSize: 256, TypeSize: 56, Quantized: true}, + GGMLTypeBF16: {BlockSize: 1, TypeSize: 2}, +} + +// Trait returns the GGMLTypeTrait of the GGMLType. +func (t GGMLType) Trait() (GGMLTypeTrait, bool) { + tt, ok := _GGMLTypeTraits[t] + return tt, ok +} + +// RowSizeOf returns the size of the given dimensions according to the GGMLType's GGMLTypeTrait, +// which is inspired by +// https://github.com/ggerganov/ggml/blob/0cbb7c0e053f5419cfbebb46fbf4d4ed60182cf5/src/ggml.c#L3142-L3145. +// +// The index of the given dimensions means the number of dimension, +// i.e. 0 is the first dimension, 1 is the second dimension, and so on. +// +// The value of the item is the number of elements in the corresponding dimension. +func (t GGMLType) RowSizeOf(dimensions []uint64) uint64 { + if len(dimensions) == 0 { + panic(errors.New("no dimensions")) + } + + tt, ok := t.Trait() + if !ok { + panic(fmt.Errorf("invalid type: %v", t)) + } + + // https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/src/ggml.c#L2640-L2643 + ds := tt.TypeSize * dimensions[0] / tt.BlockSize // Row size + for i := 1; i < len(dimensions); i++ { + ds *= dimensions[i] + } + return ds +} + +// Elements returns the number of elements of the GGUFTensorInfo, +// which is inspired by +// https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/src/ggml.c#L2597-L2601. +func (ti GGUFTensorInfo) Elements() uint64 { + if ti.NDimensions == 0 { + panic(errors.New("no dimensions")) + } + + ret := uint64(1) + for i := uint32(0); i < ti.NDimensions; i++ { + ret *= ti.Dimensions[i] + } + return ret +} + +// Bytes returns the number of bytes of the GGUFTensorInfo, +// which is inspired by +// https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/src/ggml.c#L2609-L2626. +func (ti GGUFTensorInfo) Bytes() uint64 { + if ti.NDimensions == 0 { + panic(errors.New("no dimensions")) + } + + tt, ok := ti.Type.Trait() + if !ok { + panic(fmt.Errorf("invalid type: %v", ti.Type)) + } + + // https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/src/ggml.c#L3210-L3214 + nb := make([]uint64, 0, ti.NDimensions) + { + nb = append(nb, tt.TypeSize) + nb = append(nb, nb[0]*(ti.Dimensions[0]/tt.BlockSize)) + for i := uint32(2); i < ti.NDimensions; i++ { + nb = append(nb, nb[i-1]*ti.Dimensions[i-1]) + } + } + + var ret uint64 + if tt.BlockSize == 1 { + ret = tt.TypeSize + for i := uint32(0); i < ti.NDimensions; i++ { + ret += (ti.Dimensions[i] - 1) * nb[i] + } + return ret + } + + ret = ti.Dimensions[0] * nb[0] / tt.BlockSize + for i := uint32(1); i < ti.NDimensions; i++ { + ret += (ti.Dimensions[i] - 1) * nb[i] + } + return ret +} + +// HasAll returns true if the GGUFTensorInfos has all the given names, +// and false otherwise. +func (tis GGUFTensorInfos) HasAll(names []string) bool { + ns := make(map[string]struct{}, len(names)) + for i := range names { + ns[names[i]] = struct{}{} + } + for i := range tis { + n := tis[i].Name + if _, ok := ns[n]; !ok { + continue + } + delete(ns, n) + if len(ns) == 0 { + break + } + } + return len(ns) == 0 +} + +// Get returns the GGUFTensorInfo with the given name, +// and true if found, and false otherwise. +func (tis GGUFTensorInfos) Get(name string) (info GGUFTensorInfo, found bool) { + for i := range tis { + if tis[i].Name == name { + return tis[i], true + } + } + return GGUFTensorInfo{}, false +} + +// Search returns a list of GGUFTensorInfo with the names that match the given regex. +func (tis GGUFTensorInfos) Search(nameRegex *regexp.Regexp) (infos []GGUFTensorInfo) { + for i := range tis { + if nameRegex.MatchString(tis[i].Name) { + infos = append(infos, tis[i]) + } + } + return infos +} + +// Index returns a map value to the GGUFTensorInfos with the given names, +// and the number of names found. +func (tis GGUFTensorInfos) Index(names []string) (infos map[string]GGUFTensorInfo, found int) { + ns := make(map[string]struct{}, len(names)) + for i := range names { + ns[names[i]] = struct{}{} + } + infos = make(map[string]GGUFTensorInfo) + for i := range tis { + if _, ok := ns[tis[i].Name]; ok { + infos[tis[i].Name] = tis[i] + found++ + } + if found == len(ns) { + break + } + } + return infos, found +} + +type _GGUFReader struct { + v GGUFVersion + o _GGUFReadOptions + f io.ReadSeeker + bo binary.ByteOrder +} + +func (rd _GGUFReader) ReadUint8() (v uint8, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read uint8: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadInt8() (v int8, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read int8: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadUint16() (v uint16, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read uint16: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadInt16() (v int16, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read int16: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadUint32() (v uint32, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read uint32: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadUint64FromUint32() (uint64, error) { + v, err := rd.ReadUint32() + return uint64(v), err +} + +func (rd _GGUFReader) ReadInt32() (v int32, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read int32: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadFloat32() (v float32, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read float32: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadBool() (v bool, err error) { + b, err := rd.ReadUint8() + if err != nil { + return false, fmt.Errorf("read bool: %w", err) + } + return b != 0, nil +} + +func (rd _GGUFReader) ReadString() (v string, err error) { + var l uint64 + if rd.v <= GGUFVersionV1 { + l, err = rd.ReadUint64FromUint32() + } else { + l, err = rd.ReadUint64() + } + if err != nil { + return "", fmt.Errorf("read string length: %w", err) + } + + b := bytex.GetBytes(l) + defer bytex.Put(b) + if _, err = rd.f.Read(b); err != nil { + return "", fmt.Errorf("read string: %w", err) + } + + return string(bytes.TrimSpace(b)), nil +} + +func (rd _GGUFReader) ReadArray() (v GGUFMetadataKVArrayValue, err error) { + v.StartOffset, err = rd.f.Seek(0, io.SeekCurrent) + if err != nil { + return v, fmt.Errorf("read array start: %w", err) + } + + if err = binary.Read(rd.f, rd.bo, &v.Type); err != nil { + return v, fmt.Errorf("read array item type: %w", err) + } + + if rd.v <= GGUFVersionV1 { + v.Len, err = rd.ReadUint64FromUint32() + } else { + v.Len, err = rd.ReadUint64() + } + if err != nil { + return v, fmt.Errorf("read array length: %w", err) + } + + if !rd.o.UseDeferReading { + v.Array = make([]any, v.Len) + for i := uint64(0); i < v.Len; i++ { + v.Array[i], err = rd.ReadValue(v.Type) + if err != nil { + return v, fmt.Errorf("read array item %d: %w", i, err) + } + } + + return v, nil + } + + switch v.Type { + case GGUFMetadataValueTypeUint8, GGUFMetadataValueTypeInt8, GGUFMetadataValueTypeBool: + _, err = rd.f.Seek(int64(v.Len), io.SeekCurrent) + case GGUFMetadataValueTypeUint16, GGUFMetadataValueTypeInt16: + _, err = rd.f.Seek(int64(v.Len)*2, io.SeekCurrent) + case GGUFMetadataValueTypeUint32, GGUFMetadataValueTypeInt32, GGUFMetadataValueTypeFloat32: + _, err = rd.f.Seek(int64(v.Len)*4, io.SeekCurrent) + case GGUFMetadataValueTypeUint64, GGUFMetadataValueTypeInt64, GGUFMetadataValueTypeFloat64: + _, err = rd.f.Seek(int64(v.Len)*8, io.SeekCurrent) + case GGUFMetadataValueTypeString: + err = func() error { + for i := uint64(0); i < v.Len; i++ { + var l uint64 + if rd.v <= GGUFVersionV1 { + l, err = rd.ReadUint64FromUint32() + } else { + l, err = rd.ReadUint64() + } + if err != nil { + return fmt.Errorf("read array[string] %d length: %w", i, err) + } + _, err = rd.f.Seek(int64(l), io.SeekCurrent) + if err != nil { + return fmt.Errorf("seek array[string] %d: %w", i, err) + } + } + return nil + }() + if err == nil { + _, err = rd.f.Seek(0, io.SeekCurrent) + } + default: + // Should not happen. + panic(fmt.Errorf("invalid type: %v", v.Type)) + } + if err != nil { + return v, fmt.Errorf("seek array end: %w", err) + } + + return v, nil +} + +func (rd _GGUFReader) ReadUint64() (v uint64, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read uint64: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadInt64() (v int64, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read int64: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadFloat64() (v float64, err error) { + err = binary.Read(rd.f, rd.bo, &v) + if err != nil { + return 0, fmt.Errorf("read float64: %w", err) + } + return v, nil +} + +func (rd _GGUFReader) ReadValue(vt GGUFMetadataValueType) (v any, err error) { + if vt >= _GGUFMetadataValueTypeCount { + return nil, fmt.Errorf("invalid type: %v", vt) + } + + switch vt { + case GGUFMetadataValueTypeUint8: + v, err = rd.ReadUint8() + case GGUFMetadataValueTypeInt8: + v, err = rd.ReadInt8() + case GGUFMetadataValueTypeUint16: + v, err = rd.ReadUint16() + case GGUFMetadataValueTypeInt16: + v, err = rd.ReadInt16() + case GGUFMetadataValueTypeUint32: + v, err = rd.ReadUint32() + case GGUFMetadataValueTypeInt32: + v, err = rd.ReadInt32() + case GGUFMetadataValueTypeFloat32: + v, err = rd.ReadFloat32() + case GGUFMetadataValueTypeBool: + v, err = rd.ReadBool() + case GGUFMetadataValueTypeString: + v, err = rd.ReadString() + case GGUFMetadataValueTypeArray: + v, err = rd.ReadArray() + case GGUFMetadataValueTypeUint64: + v, err = rd.ReadUint64() + case GGUFMetadataValueTypeInt64: + v, err = rd.ReadInt64() + case GGUFMetadataValueTypeFloat64: + v, err = rd.ReadFloat64() + default: + // Should not happen. + panic(fmt.Errorf("invalid type: %v", vt)) + } + if err != nil { + return nil, err + } + return v, nil +} + +type _GGUFMetadataReader struct { + _GGUFReader +} + +func (rd _GGUFMetadataReader) Read() (kv GGUFMetadataKV, err error) { + kv.Key, err = rd.ReadString() + if err != nil { + return kv, fmt.Errorf("read key: %w", err) + } + + { + vt, err := rd.ReadUint32() + if err != nil { + return kv, fmt.Errorf("read value type: %w", err) + } + kv.ValueType = GGUFMetadataValueType(vt) + if kv.ValueType >= _GGUFMetadataValueTypeCount { + return kv, fmt.Errorf("invalid value type: %v", kv.ValueType) + } + } + + kv.Value, err = rd.ReadValue(kv.ValueType) + if err != nil { + return kv, fmt.Errorf("read %s value: %w", kv.Key, err) + } + + return kv, nil +} + +type _GGUFTensorInfoReader struct { + _GGUFReader +} + +func (rd _GGUFTensorInfoReader) Read() (ti GGUFTensorInfo, err error) { + ti.StartOffset, err = rd.f.Seek(0, io.SeekCurrent) + if err != nil { + return ti, fmt.Errorf("seek tensor info start: %w", err) + } + + ti.Name, err = rd.ReadString() + if err != nil { + return ti, fmt.Errorf("read name: %w", err) + } + + ti.NDimensions, err = rd.ReadUint32() + if err != nil { + return ti, fmt.Errorf("read n dimensions: %w", err) + } + + ti.Dimensions = make([]uint64, ti.NDimensions) + for i := uint32(0); i < ti.NDimensions; i++ { + if rd.v <= GGUFVersionV1 { + ti.Dimensions[i], err = rd.ReadUint64FromUint32() + } else { + ti.Dimensions[i], err = rd.ReadUint64() + } + if err != nil { + return ti, fmt.Errorf("read dimension %d: %w", i, err) + } + } + + { + v, err := rd.ReadUint32() + if err != nil { + return ti, fmt.Errorf("read type: %w", err) + } + ti.Type = GGMLType(v) + if ti.Type >= _GGMLTypeCount { + return ti, fmt.Errorf("invalid type: %v", ti.Type) + } + } + + ti.Offset, err = rd.ReadUint64() + if err != nil { + return ti, fmt.Errorf("read offset: %w", err) + } + + return ti, nil +} diff --git a/file_architecture.go b/file_architecture.go new file mode 100644 index 0000000..2f67230 --- /dev/null +++ b/file_architecture.go @@ -0,0 +1,246 @@ +package gguf_parser + +// GGUFArchitectureMetadata represents the architecture metadata of a GGUF file. +type GGUFArchitectureMetadata struct { + // ContextLength(n_ctx_train) is the context length of the model. + // + // For most architectures, this is the hard limit on the length of the input. + // Architectures, like RWKV, + // that are not reliant on transformer-style attention may be able to handle larger inputs, + // but this is not guaranteed. + ContextLength uint64 `json:"contextLength"` + // EmbeddingLength(n_embd) is the length of the embedding layer. + EmbeddingLength uint64 `json:"embeddingLength"` + // BlockCount(n_layer) is the number of blocks of attention and feed-forward layers, + // i.e. the bulk of the LLM. + // This does not include the input or embedding layers. + BlockCount uint64 `json:"blockCount"` + // FeedForwardLength(n_ff) is the length of the feed-forward layer. + FeedForwardLength uint64 `json:"feedForwardLength,omitempty"` + // ExpertCount(n_expert) is the number of experts in MoE models. + ExpertCount uint32 `json:"expertCount,omitempty"` + // ExpertUsedCount(n_expert_used) is the number of experts used during each token evaluation in MoE models. + ExpertUsedCount uint32 `json:"expertUsedCount,omitempty"` + // AttentionHeadCount(n_head) is the number of attention heads. + AttentionHeadCount uint64 `json:"attentionHeadCount,omitempty"` + // AttentionHeadCountKV(n_head_kv) is the number of attention heads per group used in Grouped-Query-Attention. + // + // If not provided or equal to AttentionHeadCount, + // the model does not use Grouped-Query-Attention. + AttentionHeadCountKV uint64 `json:"attentionHeadCountKV,omitempty"` + // AttentionMaxALiBIBias is the maximum bias to use for ALiBI. + AttentionMaxALiBIBias float32 `json:"attentionMaxALiBIBias,omitempty"` + // AttentionClampKQV describes a value `C`, + // which is used to clamp the values of the `Q`, `K` and `V` tensors between `[-C, C]`. + AttentionClampKQV float32 `json:"attentionClampKQV,omitempty"` + // AttentionLayerNormEpsilon is the epsilon value used in the LayerNorm(Layer Normalization). + AttentionLayerNormEpsilon float32 `json:"attentionLayerNormEpsilon,omitempty"` + // AttentionLayerNormRMSEpsilon is the epsilon value used in the RMSNorm(Root Mean Square Layer Normalization), + // which is a simplification of the original LayerNorm. + AttentionLayerNormRMSEpsilon float32 `json:"attentionLayerNormRMSEpsilon,omitempty"` + // AttentionKeyLength is the size of a key head. + // + // Defaults to `EmbeddingLength / AttentionHeadCount`. + AttentionKeyLength uint32 `json:"attentionKeyLength"` + // AttentionValueLength is the size of a value head. + // + // Defaults to `EmbeddingLength / AttentionHeadCount`. + AttentionValueLength uint32 `json:"attentionValueLength"` + // RoPEDimensionCount is the number of dimensions in the RoPE(Rotary Positional Encoding). + RoPEDimensionCount uint64 `json:"ropeDimensionCount,omitempty"` + // RoPEFrequencyBase is the base frequency of the RoPE. + RoPEFrequencyBase float32 `json:"ropeFrequencyBase,omitempty"` + // RoPEFrequencyScale is the frequency scale of the RoPE. + RoPEScalingType string `json:"ropeScalingType,omitempty"` + // RoPEScalingFactor is the scaling factor of the RoPE. + RoPEScalingFactor float32 `json:"ropeScalingFactor,omitempty"` + // RoPEScalingOriginalContextLength is the original context length of the RoPE scaling. + RoPEScalingOriginalContextLength uint64 `json:"ropeScalingOriginalContextLength,omitempty"` + // RoPEScalingFinetuned is true if the RoPE scaling is fine-tuned. + RoPEScalingFinetuned bool `json:"ropeScalingFinetuned,omitempty"` + // SSMConvolutionKernel is the size of the convolution kernel used in the SSM(Selective State Space Model). + SSMConvolutionKernel uint32 `json:"ssmConvolutionKernel,omitempty"` + // SSMInnerSize is the embedding size of the state in SSM. + SSMInnerSize uint32 `json:"ssmInnerSize,omitempty"` + // SSMStateSize is the size of the recurrent state in SSM. + SSMStateSize uint32 `json:"ssmStateSize,omitempty"` + // SSMTimeStepRank is the rank of the time steps in SSM. + SSMTimeStepRank uint32 `json:"ssmTimeStepRank,omitempty"` + // VocabularyLength is the size of the vocabulary. + // + // VocabularyLength is the same as the tokenizer's token size. + VocabularyLength uint64 `json:"vocabularyLength"` +} + +// Architecture returns the architecture metadata of the GGUF file. +func (gf *GGUFFile) Architecture() (ga GGUFArchitectureMetadata) { + arch := "llama" + if v, ok := gf.Header.MetadataKV.Get("general.architecture"); ok { + arch = v.ValueString() + } + var ( + contextLengthKey = arch + ".context_length" + embeddingLengthKey = arch + ".embedding_length" + blockCountKey = arch + ".block_count" + feedForwardLengthKey = arch + ".feed_forward_length" + expertCountKey = arch + ".expert_count" + expertUsedCountKey = arch + ".expert_used_count" + + attentionHeadCountKey = arch + ".attention.head_count" + attentionHeadCountKVKey = arch + ".attention.head_count_kv" + attentionMaxALiBIBiasKey = arch + ".attention.max_alibi_bias" + attentionMaxALiBIBiasKey2 = arch + ".attention.alibi_bias_max" + attentionClampKQVKey = arch + ".attention.clamp_kqv" + attentionClampKQVKey2 = arch + ".attention.clip_kqv" + attentionLayerNormEpsilonKey = arch + ".attention.layer_norm_epsilon" + attentionLayerNormRMSEpsilonKey = arch + ".attention.layer_norm_rms_epsilon" + attentionKeyLengthKey = arch + ".attention.key_length" + attentionValueLengthKey = arch + ".attention.value_length" + + ropeDimensionCountKey = arch + ".rope.dimension_count" + ropeFrequencyBaseKey = arch + ".rope.freq_base" + ropeScaleLinearKey = arch + ".rope.scale_linear" + ropeScalingTypeKey = arch + ".rope.scaling.type" + ropeScalingFactorKey = arch + ".rope.scaling.factor" + ropeScalingOriginalContextKey = arch + ".rope.scaling.original_context_length" // uint32 maybe + ropeScalingFinetunedKey = arch + ".rope.scaling.finetuned" + + ssmConvolutionKernelKey = arch + ".ssm.conv_kernel" + ssmInnerSizeKey = arch + ".ssm.inner_size" + ssmStateSizeKey = arch + ".ssm.state_size" + ssmTimeStepRankKey = arch + ".ssm.time_step_rank" + + vocabularyLengthKey = arch + ".vocab_size" // uint32 maybe + tokenizerGGMLTokensKey = "tokenizer.ggml.tokens" + ) + + m, _ := gf.Header.MetadataKV.Index([]string{ + contextLengthKey, + embeddingLengthKey, + blockCountKey, + feedForwardLengthKey, + expertCountKey, + expertUsedCountKey, + attentionHeadCountKey, + attentionHeadCountKVKey, + attentionMaxALiBIBiasKey, + attentionMaxALiBIBiasKey2, + attentionClampKQVKey, + attentionClampKQVKey2, + attentionLayerNormEpsilonKey, + attentionLayerNormRMSEpsilonKey, + attentionKeyLengthKey, + attentionValueLengthKey, + ropeDimensionCountKey, + ropeFrequencyBaseKey, + ropeScaleLinearKey, + ropeScalingTypeKey, + ropeScalingFactorKey, + ropeScalingOriginalContextKey, + ropeScalingFinetunedKey, + ssmConvolutionKernelKey, + ssmInnerSizeKey, + ssmStateSizeKey, + ssmTimeStepRankKey, + vocabularyLengthKey, + tokenizerGGMLTokensKey, + }) + + if v, ok := m[contextLengthKey]; ok { + ga.ContextLength = ValueNumeric[uint64](v) + } + if v, ok := m[embeddingLengthKey]; ok { + ga.EmbeddingLength = ValueNumeric[uint64](v) + } + if v, ok := m[blockCountKey]; ok { + ga.BlockCount = ValueNumeric[uint64](v) + } + if v, ok := m[feedForwardLengthKey]; ok { + ga.FeedForwardLength = ValueNumeric[uint64](v) + } + if v, ok := m[expertCountKey]; ok { + ga.ExpertCount = ValueNumeric[uint32](v) + } + if v, ok := m[expertUsedCountKey]; ok { + ga.ExpertUsedCount = ValueNumeric[uint32](v) + } + + if v, ok := m[attentionHeadCountKey]; ok { + ga.AttentionHeadCount = ValueNumeric[uint64](v) + } + if v, ok := m[attentionHeadCountKVKey]; ok { + ga.AttentionHeadCountKV = ValueNumeric[uint64](v) + } else { + ga.AttentionHeadCountKV = ga.AttentionHeadCount + } + if v, ok := m[attentionMaxALiBIBiasKey]; ok { + ga.AttentionMaxALiBIBias = ValueNumeric[float32](v) + } else if v, ok := m[attentionMaxALiBIBiasKey2]; ok { + ga.AttentionMaxALiBIBias = ValueNumeric[float32](v) + } + if v, ok := m[attentionClampKQVKey]; ok { + ga.AttentionClampKQV = ValueNumeric[float32](v) + } else if v, ok := m[attentionClampKQVKey2]; ok { + ga.AttentionClampKQV = ValueNumeric[float32](v) + } + if v, ok := m[attentionLayerNormEpsilonKey]; ok { + ga.AttentionLayerNormEpsilon = ValueNumeric[float32](v) + } + if v, ok := m[attentionLayerNormRMSEpsilonKey]; ok { + ga.AttentionLayerNormRMSEpsilon = ValueNumeric[float32](v) + } + if v, ok := m[attentionKeyLengthKey]; ok { + ga.AttentionKeyLength = ValueNumeric[uint32](v) + } else if ga.AttentionHeadCount != 0 { + ga.AttentionKeyLength = uint32(ga.EmbeddingLength / ga.AttentionHeadCount) + } + if v, ok := m[attentionValueLengthKey]; ok { + ga.AttentionValueLength = ValueNumeric[uint32](v) + } else if ga.AttentionHeadCount != 0 { + ga.AttentionValueLength = uint32(ga.EmbeddingLength / ga.AttentionHeadCount) + } + + if v, ok := m[ropeDimensionCountKey]; ok { + ga.RoPEDimensionCount = ValueNumeric[uint64](v) + } + if v, ok := m[ropeFrequencyBaseKey]; ok { + ga.RoPEFrequencyBase = ValueNumeric[float32](v) + } + if v, ok := m[ropeScaleLinearKey]; ok { + ga.RoPEScalingType = "linear" + ga.RoPEScalingFactor = ValueNumeric[float32](v) + } + if v, ok := m[ropeScalingTypeKey]; ok { + ga.RoPEScalingType = v.ValueString() + } + if v, ok := m[ropeScalingFactorKey]; ok { + ga.RoPEScalingFactor = ValueNumeric[float32](v) + } + if v, ok := m[ropeScalingOriginalContextKey]; ok { + ga.RoPEScalingOriginalContextLength = ValueNumeric[uint64](v) + } + if v, ok := m[ropeScalingFinetunedKey]; ok { + ga.RoPEScalingFinetuned = v.ValueBool() + } + + if v, ok := m[ssmConvolutionKernelKey]; ok { + ga.SSMConvolutionKernel = ValueNumeric[uint32](v) + } + if v, ok := m[ssmInnerSizeKey]; ok { + ga.SSMInnerSize = ValueNumeric[uint32](v) + } + if v, ok := m[ssmStateSizeKey]; ok { + ga.SSMStateSize = ValueNumeric[uint32](v) + } + if v, ok := m[ssmTimeStepRankKey]; ok { + ga.SSMTimeStepRank = ValueNumeric[uint32](v) + } + + if v, ok := m[vocabularyLengthKey]; ok { + ga.VocabularyLength = ValueNumeric[uint64](v) + } else if v, ok := m[tokenizerGGMLTokensKey]; ok { + ga.VocabularyLength = v.ValueArray().Len + } + + return ga +} diff --git a/file_architecture_test.go b/file_architecture_test.go new file mode 100644 index 0000000..a4aad9f --- /dev/null +++ b/file_architecture_test.go @@ -0,0 +1,46 @@ +package gguf_parser + +import ( + "context" + "os" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestGGUFFile_Architecture(t *testing.T) { + ctx := context.Background() + + f, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + t.Log("\n", spew.Sdump(f.Architecture()), "\n") +} + +func BenchmarkGGUFFile_Architecture(b *testing.B) { + mp, ok := os.LookupEnv("TEST_MODEL_PATH") + if !ok { + b.Skip("TEST_MODEL_PATH is not set") + return + } + + f, err := ParseGGUFFile(mp, UseMMap(), UseDeferReading()) + if err != nil { + b.Fatal(err) + return + } + + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = f.Architecture() + } +} diff --git a/file_estimate.go b/file_estimate.go new file mode 100644 index 0000000..c34f67f --- /dev/null +++ b/file_estimate.go @@ -0,0 +1,540 @@ +package gguf_parser + +import ( + "strconv" + + "github.com/dustin/go-humanize" +) + +// GGUFEstimate represents the estimated result of the GGUF file. +type GGUFEstimate struct { + // Parameters is the number of parameters. + Parameters GGUFEstimateParametersScalar `json:"parameters"` + // BitsPerWeight is the bits per weight, + // which describes how many bits are used to store a weight. + // + // In general, a higher bpw is better, + // as it means that the model has more precision and accuracy in representing weights. + BitsPerWeight GGUFEstimateBitsPerWeightScalar `json:"bitsPerWeight"` + // InferenceUsage is the usage of inference. + InferenceUsage GGUFEstimateInferenceUsage `json:"inferenceCache"` +} + +type ( + // GGUFEstimateParametersScalar is the estimated scalar for parameters. + GGUFEstimateParametersScalar uint64 + + // GGUFEstimateBitsPerWeightScalar is the estimated scalar for bits per weight. + GGUFEstimateBitsPerWeightScalar float64 + + // GGUFEstimateBytesScalar is the estimated scalar for bytes. + GGUFEstimateBytesScalar uint64 + + // GGUFEstimateInferenceUsage represents the usage of inference. + GGUFEstimateInferenceUsage struct { + // Memory is the memory usage. + Memory GGUFEstimateBytesScalar `json:"memory"` + // KVCache is the key-value cache. + KVCache GGUFEstimateKVCache `json:"kvCache"` + } + + // GGUFEstimateKVCache represents the usage of kv-cache. + GGUFEstimateKVCache struct { + // KeyType is the type of the cached key. + KeyType GGMLType `json:"keyType"` + // KeySize is the size of the cached key. + KeySize GGUFEstimateBytesScalar `json:"keySize"` + // ValueType is the type of the cached value. + ValueType GGMLType `json:"valueType"` + // ValueSize is the size of the cached value. + ValueSize GGUFEstimateBytesScalar `json:"valueSize"` + } +) + +// Estimate returns the estimated result of the GGUF file. +func (gf *GGUFFile) Estimate(g GGUFGeneralMetadata, a GGUFArchitectureMetadata, opts ...GGUFEstimateOption) (ge GGUFEstimate) { + var o _GGUFEstimateOptions + for _, opt := range opts { + opt(&o) + } + + // Inference. + if o.Approximate { + ge.Parameters = gf.estimateParameters(g, a) + ge.InferenceUsage.Memory = GGUFEstimateBytesScalar(g.FileType.GGMLType().RowSizeOf([]uint64{uint64(ge.Parameters)})) + } else { + for i := range gf.TensorInfos { + ge.Parameters += GGUFEstimateParametersScalar(gf.TensorInfos[i].Elements()) + ge.InferenceUsage.Memory += GGUFEstimateBytesScalar(gf.TensorInfos[i].Bytes()) + } + } + ge.InferenceUsage.KVCache = gf.estimateKVCache(a, o) + ge.BitsPerWeight = GGUFEstimateBitsPerWeightScalar(float64(ge.InferenceUsage.Memory*8) / float64(ge.Parameters)) + + return ge +} + +// estimateKVCache estimates the key-value cache, +// which is inspired by https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L2479-L2501 +func (gf *GGUFFile) estimateKVCache(a GGUFArchitectureMetadata, o _GGUFEstimateOptions) (kv GGUFEstimateKVCache) { + kv.KeyType = GGMLTypeF16 + kv.ValueType = GGMLTypeF16 + + if o.CacheKeyType != nil { + kv.KeyType = *o.CacheKeyType + } + if o.CacheValueType != nil { + kv.ValueType = *o.CacheValueType + } + + var ( + embedKeyGQA = uint64(a.AttentionKeyLength) * a.AttentionHeadCountKV + embedValGQA = uint64(a.AttentionValueLength) * a.AttentionHeadCountKV + kvSize = a.ContextLength + ) + { + // Correct. + if a.SSMConvolutionKernel > 0 { + embedKeyGQA += uint64(a.SSMConvolutionKernel - 1*a.SSMInnerSize) + embedValGQA += uint64(a.SSMStateSize * a.SSMInnerSize) + } + if o.ContextSize != nil { + kvSize = uint64(*o.ContextSize) + } + } + + kv.KeySize = GGUFEstimateBytesScalar(kv.KeyType.RowSizeOf([]uint64{embedKeyGQA * kvSize}) * a.BlockCount) + kv.ValueSize = GGUFEstimateBytesScalar(kv.ValueType.RowSizeOf([]uint64{embedValGQA * kvSize}) * a.BlockCount) + + return kv +} + +// estimateParameters estimates the number of parameters, +// which is inspired by https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L3969-L4388. +func (gf *GGUFFile) estimateParameters(g GGUFGeneralMetadata, a GGUFArchitectureMetadata) GGUFEstimateParametersScalar { + const ( + K = 1e3 + M = 1e3 * K + B = 1e3 * M + ) + + // https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L1718-L1761 + const ( + _14M = 14 * M + _17M = 17 * M + _22M = 22 * M + _33M = 33 * M + _70M = 70 * M + _109M = 109 * M + _137M = 137 * M + _160M = 160 * M + _335M = 335 * M + _410M = 410 * M + _0_5B = 0.5 * B + _1B = 1 * B + _1_4B = 1.4 * B + _2B = 2 * B + _2_8B = 2.8 * B + _3B = 3 * B + _4B = 4 * B + _6_9B = 6.9 * B + _7B = 7 * B + _8B = 8 * B + _12B = 12 * B + _13B = 13 * B + _14B = 14 * B + _15B = 15 * B + _20B = 20 * B + _30B = 30 * B + _34B = 34 * B + _35B = 35 * B + _40B = 40 * B + _65B = 65 * B + _70B = 70 * B + _314B = 314 * B + _SMALL = 0.1 * B + _MEDIUM = 0.4 * B + _LARGE = 0.8 * B + _XL = 1.5 * B + _A2_7B = 14.3 * B // Guess + _8x7B = 47 * B // Guess + _8x22B = 141 * B // Guess + _16x12B = 132 * B // Guess + _10B_128x3_66B = 480 * B // Guess + ) + + // Try historical statistics, + // https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L228-L263 + switch g.Architecture { + case "llama": + if a.ExpertCount == 8 { + switch a.BlockCount { + case 32: + return _8x7B + case 56: + return _8x22B + } + } else { + switch a.BlockCount { + case 22: + return _1B + case 26: + return _3B + case 32: + if a.VocabularyLength < 40000 { + return _7B + } + return _8B + case 40: + return _13B + case 48: + return _34B + case 60: + return _30B + case 80: + if a.AttentionHeadCount == a.AttentionHeadCountKV { + return _65B + } + return _70B + } + } + case "falcon": + switch a.BlockCount { + case 32: + return _7B + case 60: + return _40B + } + case "grok": + if a.BlockCount == 64 { + return _314B + } + case "gpt2": + switch a.BlockCount { + case 12: + return _SMALL + case 24: + return _MEDIUM + case 36: + return _LARGE + case 48: + return _XL + } + case "gptj": + case "gptneox": + switch a.BlockCount { + case 6: + switch a.FeedForwardLength { + case 512: + return _14M + case 2048: + return _70M + } + case 12: + if a.FeedForwardLength == 3072 { + return _160M + } + case 16: + if a.FeedForwardLength == 8192 { + return _1B + } + case 24: + switch a.FeedForwardLength { + case 4096: + return _410M + case 8192: + return _1_4B + } + case 32: + switch a.FeedForwardLength { + case 10240: + return _2_8B + case 16384: + return _6_9B + } + case 36: + if a.FeedForwardLength == 20480 { + return _12B + } + case 44: + if a.FeedForwardLength == 24576 { + return _20B + } + } + case "mpt": + switch a.BlockCount { + case 32: + return _7B + case 48: + return _30B + } + case "baichuan": + switch a.BlockCount { + case 32: + return _7B + case 40: + return _13B + } + case "starcoder": + switch a.BlockCount { + case 24: + return _1B + case 36: + return _3B + case 42: + return _7B + case 40: + return _15B + } + case "refact": + if a.BlockCount == 32 { + return _1B + } + case "bert": + switch a.BlockCount { + case 3: + return _17M + case 6: + return _22M + case 12: + switch a.EmbeddingLength { + case 384: + return _33M + case 768: + return _109M + } + case 24: + return _335M + } + case "nomic-bert": + if a.BlockCount == 12 && a.EmbeddingLength == 768 { + return _137M + } + case "jina-bert-v2": + switch a.BlockCount { + case 4: + return _33M + case 12: + return _137M + } + case "bloom": + switch a.BlockCount { + case 24: + return _1B + case 30: + switch a.EmbeddingLength { + case 2560: + return _3B + case 4096: + return _7B + } + } + case "stablelm": + switch a.BlockCount { + case 24: + return _1B + case 32: + return _3B + case 40: + return _12B + } + case "qwen": + switch a.BlockCount { + case 32: + return _7B + case 40: + return _13B + } + case "qwen2": + switch a.BlockCount { + case 24: + if a.EmbeddingLength == 1024 { + return _0_5B + } + return _1B + case 32: + return _7B + case 40: + if a.AttentionHeadCount == 20 { + return _4B + } + return _13B + case 80: + return _70B + } + case "qwen2moe": + if a.BlockCount == 24 { + return _A2_7B + } + case "phi2": + switch a.BlockCount { + case 24: + return _1B + case 32: + return _3B + } + case "phi3": + switch a.BlockCount { + case 24: + return _1B + case 32: + return _3B + case 40: + return _14B + } + case "plamo": + if a.BlockCount == 40 { + return _13B + } + case "codeshell": + if a.BlockCount == 42 { + return _SMALL + } + case "orion": + if a.BlockCount == 40 { + return _14B + } + case "internlm2": + switch a.BlockCount { + case 32: + return _7B + case 48: + return _20B + } + case "minicpm": + if a.BlockCount == 40 { + return _2B + } + case "gemma": + switch a.BlockCount { + case 18: + return _2B + case 28: + return _7B + } + case "starcoder2": + switch a.BlockCount { + case 30: + return _3B + case 32: + return _7B + case 40: + return _15B + } + case "mamba": + switch a.BlockCount { + case 24: + if a.EmbeddingLength == 768 { + return _SMALL + } + case 48: + switch a.EmbeddingLength { + case 1024: + return _MEDIUM + case 1536: + return _LARGE + case 2048: + return _XL + } + case 64: + if a.EmbeddingLength == 2560 { + return _3B + } + } + case "xverse": + switch a.BlockCount { + case 32: + return _7B + case 40: + return _13B + case 80: + return _65B + } + case "command-r": + if a.BlockCount == 40 { + return _35B + } + case "dbrx": + if a.BlockCount == 40 { + return _16x12B + } + case "olmo": + switch a.BlockCount { + case 22: + return _1B + case 32: + return _7B + case 80: + return _70B + } + case "arctic": + if a.ExpertCount == 128 && a.BlockCount == 35 { + return _10B_128x3_66B + } + } + + // Otherwise, calculate by experience. + // + // Let's say, the model is based on Transformer architecture, + // and use decoder-only. + // + // Vocabulary embedding parameter number(VeP), mainly includes the embedding matrix. + // The embedding matrix shape is [VocabularyLength, EmbeddingLength]. + // So the VeP value is VocabularyLength * EmbeddingLength. + // + // Self-Attention parameter number(SaP), includes Wq, Wk, Wv, Wo, and their bias. + // The all weight matrix shapes are [EmbeddingLength, EmbeddingLength], + // and the bias shapes are [EmbeddingLength]. + // So the SaP value is 4 * (EmbeddingLength * EmbeddingLength) + 4 * EmbeddingLength. + // + // Feed-Forward parameter number(FfP), includes W1, W2, and their bias. + // The W1 shape is [EmbeddingLength, 4*EmbeddingLength], its bias shape is [4*EmbeddingLength]. + // The W2 shape is [4*EmbeddingLength, EmbeddingLength], its bias shape is [EmbeddingLength]. + // So the FfP value is (EmbeddingLength * 4 * EmbeddingLength) + 4 * EmbeddingLength + (4 * EmbeddingLength * EmbeddingLength) + EmbeddingLength. + // + // There are two LayerNorm, one for Self-Attention, and another for Feed-Forward. + // Layer Normalization parameter number(LnP), includes scale and bias. + // The scale and bias shapes are [EmbeddingLength]. + // So the LnP value is 2 * (2 * EmbeddingLength). + // + // So the total parameters of a decoder-only model can estimate as below. + // Parameters = BlockCount * (SaP + FfP + LnP) + VeP + // = BlockCount * (12 * EmbeddingLength * EmbeddingLength + 13 * EmbeddingLength) + VocabularyLength * EmbeddingLength + + ret := a.BlockCount*(12*a.EmbeddingLength*a.EmbeddingLength+13*a.EmbeddingLength) + a.VocabularyLength*a.EmbeddingLength + // TODO MoE + return GGUFEstimateParametersScalar(ret) +} + +func (s GGUFEstimateParametersScalar) String() string { + switch { + case s >= 1e15: + return humanize.CommafWithDigits(float64(s)/1e15, 1) + " Q" + case s >= 1e12: + return humanize.CommafWithDigits(float64(s)/1e12, 1) + " T" + case s >= 1e9: + return humanize.CommafWithDigits(float64(s)/1e9, 1) + " B" + case s >= 1e6: + return humanize.CommafWithDigits(float64(s)/1e6, 1) + " M" + case s >= 1e3: + return humanize.CommafWithDigits(float64(s)/1e3, 1) + " K" + default: + return strconv.Itoa(int(s)) + } +} + +func (s GGUFEstimateBitsPerWeightScalar) String() string { + return strconv.FormatFloat(float64(s), 'f', 2, 64) + " bpw" +} + +func (s GGUFEstimateBytesScalar) String() string { + return humanize.IBytes(uint64(s)) +} + +func (c GGUFEstimateInferenceUsage) Total() GGUFEstimateBytesScalar { + return c.Memory + c.KVCache.Total() +} + +func (c GGUFEstimateKVCache) Total() GGUFEstimateBytesScalar { + return c.KeySize + c.ValueSize +} diff --git a/file_estimate_option.go b/file_estimate_option.go new file mode 100644 index 0000000..f6fe98d --- /dev/null +++ b/file_estimate_option.go @@ -0,0 +1,60 @@ +package gguf_parser + +import ( + "slices" +) + +type ( + _GGUFEstimateOptions struct { + ContextSize *int32 + CacheKeyType *GGMLType + CacheValueType *GGMLType + Approximate bool + } + GGUFEstimateOption func(*_GGUFEstimateOptions) +) + +// WithContextSize sets the context size for the estimate. +func WithContextSize(size int32) GGUFEstimateOption { + return func(o *_GGUFEstimateOptions) { + if size <= 0 { + return + } + o.ContextSize = &size + } +} + +// _GGUFEstimateCacheTypeAllowList is the allow list of cache key and value types. +var _GGUFEstimateCacheTypeAllowList = []GGMLType{ + GGMLTypeF32, + GGMLTypeF16, + GGMLTypeQ8_0, + GGMLTypeQ4_0, GGMLTypeQ4_1, + GGMLTypeIQ4_NL, + GGMLTypeQ5_0, GGMLTypeQ5_1, +} + +// WithCacheKeyType sets the cache key type for the estimate. +func WithCacheKeyType(t GGMLType) GGUFEstimateOption { + return func(o *_GGUFEstimateOptions) { + if slices.Contains(_GGUFEstimateCacheTypeAllowList, t) { + o.CacheKeyType = &t + } + } +} + +// WithCacheValueType sets the cache value type for the estimate. +func WithCacheValueType(t GGMLType) GGUFEstimateOption { + return func(o *_GGUFEstimateOptions) { + if slices.Contains(_GGUFEstimateCacheTypeAllowList, t) { + o.CacheValueType = &t + } + } +} + +// WithApproximate enable approximate estimate. +func WithApproximate() GGUFEstimateOption { + return func(o *_GGUFEstimateOptions) { + o.Approximate = true + } +} diff --git a/file_estimate_test.go b/file_estimate_test.go new file mode 100644 index 0000000..215b01d --- /dev/null +++ b/file_estimate_test.go @@ -0,0 +1,237 @@ +package gguf_parser + +import ( + "context" + "os" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestGGUFFile_Estimate(t *testing.T) { + ctx := context.Background() + + f, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + "Hermes-2-Pro-Mistral-7B.Q5_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + g, a := f.General(), f.Architecture() + t.Log("\n", spew.Sdump(f.Estimate(g, a, WithApproximate())), "\n") +} + +func BenchmarkGGUFFile_EstimateApproximate(b *testing.B) { + mp, ok := os.LookupEnv("TEST_MODEL_PATH") + if !ok { + b.Skip("TEST_MODEL_PATH is not set") + return + } + + f, err := ParseGGUFFile(mp, UseMMap(), UseDeferReading()) + if err != nil { + b.Fatal(err) + return + } + g, a := f.General(), f.Architecture() + + b.ReportAllocs() + + b.ResetTimer() + b.Run("WithoutApproximate", func(b *testing.B) { + for i := 0; i < b.N; i++ { + f.Estimate(g, a) + } + }) + + b.ResetTimer() + b.Run("WithApproximate", func(b *testing.B) { + for i := 0; i < b.N; i++ { + f.Estimate(g, a, WithApproximate()) + } + }) +} + +func TestGGUFFile_Estimate_Parameters(t *testing.T) { + ctx := context.Background() + + gfMixtral7B, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + gfMixtral8x7B, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO-GGUF", + "Nous-Hermes-2-Mixtral-8x7B-DPO.Q5_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + gfWizardLM8x22B, err := ParseGGUFFileFromHuggingFace( + ctx, + "MaziyarPanahi/WizardLM-2-8x22B-GGUF", + "WizardLM-2-8x22B.IQ1_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + cases := []struct { + name string + given *GGUFFile + opts []GGUFEstimateOption + }{ + { + name: "mixtral 7B", + given: gfMixtral7B, + }, + { + name: "mixtral 7B with approximate", + given: gfMixtral7B, + opts: []GGUFEstimateOption{WithApproximate()}, + }, + { + name: "mixtral 8x7B", + given: gfMixtral8x7B, + }, + { + name: "mixtral 8x7B with approximate", + given: gfMixtral8x7B, + opts: []GGUFEstimateOption{WithApproximate()}, + }, + { + name: "wizardlm 8x22B", + given: gfWizardLM8x22B, + }, + { + name: "wizardlm 8x22B with approximate", + given: gfWizardLM8x22B, + opts: []GGUFEstimateOption{WithApproximate()}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + f := tc.given + g, a := f.General(), f.Architecture() + t.Log("\n", spew.Sdump(f.Estimate(g, a, tc.opts...).Parameters), "\n") + }) + } +} + +func TestGGUFFile_Estimate_InferenceUsage_Memory(t *testing.T) { + ctx := context.Background() + + gfMixtral7B, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + gfMixtral8x7B, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO-GGUF", + "Nous-Hermes-2-Mixtral-8x7B-DPO.Q5_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + gfWizardLM8x22B, err := ParseGGUFFileFromHuggingFace( + ctx, + "MaziyarPanahi/WizardLM-2-8x22B-GGUF", + "WizardLM-2-8x22B.IQ1_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + cases := []struct { + name string + given *GGUFFile + opts []GGUFEstimateOption + }{ + { + name: "mixtral 7B", + given: gfMixtral7B, + }, + { + name: "mixtral 7B with approximate", + given: gfMixtral7B, + opts: []GGUFEstimateOption{WithApproximate()}, + }, + { + name: "mixtral 8x7B", + given: gfMixtral8x7B, + }, + { + name: "mixtral 8x7B with approximate", + given: gfMixtral8x7B, + opts: []GGUFEstimateOption{WithApproximate()}, + }, + { + name: "wizardlm 8x22B", + given: gfWizardLM8x22B, + }, + { + name: "wizardlm 8x22B with approximate", + given: gfWizardLM8x22B, + opts: []GGUFEstimateOption{WithApproximate()}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + f := tc.given + g, a := f.General(), f.Architecture() + t.Log("\n", spew.Sdump(f.Estimate(g, a, tc.opts...).InferenceUsage.Memory), "\n") + }) + } +} + +func TestGGUFFile_Estimate_InferenceUsage_KVCache(t *testing.T) { + ctx := context.Background() + + f, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + g, a := f.General(), f.Architecture() + + cases := []struct { + name string + opts []GGUFEstimateOption + }{ + {"1024(fp16)", []GGUFEstimateOption{WithContextSize(1024)}}, + {"1024(fp32)", []GGUFEstimateOption{WithContextSize(1024), WithCacheKeyType(GGMLTypeF32), WithCacheValueType(GGMLTypeF32)}}, + {"4096(fp16)", []GGUFEstimateOption{WithContextSize(4096)}}, + {"4096(fp32)", []GGUFEstimateOption{WithContextSize(4096), WithCacheKeyType(GGMLTypeF32), WithCacheValueType(GGMLTypeF32)}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Log("\n", spew.Sdump(f.Estimate(g, a, tc.opts...).InferenceUsage.KVCache), "\n") + }) + } +} diff --git a/file_general.go b/file_general.go new file mode 100644 index 0000000..567154a --- /dev/null +++ b/file_general.go @@ -0,0 +1,314 @@ +package gguf_parser + +import ( + "sort" + "strings" +) + +// GGUFGeneralMetadata represents the general metadata of a GGUF file. +type GGUFGeneralMetadata struct { + // Size is the size of the GGUF file in bytes. + Size int64 `json:"size"` + // LittleEndian is true if the GGUF file is little-endian, + // and false for big-endian. + LittleEndian bool `json:"littleEndian"` + // Architecture describes what architecture this model implements. + // + // All lowercase ASCII, with only [a-z0-9]+ characters allowed. + Architecture string `json:"architecture"` + // QuantizationVersion describes the version of the quantization format. + // + // Not required if the model is not quantized (i.e. no tensors are quantized). + // If any tensors are quantized, this must be present. + // This is separate to the quantization scheme of the tensors itself, + // + // the quantization version may change without changing the scheme's name, + // e.g. the quantization scheme is Q5_K, and the QuantizationVersion is 4. + QuantizationVersion uint32 `json:"quantizationVersion,omitempty"` + // Alignment describes the alignment of the GGUF file. + // + // This can vary to allow for different alignment schemes, but it must be a multiple of 8. + // Some writers may not write the alignment. + // + // Default is 32. + Alignment uint32 `json:"alignment"` + // Name to the model. + // + // This should be a human-readable name that can be used to identify the model. + // It should be unique within the community that the model is defined in. + Name string `json:"name"` + // Author to the model. + Author string `json:"author,omitempty"` + // URL to the model's homepage. + // + // This can be a GitHub repo, a paper, etc. + URL string `json:"url,omitempty"` + // Description to the model. + Description string `json:"description,omitempty"` + // License to the model. + // + // This is expressed as a SPDX license expression, e.g. "MIT OR Apache-2.0". + License string `json:"license,omitempty"` + // FileType describes the type of the majority of the tensors in the GGUF file. + FileType GGUFFileType `json:"fileType"` + // TensorTypesDistribution describes the distribution of tensor types in the GGUF file. + TensorTypesDistribution map[GGMLType]float64 `json:"tensorTypesDistribution"` +} + +// GGUFFileType is a type of GGUF file, +// see https://github.com/ggerganov/ggml/blob/0cbb7c0e053f5419cfbebb46fbf4d4ed60182cf5/include/ggml/ggml.h#L396-L421. +type GGUFFileType uint32 + +// GGUFFileType constants. +// +// GGUFFileTypeMostlyQ4_2, GGUFFileTypeMostlyQ4_3 are deprecated. +// +// GGUFFileTypeMostlyQ4_1_F16 is a special case where the majority of the tensors are Q4_1, +// but 'token_embd.weight' and 'output.weight' tensors are F16. +const ( + GGUFFileTypeAllF32 GGUFFileType = iota // F32 + GGUFFileTypeMostlyF16 // F16 + GGUFFileTypeMostlyQ4_0 // Q4_0 + GGUFFileTypeMostlyQ4_1 // Q4_1 + GGUFFileTypeMostlyQ4_1_F16 // Q4_1_F16 + GGUFFileTypeMostlyQ4_2 // Q4_2 + GGUFFileTypeMostlyQ4_3 // Q4_3 + GGUFFileTypeMostlyQ8_0 // Q8_0 + GGUFFileTypeMostlyQ5_0 // Q5_0 + GGUFFileTypeMostlyQ5_1 // Q5_1 + GGUFFileTypeMostlyQ2_K // Q2_K + GGUFFileTypeMostlyQ3_K // Q3_K/Q3_K_S + GGUFFileTypeMostlyQ4_K // Q4_K/Q3_K_M + GGUFFileTypeMostlyQ5_K // Q5_K/Q3_K_L + GGUFFileTypeMostlyQ6_K // Q6_K/Q4_K_S + GGUFFileTypeMostlyIQ2_XXS // IQ2_XXS/Q4_K_M + GGUFFileTypeMostlyIQ2_XS // IQ2_XS/Q5_K_S + GGUFFileTypeMostlyIQ3_XXS // IQ3_XXS/Q5_K_M + GGUFFileTypeMostlyIQ1_S // IQ1_S/Q6_K + GGUFFileTypeMostlyIQ4_NL // IQ4_NL + GGUFFileTypeMostlyIQ3_S // IQ3_S + GGUFFileTypeMostlyIQ2_S // IQ2_S + GGUFFileTypeMostlyIQ4_XS // IQ4_XS + GGUFFileTypeMostlyIQ1_M // IQ1_M + GGUFFileTypeMostlyBF16 // BF16 + _GGUFFileTypeCount // Unknown +) + +// General returns the general metadata of the GGUF file. +func (gf *GGUFFile) General() (gm GGUFGeneralMetadata) { + const ( + architectureKey = "general.architecture" + quantizationKey = "general.quantization_version" + alignmentKey = "general.alignment" + nameKey = "general.name" + authorKey = "general.author" + urlKey = "general.url" + descriptionKey = "general.description" + licenseKey = "general.license" + fileTypeKey = "general.file_type" + ) + + gm.Size = gf.Size + gm.LittleEndian = gf.LittleEndian() + gm.FileType = _GGUFFileTypeCount + gm.TensorTypesDistribution = map[GGMLType]float64{} + + m, _ := gf.Header.MetadataKV.Index([]string{ + architectureKey, + quantizationKey, + alignmentKey, + nameKey, + authorKey, + urlKey, + descriptionKey, + licenseKey, + fileTypeKey, + }) + + if v, ok := m[architectureKey]; ok { + gm.Architecture = v.ValueString() + } + if v, ok := m[quantizationKey]; ok { + gm.QuantizationVersion = ValueNumeric[uint32](v) + } + if v, ok := m[alignmentKey]; ok { + gm.Alignment = ValueNumeric[uint32](v) + } else { + gm.Alignment = 32 + } + if v, ok := m[nameKey]; ok { + gm.Name = v.ValueString() + } + if v, ok := m[authorKey]; ok { + gm.Author = v.ValueString() + } + if v, ok := m[urlKey]; ok { + gm.URL = v.ValueString() + } + if v, ok := m[descriptionKey]; ok { + gm.Description = v.ValueString() + } + if v, ok := m[licenseKey]; ok { + gm.License = v.ValueString() + } + if v, ok := m[fileTypeKey]; ok { + gm.FileType = GGUFFileType(ValueNumeric[uint32](v)) + } + + if gm.FileType >= _GGUFFileTypeCount { + gm.FileType = gf.guessGGUFFileType() + } + + return gm +} + +// GGMLType returns the GGMLType of the GGUFFileType, +// which is inspired by +// https://github.com/ggerganov/ggml/blob/a10a8b880c059b3b29356eb9a9f8df72f03cdb6a/src/ggml.c#L2730-L2763. +func (t GGUFFileType) GGMLType() GGMLType { + switch t { + case GGUFFileTypeAllF32: + return GGMLTypeF32 + case GGUFFileTypeMostlyF16: + return GGMLTypeF16 + case GGUFFileTypeMostlyQ4_0: + return GGMLTypeQ4_0 + case GGUFFileTypeMostlyQ4_1: + return GGMLTypeQ4_1 + case GGUFFileTypeMostlyQ4_2: + return GGMLTypeQ4_2 + case GGUFFileTypeMostlyQ4_3: + return GGMLTypeQ4_3 + case GGUFFileTypeMostlyQ8_0: + return GGMLTypeQ8_0 + case GGUFFileTypeMostlyQ5_0: + return GGMLTypeQ5_0 + case GGUFFileTypeMostlyQ5_1: + return GGMLTypeQ5_1 + case GGUFFileTypeMostlyQ2_K: + return GGMLTypeQ2_K + case GGUFFileTypeMostlyQ3_K: + return GGMLTypeQ3_K + case GGUFFileTypeMostlyQ4_K: + return GGMLTypeQ4_K + case GGUFFileTypeMostlyQ5_K: + return GGMLTypeQ5_K + case GGUFFileTypeMostlyQ6_K: + return GGMLTypeQ6_K + case GGUFFileTypeMostlyIQ2_XXS: + return GGMLTypeIQ2_XXS + case GGUFFileTypeMostlyIQ2_XS: + return GGMLTypeIQ2_XS + case GGUFFileTypeMostlyIQ3_XXS: + return GGMLTypeIQ3_XXS + case GGUFFileTypeMostlyIQ1_S: + return GGMLTypeIQ1_S + case GGUFFileTypeMostlyIQ4_NL: + return GGMLTypeIQ4_NL + case GGUFFileTypeMostlyIQ3_S: + return GGMLTypeIQ3_S + case GGUFFileTypeMostlyIQ2_S: + return GGMLTypeIQ2_S + case GGUFFileTypeMostlyIQ4_XS: + return GGMLTypeIQ4_XS + case GGUFFileTypeMostlyIQ1_M: + return GGMLTypeIQ1_M + case GGUFFileTypeMostlyBF16: + return GGMLTypeBF16 + default: + } + return _GGMLTypeCount +} + +// guessGGUFFileType guesses the GGUF file type by +// statistically analyzing the tensor types, +// which is inspired by +// https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML#provided-files. +func (gf *GGUFFile) guessGGUFFileType() GGUFFileType { + var ts []GGMLType + { + // Count. + cm := map[GGMLType]int{} + for i := range gf.TensorInfos { + if !strings.HasPrefix(gf.TensorInfos[i].Name, "blk") { + continue + } + cm[gf.TensorInfos[i].Type] = cm[gf.TensorInfos[i].Type] + 1 + } + + // Calculate. + ts = make([]GGMLType, 0, len(cm)) + for t := range cm { + ts = append(ts, t) + } + sort.Slice(ts, func(i, j int) bool { + return cm[ts[i]] > cm[ts[j]] + }) + } + + switch ts[0] { + case GGMLTypeF32: + return GGUFFileTypeAllF32 + case GGMLTypeF16: + return GGUFFileTypeMostlyF16 + case GGMLTypeQ4_0: + return GGUFFileTypeMostlyQ4_0 + case GGMLTypeQ4_1: + return GGUFFileTypeMostlyQ4_1 + case GGMLTypeQ4_2: + return GGUFFileTypeMostlyQ4_2 + case GGMLTypeQ4_3: + return GGUFFileTypeMostlyQ4_3 + case GGMLTypeQ5_0: + return GGUFFileTypeMostlyQ5_0 + case GGMLTypeQ5_1: + return GGUFFileTypeMostlyQ5_1 + case GGMLTypeQ8_0: + return GGUFFileTypeMostlyQ8_0 + case GGMLTypeQ2_K: + return GGUFFileTypeMostlyQ2_K + case GGMLTypeQ3_K: + switch ts[1] { + case GGMLTypeQ4_K: // Legacy, Q3_K_M. + return GGUFFileTypeMostlyQ4_K + case GGMLTypeQ5_K: // Legacy, Q3_K_L. + return GGUFFileTypeMostlyQ5_K + default: // Legacy. Q3_K_S + return GGUFFileTypeMostlyQ3_K + } + case GGMLTypeQ4_K: + if len(ts) > 2 && ts[2] == GGMLTypeQ6_K { // Legacy, Q4_K_M. + return GGUFFileTypeMostlyIQ2_XXS + } + return GGUFFileTypeMostlyQ6_K // Legacy. Q4_K_S + case GGMLTypeQ5_K: + if len(ts) > 2 && ts[2] == GGMLTypeQ6_K { // Legacy, Q5_K_M. + return GGUFFileTypeMostlyIQ3_XXS + } + return GGUFFileTypeMostlyIQ2_XS // Legacy. Q5_K_S + case GGMLTypeQ6_K: + return GGUFFileTypeMostlyIQ1_S // Legacy. Q6_K + case GGMLTypeIQ2_XXS: + return GGUFFileTypeMostlyIQ2_XXS + case GGMLTypeIQ2_XS: + return GGUFFileTypeMostlyIQ2_XS + case GGMLTypeIQ3_XXS: + return GGUFFileTypeMostlyIQ3_XXS + case GGMLTypeIQ1_S: + return GGUFFileTypeMostlyIQ1_S + case GGMLTypeIQ4_NL: + return GGUFFileTypeMostlyIQ4_NL + case GGMLTypeIQ3_S: + return GGUFFileTypeMostlyIQ3_S + case GGMLTypeIQ2_S: + return GGUFFileTypeMostlyIQ2_S + case GGMLTypeIQ4_XS: + return GGUFFileTypeMostlyIQ4_XS + case GGMLTypeIQ1_M: + return GGUFFileTypeMostlyIQ1_M + case GGMLTypeBF16: + return GGUFFileTypeMostlyBF16 + default: + } + return _GGUFFileTypeCount +} diff --git a/file_general_test.go b/file_general_test.go new file mode 100644 index 0000000..7be4184 --- /dev/null +++ b/file_general_test.go @@ -0,0 +1,81 @@ +package gguf_parser + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/stretchr/testify/assert" +) + +func TestGGUFFile_General(t *testing.T) { + ctx := context.Background() + + f, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + t.Log("\n", spew.Sdump(f.General()), "\n") +} + +func BenchmarkGGUFFile_General(b *testing.B) { + mp, ok := os.LookupEnv("TEST_MODEL_PATH") + if !ok { + b.Skip("TEST_MODEL_PATH is not set") + return + } + + f, err := ParseGGUFFile(mp, UseMMap(), UseDeferReading()) + if err != nil { + b.Fatal(err) + return + } + + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = f.General() + } +} + +func TestGGUFFile_guessGGUFFileType(t *testing.T) { + ctx := context.Background() + + cases := []string{ + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + } + for _, tc := range cases { + t.Run(tc, func(t *testing.T) { + gf, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + fmt.Sprintf("Hermes-2-Pro-Mistral-7B.%s.gguf", tc), + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + assert.Equal(t, gf.General().FileType.String(), gf.guessGGUFFileType().String(), tc+" file type should be equal") + }) + } +} diff --git a/file_option.go b/file_option.go new file mode 100644 index 0000000..5c5adfc --- /dev/null +++ b/file_option.go @@ -0,0 +1,76 @@ +package gguf_parser + +import "net/url" + +type ( + _GGUFReadOptions struct { + UseDebug bool + UseDeferReading bool + + // Local. + UseMMap bool + + // Remote. + ProxyURL *url.URL + SkipProxy bool + SkipTLSVerification bool + BufferSize int + } + GGUFReadOption func(o *_GGUFReadOptions) +) + +// UseDebug uses debug mode to read the file. +func UseDebug() GGUFReadOption { + return func(o *_GGUFReadOptions) { + o.UseDebug = true + } +} + +// UseDeferReading skips reading large or complex data value to speed up the process. +// +// For examples, to skip reading the items of array metadata, +// to skip reading the tensor values, etc. +func UseDeferReading() GGUFReadOption { + return func(o *_GGUFReadOptions) { + o.UseDeferReading = true + } +} + +// UseMMap uses mmap to read the local file. +func UseMMap() GGUFReadOption { + return func(o *_GGUFReadOptions) { + o.UseMMap = true + } +} + +// UseProxy uses the given url as a proxy when reading from a remote URL. +func UseProxy(url *url.URL) GGUFReadOption { + return func(o *_GGUFReadOptions) { + o.ProxyURL = url + } +} + +// SkipProxy skips the proxy when reading from a remote URL. +func SkipProxy() GGUFReadOption { + return func(o *_GGUFReadOptions) { + o.SkipProxy = true + } +} + +// SkipTLSVerification skips the TLS verification when reading from a remote URL. +func SkipTLSVerification() GGUFReadOption { + return func(o *_GGUFReadOptions) { + o.SkipTLSVerification = true + } +} + +// UseBufferSize sets the buffer size when reading from a remote URL. +func UseBufferSize(size int) GGUFReadOption { + const minSize = 32 * 1024 + if size < minSize { + size = minSize + } + return func(o *_GGUFReadOptions) { + o.BufferSize = size + } +} diff --git a/file_test.go b/file_test.go new file mode 100644 index 0000000..b467426 --- /dev/null +++ b/file_test.go @@ -0,0 +1,203 @@ +package gguf_parser + +import ( + "context" + "os" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestParseGGUFFile(t *testing.T) { + mp, ok := os.LookupEnv("TEST_MODEL_PATH") + if !ok { + t.Skip("TEST_MODEL_PATH is not set") + return + } + + // Slow read. + { + f, err := ParseGGUFFile(mp) + if err != nil { + t.Fatal(err) + return + } + s := spew.ConfigState{ + Indent: " ", + MaxDepth: 5, // Avoid console overflow. + } + t.Log("\n", s.Sdump(f), "\n") + } + + // Fast read. + { + f, err := ParseGGUFFile(mp, UseMMap(), UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + t.Log("\n", spew.Sdump(f), "\n") + } +} + +func BenchmarkParseGGUFFileMMap(b *testing.B) { + mp, ok := os.LookupEnv("TEST_MODEL_PATH") + if !ok { + b.Skip("TEST_MODEL_PATH is not set") + return + } + + b.ReportAllocs() + + b.ResetTimer() + b.Run("WithoutMMap", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFile(mp) + if err != nil { + b.Fatal(err) + return + } + } + }) + + b.ResetTimer() + b.Run("WithMMap", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFile(mp, UseMMap()) + if err != nil { + b.Fatal(err) + return + } + } + }) +} + +func BenchmarkParseGGUFFileDeferReading(b *testing.B) { + mp, ok := os.LookupEnv("TEST_MODEL_PATH") + if !ok { + b.Skip("TEST_MODEL_PATH is not set") + return + } + + b.ReportAllocs() + + b.ResetTimer() + b.Run("WithoutDeferReading", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFile(mp) + if err != nil { + b.Fatal(err) + return + } + } + }) + + b.ResetTimer() + b.Run("WithDeferReading", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFile(mp, UseDeferReading()) + if err != nil { + b.Fatal(err) + return + } + } + }) +} + +func TestParseGGUFFileRemote(t *testing.T) { + const u = "https://huggingface.co/NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF" + + "/resolve/main/Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf" + + ctx := context.Background() + + // Slow read. + { + f, err := ParseGGUFFileRemote(ctx, u, UseDebug()) + if err != nil { + t.Fatal(err) + return + } + s := spew.ConfigState{ + Indent: " ", + MaxDepth: 5, // Avoid console overflow. + } + t.Log("\n", s.Sdump(f), "\n") + } + + // Fast read. + { + f, err := ParseGGUFFileRemote(ctx, u, UseDebug(), UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + t.Log("\n", spew.Sdump(f), "\n") + } +} + +func BenchmarkParseGGUFFileRemoteWithBufferSize(b *testing.B) { + const u = "https://huggingface.co/NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF" + + "/resolve/main/Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf" + + ctx := context.Background() + + b.ReportAllocs() + + b.ResetTimer() + b.Run("256KibBuffer", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFileRemote(ctx, u, UseDeferReading(), UseBufferSize(256*1024)) + if err != nil { + b.Fatal(err) + return + } + } + }) + + b.ResetTimer() + b.Run("1MibBuffer", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFileRemote(ctx, u, UseDeferReading(), UseBufferSize(1024*1024)) + if err != nil { + b.Fatal(err) + return + } + } + }) + + b.ResetTimer() + b.Run("4MibBuffer", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFileRemote(ctx, u, UseDeferReading(), UseBufferSize(4*1024*1024)) + if err != nil { + b.Fatal(err) + return + } + } + }) +} + +func TestParseGGUFFileFromHuggingFace(t *testing.T) { + ctx := context.Background() + + cases := [][2]string{ + { + "TheBloke/Llama-2-13B-chat-GGUF", + "llama-2-13b-chat.Q8_0.gguf", + }, + { + "lmstudio-community/Yi-1.5-9B-Chat-GGUF", + "Yi-1.5-9B-Chat-Q5_K_M.gguf", + }, + } + for _, tc := range cases { + t.Run(tc[0]+"/"+tc[1], func(t *testing.T) { + f, err := ParseGGUFFileFromHuggingFace(ctx, tc[0], tc[1], UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + t.Log("\n", spew.Sdump(f), "\n") + }) + } +} diff --git a/file_tokenizer.go b/file_tokenizer.go new file mode 100644 index 0000000..2865a01 --- /dev/null +++ b/file_tokenizer.go @@ -0,0 +1,89 @@ +package gguf_parser + +// GGUFTokenizerMetadata represents the tokenizer metadata of a GGUF file. +type GGUFTokenizerMetadata struct { + // Model is the model of the tokenizer. + Model string `json:"model"` + // TokensLength is the size of tokens. + TokensLength uint64 `json:"tokenLength"` + // AddedTokensLength is the size of added tokens after training. + AddedTokensLength uint64 `json:"addedTokenLength"` + // BOSTokenID is the ID of the beginning of sentence token. + // + // Use -1 if the token is not found. + BOSTokenID int64 `json:"bosTokenID"` + // EOSTokenID is the ID of the end of sentence token. + // + // Use -1 if the token is not found. + EOSTokenID int64 `json:"eosTokenID"` + // UnknownTokenID is the ID of the unknown token. + // + // Use -1 if the token is not found. + UnknownTokenID int64 `json:"unknownTokenID"` + // SeparatorTokenID is the ID of the separator token. + // + // Use -1 if the token is not found. + SeparatorTokenID int64 `json:"separatorTokenID"` + // PaddingTokenID is the ID of the padding token. + // + // Use -1 if the token is not found. + PaddingTokenID int64 `json:"paddingTokenID"` +} + +// Tokenizer returns the tokenizer metadata of a GGUF file. +func (gf *GGUFFile) Tokenizer() (gt GGUFTokenizerMetadata) { + const ( + modelKey = "tokenizer.ggml.model" + tokensKey = "tokenizer.ggml.tokens" + addedTokensKey = "tokenizer.ggml.added_tokens" + bosTokenIDKey = "tokenizer.ggml.bos_token_id" + eosTokenIDKey = "tokenizer.ggml.eos_token_id" + unknownTokenIDKey = "tokenizer.ggml.unknown_token_id" + separatorTokenIDKey = "tokenizer.ggml.separator_token_id" + paddingTokenIDKey = "tokenizer.ggml.padding_token_id" + ) + + m, _ := gf.Header.MetadataKV.Index([]string{ + modelKey, + tokensKey, + addedTokensKey, + bosTokenIDKey, + eosTokenIDKey, + unknownTokenIDKey, + separatorTokenIDKey, + paddingTokenIDKey, + }) + + gt.BOSTokenID = -1 + gt.EOSTokenID = -1 + gt.UnknownTokenID = -1 + gt.SeparatorTokenID = -1 + gt.PaddingTokenID = -1 + + if v, ok := m[modelKey]; ok { + gt.Model = v.ValueString() + } + if v, ok := m[tokensKey]; ok { + gt.TokensLength = v.ValueArray().Len + } + if v, ok := m[addedTokensKey]; ok { + gt.AddedTokensLength = v.ValueArray().Len + } + if v, ok := m[bosTokenIDKey]; ok { + gt.BOSTokenID = ValueNumeric[int64](v) + } + if v, ok := m[eosTokenIDKey]; ok { + gt.EOSTokenID = ValueNumeric[int64](v) + } + if v, ok := m[unknownTokenIDKey]; ok { + gt.UnknownTokenID = ValueNumeric[int64](v) + } + if v, ok := m[separatorTokenIDKey]; ok { + gt.SeparatorTokenID = ValueNumeric[int64](v) + } + if v, ok := m[paddingTokenIDKey]; ok { + gt.PaddingTokenID = ValueNumeric[int64](v) + } + + return gt +} diff --git a/file_tokenizer_test.go b/file_tokenizer_test.go new file mode 100644 index 0000000..9ee685b --- /dev/null +++ b/file_tokenizer_test.go @@ -0,0 +1,46 @@ +package gguf_parser + +import ( + "context" + "os" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestGGUFFile_Tokenizer(t *testing.T) { + ctx := context.Background() + + f, err := ParseGGUFFileFromHuggingFace( + ctx, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + "Hermes-2-Pro-Mistral-7B.Q4_K_M.gguf", + UseDeferReading()) + if err != nil { + t.Fatal(err) + return + } + + t.Log("\n", spew.Sdump(f.Tokenizer()), "\n") +} + +func BenchmarkGGUFFile_Tokenizer(b *testing.B) { + mp, ok := os.LookupEnv("TEST_MODEL_PATH") + if !ok { + b.Skip("TEST_MODEL_PATH is not set") + return + } + + f, err := ParseGGUFFile(mp, UseMMap(), UseDeferReading()) + if err != nil { + b.Fatal(err) + return + } + + b.ReportAllocs() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = f.Tokenizer() + } +} diff --git a/filename.go b/filename.go new file mode 100644 index 0000000..df3772f --- /dev/null +++ b/filename.go @@ -0,0 +1,115 @@ +package gguf_parser + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/thxcode/gguf-parser-go/util/funcx" + "github.com/thxcode/gguf-parser-go/util/ptr" +) + +// GGUFFilename represents a GGUF filename, +// see https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention. +type GGUFFilename struct { + ModelName string `json:"modelName"` + Major *int `json:"major"` + Minor *int `json:"minor"` + ExpertsCount *int `json:"expertsCount,omitempty"` + Parameters string `json:"parameters"` + EncodingScheme string `json:"encodingScheme"` + Shard *int `json:"shard,omitempty"` + ShardTotal *int `json:"shardTotal,omitempty"` +} + +var GGUFFilenameRegex = regexp.MustCompile(`^(?P[A-Za-z0-9\s-]+)(?:-v(?P\d+)\.(?P\d+))?-(?:(?P\d+)x)?(?P\d+[A-Za-z]?)-(?P[\w_]+)(?:-(?P\d{5})-of-(?P\d{5}))?\.gguf$`) // nolint:lll + +// ParseGGUFFilename parses the given GGUF filename string, +// and returns the GGUFFilename, or nil if the filename is invalid. +func ParseGGUFFilename(name string) *GGUFFilename { + parseInt := func(v string) int { + return int(funcx.MustNoError(strconv.ParseInt(v, 10, 64))) + } + + n := name + if !strings.HasSuffix(n, ".gguf") { + n += ".gguf" + } + + m := make(map[string]string) + { + r := GGUFFilenameRegex.FindStringSubmatch(n) + for i, ne := range GGUFFilenameRegex.SubexpNames() { + if i != 0 && i <= len(r) { + m[ne] = r[i] + } + } + } + + if m["model_name"] == "" || m["parameters"] == "" || m["encoding_scheme"] == "" { + return nil + } + + var gn GGUFFilename + + gn.ModelName = strings.ReplaceAll(m["model_name"], "-", " ") + if v := m["major"]; v != "" { + gn.Major = ptr.To(parseInt(v)) + } + if v := m["minor"]; v != "" { + gn.Minor = ptr.To(parseInt(v)) + } + if v := m["experts_count"]; v != "" { + gn.ExpertsCount = ptr.To(parseInt(v)) + } + gn.Parameters = m["parameters"] + gn.EncodingScheme = m["encoding_scheme"] + if v := m["shard"]; v != "" { + gn.Shard = ptr.To(parseInt(v)) + } + if v := m["shardTotal"]; v != "" { + gn.ShardTotal = ptr.To(parseInt(v)) + } + return &gn +} + +func (gn GGUFFilename) String() string { + if gn.ModelName == "" || gn.Parameters == "" || gn.EncodingScheme == "" { + return "" + } + + var sb strings.Builder + sb.WriteString(strings.ReplaceAll(gn.ModelName, " ", "-")) + sb.WriteString("-") + if gn.Major != nil { + sb.WriteString("v") + sb.WriteString(strconv.Itoa(ptr.Deref(gn.Major, 0))) + sb.WriteString(".") + sb.WriteString(strconv.Itoa(ptr.Deref(gn.Minor, 0))) + sb.WriteString("-") + } + if v := ptr.Deref(gn.ExpertsCount, 0); v > 0 { + sb.WriteString(strconv.Itoa(v)) + sb.WriteString("x") + } + sb.WriteString(gn.Parameters) + sb.WriteString("-") + sb.WriteString(gn.EncodingScheme) + if m, n := ptr.Deref(gn.Shard, 0), ptr.Deref(gn.ShardTotal, 0); m > 0 && n > 0 { + sb.WriteString("-") + sb.WriteString(fmt.Sprintf("%05d", m)) + sb.WriteString("-of-") + sb.WriteString(fmt.Sprintf("%05d", n)) + } + sb.WriteString(".gguf") + return sb.String() +} + +func (gn GGUFFilename) IsPreRelease() bool { + return ptr.Deref(gn.Major, 0) == 0 && ptr.Deref(gn.Minor, 0) == 0 +} + +func (gn GGUFFilename) IsSharding() bool { + return ptr.Deref(gn.Shard, 0) > 0 && ptr.Deref(gn.ShardTotal, 0) > 0 +} diff --git a/filename_test.go b/filename_test.go new file mode 100644 index 0000000..8bbcbe1 --- /dev/null +++ b/filename_test.go @@ -0,0 +1,137 @@ +package gguf_parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/thxcode/gguf-parser-go/util/ptr" +) + +func TestParseGGUFFilename(t *testing.T) { + cases := []struct { + given string + expected *GGUFFilename + }{ + { + given: "Mixtral-v0.1-8x7B-KQ2.gguf", + expected: &GGUFFilename{ + ModelName: "Mixtral", + Major: ptr.To(0), + Minor: ptr.To(1), + ExpertsCount: ptr.To(8), + Parameters: "7B", + EncodingScheme: "KQ2", + }, + }, + { + given: "Grok-v1.0-100B-Q4_0-00003-of-00009.gguf", + expected: &GGUFFilename{ + ModelName: "Grok", + Major: ptr.To(1), + Minor: ptr.To(0), + Parameters: "100B", + EncodingScheme: "Q4_0", + Shard: ptr.To(3), + ShardTotal: ptr.To(9), + }, + }, + { + given: "Hermes-2-Pro-Llama-3-8B-F16.gguf", + expected: &GGUFFilename{ + ModelName: "Hermes 2 Pro Llama 3", + Parameters: "8B", + EncodingScheme: "F16", + }, + }, + { + given: "Hermes-2-Pro-Llama-3-v32.33-8Q-F16.gguf", + expected: &GGUFFilename{ + ModelName: "Hermes 2 Pro Llama 3", + Major: ptr.To(32), + Minor: ptr.To(33), + Parameters: "8Q", + EncodingScheme: "F16", + }, + }, + { + given: "not-a-known-arrangement.gguf", + expected: nil, + }, + } + for _, tc := range cases { + t.Run(tc.given, func(t *testing.T) { + actual := ParseGGUFFilename(tc.given) + assert.Equal(t, tc.expected, actual) + }) + } +} + +func TestGGUFFilenameString(t *testing.T) { + cases := []struct { + given GGUFFilename + expected string + }{ + { + given: GGUFFilename{ + ModelName: "Mixtral", + Major: ptr.To(0), + Minor: ptr.To(1), + ExpertsCount: ptr.To(8), + Parameters: "7B", + EncodingScheme: "KQ2", + }, + expected: "Mixtral-v0.1-8x7B-KQ2.gguf", + }, + { + given: GGUFFilename{ + ModelName: "Grok", + Major: ptr.To(1), + Minor: ptr.To(0), + Parameters: "100B", + EncodingScheme: "Q4_0", + Shard: ptr.To(3), + ShardTotal: ptr.To(9), + }, + expected: "Grok-v1.0-100B-Q4_0-00003-of-00009.gguf", + }, + { + given: GGUFFilename{ + ModelName: "Hermes 2 Pro Llama 3", + Parameters: "8B", + EncodingScheme: "F16", + }, + expected: "Hermes-2-Pro-Llama-3-8B-F16.gguf", + }, + { + given: GGUFFilename{ + ModelName: "Hermes 2 Pro Llama 3", + Major: ptr.To(0), + Minor: ptr.To(0), + Parameters: "8B", + EncodingScheme: "F16", + }, + expected: "Hermes-2-Pro-Llama-3-v0.0-8B-F16.gguf", + }, + { + given: GGUFFilename{ + ModelName: "Hermes 2 Pro Llama 3", + Major: ptr.To(32), + Minor: ptr.To(33), + Parameters: "8Q", + EncodingScheme: "F16", + }, + expected: "Hermes-2-Pro-Llama-3-v32.33-8Q-F16.gguf", + }, + { + given: GGUFFilename{}, + expected: "", + }, + } + for _, tc := range cases { + t.Run(tc.expected, func(t *testing.T) { + actual := tc.given.String() + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/gen.go b/gen.go new file mode 100644 index 0000000..6dfd072 --- /dev/null +++ b/gen.go @@ -0,0 +1,2 @@ +//go:generate go generate -tags tools gen.stringer.go +package gguf_parser diff --git a/gen.stringer.go b/gen.stringer.go new file mode 100644 index 0000000..faa8692 --- /dev/null +++ b/gen.stringer.go @@ -0,0 +1,10 @@ +//go:build tools + +//go:generate go run golang.org/x/tools/cmd/stringer -linecomment -type GGUFMagic -output zz_generated.ggufmagic.stringer.go -trimprefix GGUFMagic +//go:generate go run golang.org/x/tools/cmd/stringer -linecomment -type GGUFVersion -output zz_generated.ggufversion.stringer.go -trimprefix GGUFVersion +//go:generate go run golang.org/x/tools/cmd/stringer -linecomment -type GGUFMetadataValueType -output zz_generated.ggufmetadatavaluetype.stringer.go -trimprefix GGUFMetadataValueType +//go:generate go run golang.org/x/tools/cmd/stringer -linecomment -type GGUFFileType -output zz_generated.gguffiletype.stringer.go -trimprefix GGUFFileType +//go:generate go run golang.org/x/tools/cmd/stringer -linecomment -type GGMLType -output zz_generated.ggmltype.stringer.go -trimprefix GGMLType +package gguf_parser + +import _ "golang.org/x/tools/cmd/stringer" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..df64f18 --- /dev/null +++ b/go.mod @@ -0,0 +1,21 @@ +module github.com/thxcode/gguf-parser-go + +go 1.22 + +require ( + github.com/davecgh/go-spew v1.1.1 + github.com/dustin/go-humanize v1.0.1 + github.com/henvic/httpretty v0.1.3 + github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b + github.com/stretchr/testify v1.9.0 + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 + golang.org/x/sys v0.20.0 + golang.org/x/tools v0.21.0 +) + +require ( + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/sync v0.7.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e2f09c1 --- /dev/null +++ b/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/henvic/httpretty v0.1.3 h1:4A6vigjz6Q/+yAfTD4wqipCv+Px69C7Th/NhT0ApuU8= +github.com/henvic/httpretty v0.1.3/go.mod h1:UUEv7c2kHZ5SPQ51uS3wBpzPDibg2U3Y+IaXyHy5GBg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b h1:e9eeuSYSLmUKxy7ALzKcxo7ggTceQaVcBhjDIcewa9c= +github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw= +golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/util/bytex/pool.go b/util/bytex/pool.go new file mode 100644 index 0000000..6178c04 --- /dev/null +++ b/util/bytex/pool.go @@ -0,0 +1,91 @@ +package bytex + +import ( + "bytes" + "sync" +) + +const defaultSize = 32 * 1024 + +type ( + Bytes = []byte + BytesBuffer = *bytes.Buffer +) + +var gp = sync.Pool{ + New: func() any { + buf := make(Bytes, defaultSize) + return &buf + }, +} + +// GetBytes gets a bytes buffer from the pool, +// which can specify with a size, +// default is 32k. +func GetBytes(size ...uint64) Bytes { + buf := *(gp.Get().(*Bytes)) + + s := defaultSize + if len(size) != 0 { + s = int(size[0]) + if s == 0 { + s = defaultSize + } + } + if cap(buf) >= s { + return buf[:s] + } + + gp.Put(&buf) + + ns := s + if ns < defaultSize { + ns = defaultSize + } + buf = make(Bytes, ns) + return buf[:s] +} + +// WithBytes relies on GetBytes to get a buffer, +// calls the function with the buffer, +// finally, puts it back to the pool after the function returns. +func WithBytes(fn func(Bytes) error, size ...uint64) error { + if fn == nil { + return nil + } + + buf := GetBytes(size...) + defer Put(buf) + return fn(buf) +} + +// GetBuffer is similar to GetBytes, +// but it returns the bytes buffer wrapped by bytes.Buffer. +func GetBuffer(size ...uint64) BytesBuffer { + return bytes.NewBuffer(GetBytes(size...)[:0]) +} + +// WithBuffer relies on GetBuffer to get a buffer, +// calls the function with the buffer, +// finally, puts it back to the pool after the function returns. +func WithBuffer(fn func(BytesBuffer) error, size ...uint64) error { + if fn == nil { + return nil + } + + buf := GetBuffer(size...) + defer Put(buf) + return fn(buf) +} + +// Put puts the buffer(either Bytes or BytesBuffer) back to the pool. +func Put[T Bytes | BytesBuffer](buf T) { + switch v := any(buf).(type) { + case Bytes: + gp.Put(&v) + case BytesBuffer: + bs := v.Bytes() + gp.Put(&bs) + v.Reset() + } +} diff --git a/util/funcx/error.go b/util/funcx/error.go new file mode 100644 index 0000000..1023553 --- /dev/null +++ b/util/funcx/error.go @@ -0,0 +1,65 @@ +package funcx + +// NoError ignores the given error, +// it is usually a nice helper for chain function calling. +func NoError[T any](t T, _ error) T { + return t +} + +// NoError2 ignores the given error, +// it is usually a nice helper for chain function calling. +func NoError2[T, U any](t T, u U, _ error) (T, U) { + return t, u +} + +// NoError3 ignores the given error, +// it is usually a nice helper for chain function calling. +func NoError3[T, U, V any](t T, u U, v V, _ error) (T, U, V) { + return t, u, v +} + +// NoError4 ignores the given error, +// it is usually a nice helper for chain function calling. +func NoError4[T, U, V, W any](t T, u U, v V, w W, _ error) (T, U, V, W) { + return t, u, v, w +} + +// MustNoError is similar to NoError, +// but it panics if the given error is not nil, +// it is usually a nice helper for chain function calling. +func MustNoError[T any](t T, e error) T { + if e != nil { + panic(e) + } + return t +} + +// MustNoError2 is similar to NoError2, +// but it panics if the given error is not nil, +// it is usually a nice helper for chain function calling. +func MustNoError2[T, U any](t T, u U, e error) (T, U) { + if e != nil { + panic(e) + } + return t, u +} + +// MustNoError3 is similar to NoError3, +// but it panics if the given error is not nil, +// it is usually a nice helper for chain function calling. +func MustNoError3[T, U, V any](t T, u U, v V, e error) (T, U, V) { + if e != nil { + panic(e) + } + return t, u, v +} + +// MustNoError4 is similar to NoError4, +// but it panics if the given error is not nil, +// it is usually a nice helper for chain function calling. +func MustNoError4[T, U, V, W any](t T, u U, v V, w W, e error) (T, U, V, W) { + if e != nil { + panic(e) + } + return t, u, v, w +} diff --git a/util/httpx/client.go b/util/httpx/client.go new file mode 100644 index 0000000..3c182df --- /dev/null +++ b/util/httpx/client.go @@ -0,0 +1,226 @@ +package httpx + +import ( + "context" + "fmt" + "io" + "net/http" + + "github.com/henvic/httpretty" + + "github.com/thxcode/gguf-parser-go/util/bytex" +) + +// DefaultClient is similar to the default http.Client used by the package. +// +// It is used for requests pooling. +var DefaultClient = &http.Client{ + Transport: DefaultTransport, +} + +// DefaultInsecureClient is the default http.Client used by the package, +// with TLS insecure skip verify. +// +// It is used for requests pooling. +var DefaultInsecureClient = &http.Client{ + Transport: DefaultInsecureTransport, +} + +// Client returns a new http.Client with the given options, +// the result http.Client is used for fast-consuming requests. +// +// If you want a requests pool management, use DefaultClient instead. +func Client(opts ...*ClientOption) *http.Client { + var o *ClientOption + if len(opts) > 0 { + o = opts[0] + } else { + o = ClientOptions() + } + + root := DefaultTransport + if o.transport != nil { + root = o.transport + } + + if o.debug { + pretty := &httpretty.Logger{ + Time: true, + TLS: true, + RequestHeader: true, + RequestBody: true, + MaxRequestBody: 1024, + ResponseHeader: true, + ResponseBody: true, + MaxResponseBody: 1024, + Formatters: []httpretty.Formatter{&JSONFormatter{}}, + } + root = pretty.RoundTripper(root) + } + + transport := RoundTripperChain{ + Next: root, + } + for i := range o.roundTrips { + transport = RoundTripperChain{ + Do: o.roundTrips[i], + Next: transport, + } + } + + return &http.Client{ + Transport: transport, + Timeout: o.timeout, + } +} + +// NewGetRequestWithContext returns a new http.MethodGet request, +// which is saving your life from http.NewRequestWithContext. +func NewGetRequestWithContext(ctx context.Context, uri string) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) +} + +// NewGetRequest returns a new http.MethodGet request, +// which is saving your life from http.NewRequest. +func NewGetRequest(uri string) (*http.Request, error) { + return http.NewRequest(http.MethodGet, uri, nil) +} + +// NewHeadRequestWithContext returns a new http.MethodHead request, +// which is saving your life from http.NewRequestWithContext. +func NewHeadRequestWithContext(ctx context.Context, uri string) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodHead, uri, nil) +} + +// NewHeadRequest returns a new http.MethodHead request, +// which is saving your life from http.NewRequest. +func NewHeadRequest(uri string) (*http.Request, error) { + return http.NewRequest(http.MethodHead, uri, nil) +} + +// NewPostRequestWithContext returns a new http.MethodPost request with the given context, +// which is saving your life from http.NewRequestWithContext. +func NewPostRequestWithContext(ctx context.Context, uri string, body io.Reader) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodPost, uri, body) +} + +// NewPostRequest returns a new http.MethodPost request, +// which is saving your life from http.NewRequest. +func NewPostRequest(uri string, body io.Reader) (*http.Request, error) { + return http.NewRequest(http.MethodPost, uri, body) +} + +// NewPutRequestWithContext returns a new http.MethodPut request with the given context, +// which is saving your life from http.NewRequestWithContext. +func NewPutRequestWithContext(ctx context.Context, uri string, body io.Reader) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodPut, uri, body) +} + +// NewPutRequest returns a new http.MethodPut request, +// which is saving your life from http.NewRequest. +func NewPutRequest(uri string, body io.Reader) (*http.Request, error) { + return http.NewRequest(http.MethodPut, uri, body) +} + +// NewPatchRequestWithContext returns a new http.MethodPatch request with the given context, +// which is saving your life from http.NewRequestWithContext. +func NewPatchRequestWithContext(ctx context.Context, uri string, body io.Reader) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodPatch, uri, body) +} + +// NewPatchRequest returns a new http.MethodPatch request, +// which is saving your life from http.NewRequest. +func NewPatchRequest(uri string, body io.Reader) (*http.Request, error) { + return http.NewRequest(http.MethodPatch, uri, body) +} + +// NewDeleteRequestWithContext returns a new http.MethodDelete request with the given context, +// which is saving your life from http.NewRequestWithContext. +func NewDeleteRequestWithContext(ctx context.Context, uri string) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodDelete, uri, nil) +} + +// NewDeleteRequest returns a new http.MethodDelete request, +// which is saving your life from http.NewRequest. +func NewDeleteRequest(uri string) (*http.Request, error) { + return http.NewRequest(http.MethodDelete, uri, nil) +} + +// NewConnectRequestWithContext returns a new http.MethodConnect request with the given context, +// which is saving your life from http.NewRequestWithContext. +func NewConnectRequestWithContext(ctx context.Context, uri string) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodConnect, uri, nil) +} + +// NewConnectRequest returns a new http.MethodConnect request, +// which is saving your life from http.NewRequest. +func NewConnectRequest(uri string) (*http.Request, error) { + return http.NewRequest(http.MethodConnect, uri, nil) +} + +// NewOptionsRequestWithContext returns a new http.MethodOptions request with the given context, +// which is saving your life from http.NewRequestWithContext. +func NewOptionsRequestWithContext(ctx context.Context, uri string) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodOptions, uri, nil) +} + +// NewOptionsRequest returns a new http.MethodOptions request, +// which is saving your life from http.NewRequest. +func NewOptionsRequest(uri string) (*http.Request, error) { + return http.NewRequest(http.MethodOptions, uri, nil) +} + +// NewTraceRequestWithContext returns a new http.MethodTrace request with the given context, +// which is saving your life from http.NewRequestWithContext. +func NewTraceRequestWithContext(ctx context.Context, uri string) (*http.Request, error) { + return http.NewRequestWithContext(ctx, http.MethodTrace, uri, nil) +} + +// NewTraceRequest returns a new http.MethodTrace request, +// which is saving your life from http.NewRequest. +func NewTraceRequest(uri string) (*http.Request, error) { + return http.NewRequest(http.MethodTrace, uri, nil) +} + +// Error is similar to http.Error, +// but it can get the error message by the given code. +func Error(rw http.ResponseWriter, code int) { + http.Error(rw, http.StatusText(code), code) +} + +// Close closes the http response body without error. +func Close(resp *http.Response) { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } +} + +// BodyBytes returns the body of the http response as a byte slice. +func BodyBytes(resp *http.Response) []byte { + buf := bytex.GetBytes() + defer bytex.Put(buf) + + w := bytex.GetBuffer() + _, _ = io.CopyBuffer(w, resp.Body, buf) + return w.Bytes() +} + +// BodyString returns the body of the http response as a string. +func BodyString(resp *http.Response) string { + return string(BodyBytes(resp)) +} + +// Do is a helper function to execute the given http request with the given http client, +// and execute the given function with the http response. +// +// It is useful to avoid forgetting to close the http response body. +// +// Do will return the error if failed to execute the http request or the given function. +func Do(cli *http.Client, req *http.Request, respFunc func(*http.Response) error) error { + resp, err := cli.Do(req) + if err != nil { + return fmt.Errorf("do request: %w", err) + } + defer Close(resp) + return respFunc(resp) +} diff --git a/util/httpx/client_helper.go b/util/httpx/client_helper.go new file mode 100644 index 0000000..afe6ce5 --- /dev/null +++ b/util/httpx/client_helper.go @@ -0,0 +1,60 @@ +package httpx + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "regexp" + + "github.com/henvic/httpretty" +) + +var _ httpretty.Formatter = (*JSONFormatter)(nil) + +// JSONFormatter is copied from httpretty.JSONFormatter, +// but use our own json package. +type JSONFormatter struct{} + +var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) + +// Match JSON media type. +func (j *JSONFormatter) Match(mediatype string) bool { + return jsonTypeRE.MatchString(mediatype) +} + +// Format JSON content. +func (j *JSONFormatter) Format(w io.Writer, src []byte) error { + if !json.Valid(src) { + // We want to get the error of json.checkValid, not unmarshal it. + // The happy path has been optimized, maybe prematurely. + if err := json.Unmarshal(src, &json.RawMessage{}); err != nil { + return err + } + } + // Avoiding allocation as we use *bytes.Buffer to store the formatted body before printing + dst, ok := w.(*bytes.Buffer) + if !ok { + // Mitigating panic to avoid upsetting anyone who uses this directly + return errors.New("underlying writer for JSONFormatter must be *bytes.Buffer") + } + return json.Indent(dst, src, "", " ") +} + +type RoundTripperChain struct { + Do func(req *http.Request) error + Next http.RoundTripper +} + +func (c RoundTripperChain) RoundTrip(req *http.Request) (*http.Response, error) { + if c.Do != nil { + if err := c.Do(req); err != nil { + return nil, err + } + } + if c.Next != nil { + return c.Next.RoundTrip(req) + } + return nil, nil +} diff --git a/util/httpx/client_options.go b/util/httpx/client_options.go new file mode 100644 index 0000000..2c2fa16 --- /dev/null +++ b/util/httpx/client_options.go @@ -0,0 +1,112 @@ +package httpx + +import ( + "net/http" + "time" +) + +type ClientOption struct { + *TransportOption + + timeout time.Duration + debug bool + roundTrips []func(req *http.Request) error +} + +func ClientOptions() *ClientOption { + return &ClientOption{ + TransportOption: TransportOptions().WithoutKeepalive(), + timeout: 30 * time.Second, + } +} + +// WithTransport sets the TransportOption. +func (o *ClientOption) WithTransport(opt *TransportOption) *ClientOption { + if o == nil || opt == nil { + return o + } + o.TransportOption = opt + return o +} + +// WithTimeout sets the request timeout. +// +// This timeout controls the sum of [network dial], [tls handshake], [request], [response header reading] and [response body reading]. +// +// Use 0 to disable timeout. +func (o *ClientOption) WithTimeout(timeout time.Duration) *ClientOption { + if o == nil || timeout < 0 { + return o + } + o.timeout = timeout + return o +} + +// WithDebug sets the debug mode. +func (o *ClientOption) WithDebug() *ClientOption { + if o == nil { + return o + } + o.debug = true + return o +} + +// WithRoundTrip sets the round trip function. +func (o *ClientOption) WithRoundTrip(rt func(req *http.Request) error) *ClientOption { + if o == nil || rt == nil { + return o + } + o.roundTrips = append(o.roundTrips, rt) + return o +} + +// WithUserAgent sets the user agent. +func (o *ClientOption) WithUserAgent(ua string) *ClientOption { + return o.WithRoundTrip(func(req *http.Request) error { + req.Header.Set("User-Agent", ua) + return nil + }) +} + +// WithBearerAuth sets the bearer token. +func (o *ClientOption) WithBearerAuth(token string) *ClientOption { + return o.WithRoundTrip(func(req *http.Request) error { + req.Header.Set("Authorization", "Bearer "+token) + return nil + }) +} + +// WithBasicAuth sets the basic authentication. +func (o *ClientOption) WithBasicAuth(username, password string) *ClientOption { + return o.WithRoundTrip(func(req *http.Request) error { + req.SetBasicAuth(username, password) + return nil + }) +} + +// WithHeader sets the header. +func (o *ClientOption) WithHeader(key, value string) *ClientOption { + return o.WithRoundTrip(func(req *http.Request) error { + req.Header.Set(key, value) + return nil + }) +} + +// WithHeaders sets the headers. +func (o *ClientOption) WithHeaders(headers map[string]string) *ClientOption { + return o.WithRoundTrip(func(req *http.Request) error { + for k, v := range headers { + req.Header.Set(k, v) + } + return nil + }) +} + +// If is a conditional option, +// which receives a boolean condition to trigger the given function or not. +func (o *ClientOption) If(condition bool, then func(*ClientOption) *ClientOption) *ClientOption { + if condition { + return then(o) + } + return o +} diff --git a/util/httpx/file.go b/util/httpx/file.go new file mode 100644 index 0000000..5a2bf09 --- /dev/null +++ b/util/httpx/file.go @@ -0,0 +1,197 @@ +package httpx + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + "syscall" + + "github.com/smallnest/ringbuffer" + + "github.com/thxcode/gguf-parser-go/util/bytex" +) + +type SeekerFile struct { + cli *http.Client + req *http.Request + b *ringbuffer.RingBuffer + c int64 + l int64 +} + +func OpenSeekerFile(cli *http.Client, req *http.Request) (*SeekerFile, error) { + return OpenSeekerFileWithSize(cli, req, 0, 0) +} + +func OpenSeekerFileWithSize(cli *http.Client, req *http.Request, bufSize, size int) (*SeekerFile, error) { + if cli == nil { + return nil, errors.New("client is nil") + } + if req == nil { + return nil, errors.New("request is nil") + } + if req.Method != http.MethodGet { + return nil, errors.New("request method is not GET") + } + + var l int64 + { + req := req.Clone(req.Context()) + req.Method = http.MethodHead + err := Do(cli, req, func(resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("stat: status code %d", resp.StatusCode) + } + if !strings.EqualFold(resp.Header.Get("Accept-Ranges"), "bytes") { + return fmt.Errorf("stat: not support range download") + } + l = resp.ContentLength + return nil + }) + if err != nil { + return nil, fmt.Errorf("stat: do head request: %w", err) + } + switch sz := int64(size); { + case sz > l: + return nil, fmt.Errorf("size %d is greater than limit %d", size, l) + case sz <= 0: + default: + l = sz + } + } + + if bufSize <= 0 { + bufSize = 4 * 1024 * 1024 // 4mb + } + + b := ringbuffer.New(bufSize).WithCancel(req.Context()) + return &SeekerFile{cli: cli, req: req, b: b, c: 1<<63 - 1, l: l}, nil +} + +func (f *SeekerFile) Close() error { + if f.b != nil { + f.b.CloseWriter() + } + return nil +} + +func (f *SeekerFile) Len() int64 { + return f.l +} + +func (f *SeekerFile) ReadAt(p []byte, off int64) (int, error) { + if off < 0 { + return 0, syscall.EINVAL + } + if off > f.Len() { + return 0, io.EOF + } + + // Sync and move to new offset, if backward or empty buffer. + if f.c > off || f.b.IsEmpty() { + if err := f.sync(off, true); err != nil { + return 0, err + } + } + + var ( + remain = int64(f.b.Length()) + capacity = int64(f.b.Capacity()) + need = int64(len(p)) + ) + + switch { + case f.c+remain >= off+need: // Skip and move to new offset, if enough to forward. + if err := f.skip(off - f.c); err != nil { + return 0, err + } + return f.Read(p) + case f.c+capacity >= off+need: // Sync and move to new offset, if enough to forward after synced. + if err := f.sync(f.c+remain, false); err != nil { + return 0, err + } + if err := f.skip(off - f.c); err != nil { + return 0, err + } + return f.Read(p) + default: + } + + // Otherwise, read directly. + + f.b.Reset() + f.c = off + + // Request remain needing. + lim := off + int64(len(p)) - 1 + if lim > f.Len() { + lim = f.Len() + } + req := f.req.Clone(f.req.Context()) + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", off, lim)) + resp, err := f.cli.Do(req) + if err != nil { + return 0, err + } + defer Close(resp) + if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { + return 0, errors.New(resp.Status) + } + n, err := resp.Body.Read(p) + f.c += int64(n) + return n, err +} + +func (f *SeekerFile) Read(p []byte) (int, error) { + n, err := f.b.Read(p) + f.c += int64(n) + return n, err +} + +func (f *SeekerFile) sync(off int64, reset bool) error { + lim := off + int64(f.b.Free()) - 1 + if lim > f.Len() { + lim = f.Len() + } + req := f.req.Clone(f.req.Context()) + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", off, lim)) + + resp, err := f.cli.Do(req) + if err != nil { + return err + } + defer Close(resp) + if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + + buf := bytex.GetBytes() + defer bytex.Put(buf) + if reset { + f.b.Reset() + f.c = off + } + _, err = io.CopyBuffer(f.b, resp.Body, buf) + if err != nil { + return err + } + + return nil +} + +func (f *SeekerFile) skip(dif int64) error { + if dif <= 0 { + return nil + } + + buf := bytex.GetBytes(uint64(dif)) + defer bytex.Put(buf) + n, err := f.b.Read(buf) + f.c += int64(n) + if err != nil { + return err + } + return nil +} diff --git a/util/httpx/proxy.go b/util/httpx/proxy.go new file mode 100644 index 0000000..a7c33e0 --- /dev/null +++ b/util/httpx/proxy.go @@ -0,0 +1,37 @@ +package httpx + +import ( + "net" + "net/http" + "net/url" + "strings" + + "github.com/thxcode/gguf-parser-go/util/osx" +) + +var noProxies []*net.IPNet + +func init() { + noProxyEnv := osx.Getenv("NO_PROXY", osx.Getenv("no_proxy")) + noProxyRules := strings.Split(noProxyEnv, ",") + for i := range noProxyRules { + _, cidr, _ := net.ParseCIDR(noProxyRules[i]) + if cidr != nil { + noProxies = append(noProxies, cidr) + } + } +} + +// ProxyFromEnvironment is similar to http.ProxyFromEnvironment, +// but it also respects the NO_PROXY environment variable. +func ProxyFromEnvironment(r *http.Request) (*url.URL, error) { + if ip := net.ParseIP(r.URL.Hostname()); ip != nil { + for i := range noProxies { + if noProxies[i].Contains(ip) { + return nil, nil + } + } + } + + return http.ProxyFromEnvironment(r) +} diff --git a/util/httpx/transport.go b/util/httpx/transport.go new file mode 100644 index 0000000..2e30eb6 --- /dev/null +++ b/util/httpx/transport.go @@ -0,0 +1,25 @@ +package httpx + +import ( + "net/http" +) + +// DefaultTransport is similar to the default http.DefaultTransport used by the package. +var DefaultTransport http.RoundTripper = Transport() + +// DefaultInsecureTransport is the default http.DefaultTransport used by the package, +// with TLS insecure skip verify. +var DefaultInsecureTransport http.RoundTripper = Transport(TransportOptions().WithoutInsecureVerify()) + +// Transport returns a new http.Transport with the given options, +// the result http.Transport is used for constructing http.Client. +func Transport(opts ...*TransportOption) *http.Transport { + var o *TransportOption + if len(opts) > 0 { + o = opts[0] + } else { + o = TransportOptions() + } + + return o.transport +} diff --git a/util/httpx/transport_options.go b/util/httpx/transport_options.go new file mode 100644 index 0000000..fd0f224 --- /dev/null +++ b/util/httpx/transport_options.go @@ -0,0 +1,193 @@ +package httpx + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "time" +) + +type TransportOption struct { + dialer *net.Dialer + transport *http.Transport +} + +func TransportOptions() *TransportOption { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + transport := &http.Transport{ + Proxy: ProxyFromEnvironment, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + DialContext: dialer.DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + return &TransportOption{ + dialer: dialer, + transport: transport, + } +} + +// WithProxy sets the proxy. +func (o *TransportOption) WithProxy(proxy func(*http.Request) (*url.URL, error)) *TransportOption { + if o == nil || o.transport == nil { + return o + } + o.transport.Proxy = proxy + return o +} + +// WithoutProxy disables the proxy. +func (o *TransportOption) WithoutProxy() *TransportOption { + if o == nil || o.transport == nil { + return o + } + o.transport.Proxy = nil + return o +} + +// WithKeepalive sets the keepalive. +func (o *TransportOption) WithKeepalive(timeoutAndKeepalive ...time.Duration) *TransportOption { + if o == nil || o.transport == nil || o.dialer == nil { + return o + } + tak := [2]time.Duration{30 * time.Second, 30 * time.Second} + if len(timeoutAndKeepalive) > 0 { + tak[0] = timeoutAndKeepalive[0] + if len(timeoutAndKeepalive) > 1 { + tak[1] = timeoutAndKeepalive[1] + } + } + o.dialer.Timeout, o.dialer.KeepAlive = tak[0], tak[1] + o.transport.MaxIdleConns = 100 + o.transport.IdleConnTimeout = 90 * time.Second + return o +} + +// WithoutKeepalive disables the keepalive. +func (o *TransportOption) WithoutKeepalive() *TransportOption { + if o == nil || o.transport == nil { + return o + } + o.dialer.KeepAlive = -1 + o.transport.MaxIdleConns = 0 + o.transport.IdleConnTimeout = 0 + return o +} + +// WithInsecureVerify verifies the insecure connection. +func (o *TransportOption) WithInsecureVerify() *TransportOption { + if o == nil || o.transport == nil || o.transport.TLSClientConfig == nil { + return o + } + o.transport.TLSClientConfig.InsecureSkipVerify = false + return o +} + +// WithoutInsecureVerify skips the insecure connection verify. +func (o *TransportOption) WithoutInsecureVerify() *TransportOption { + if o == nil || o.transport == nil || o.transport.TLSClientConfig == nil { + return o + } + o.transport.TLSClientConfig.InsecureSkipVerify = true + return o +} + +// TimeoutForDial sets the timeout for network dial. +// +// This timeout controls the [network dial] only. +// +// Use 0 to disable timeout. +func (o *TransportOption) TimeoutForDial(timeout time.Duration) *TransportOption { + if o == nil || o.dialer == nil { + return o + } + o.dialer.Timeout = timeout + return o +} + +// TimeoutForResponseHeader sets the timeout for response header. +// +// This timeout controls the [response header reading] only. +// +// Use 0 to disable timeout. +func (o *TransportOption) TimeoutForResponseHeader(timeout time.Duration) *TransportOption { + if o == nil || o.transport == nil { + return o + } + o.transport.ResponseHeaderTimeout = timeout + return o +} + +// TimeoutForTLSHandshake sets the timeout for tls handshake. +// +// This timeout controls the [tls handshake] only. +// +// Use 0 to disable timeout. +func (o *TransportOption) TimeoutForTLSHandshake(timeout time.Duration) *TransportOption { + if o == nil || o.transport == nil { + return o + } + o.transport.TLSHandshakeTimeout = timeout + return o +} + +// TimeoutForIdleConn sets the timeout for idle connection. +// +// This timeout controls the [idle connection lifetime] only. +// +// Use 0 to disable timeout. +func (o *TransportOption) TimeoutForIdleConn(timeout time.Duration) *TransportOption { + if o == nil || o.transport == nil { + return o + } + o.transport.IdleConnTimeout = timeout + return o +} + +// WithTLSClientConfig sets the tls.Config. +func (o *TransportOption) WithTLSClientConfig(config *tls.Config) *TransportOption { + if o == nil || o.transport == nil { + return o + } + o.transport.TLSClientConfig = config + return o +} + +// WithDialer sets the dialer. +func (o *TransportOption) WithDialer(dialer *net.Dialer) *TransportOption { + if o == nil || o.transport == nil || dialer == nil { + return o + } + o.dialer = dialer + o.transport.DialContext = dialer.DialContext + return o +} + +// Customize sets the transport. +func (o *TransportOption) Customize(fn func(*http.Transport)) *TransportOption { + if o == nil || o.transport == nil { + return o + } + o.dialer = nil + fn(o.transport) + return o +} + +// If is a conditional option, +// which receives a boolean condition to trigger the given function or not. +func (o *TransportOption) If(condition bool, then func(*TransportOption) *TransportOption) *TransportOption { + if condition { + return then(o) + } + return o +} diff --git a/util/osx/env.go b/util/osx/env.go new file mode 100644 index 0000000..553aa45 --- /dev/null +++ b/util/osx/env.go @@ -0,0 +1,29 @@ +package osx + +import ( + "os" +) + +// ExistEnv checks if the environment variable named by the key exists. +func ExistEnv(key string) bool { + _, ok := os.LookupEnv(key) + return ok +} + +// Getenv retrieves the value of the environment variable named by the key. +// It returns the default, which will be empty if the variable is not present. +// To distinguish between an empty value and an unset value, use LookupEnv. +func Getenv(key string, def ...string) string { + e, ok := os.LookupEnv(key) + if !ok && len(def) != 0 { + return def[0] + } + + return e +} + +// ExpandEnv is similar to Getenv, +// but replaces ${var} or $var in the result. +func ExpandEnv(key string, def ...string) string { + return os.ExpandEnv(Getenv(key, def...)) +} diff --git a/util/osx/file.go b/util/osx/file.go new file mode 100644 index 0000000..e226076 --- /dev/null +++ b/util/osx/file.go @@ -0,0 +1,84 @@ +package osx + +import ( + "io" + "os" + "path/filepath" + "strings" +) + +// Open is similar to os.Open but supports ~ as the home directory. +func Open(path string) (*os.File, error) { + p := filepath.Clean(path) + if strings.HasPrefix(p, "~"+string(filepath.Separator)) { + hd, err := os.UserHomeDir() + if err != nil { + return nil, err + } + p = filepath.Join(hd, p[2:]) + } + return os.Open(p) +} + +// Exists checks if the given path exists. +func Exists(path string, checks ...func(os.FileInfo) bool) bool { + stat, err := os.Lstat(path) + if err != nil { + return false + } + + for i := range checks { + if checks[i] == nil { + continue + } + + if !checks[i](stat) { + return false + } + } + + return true +} + +// ExistsDir checks if the given path exists and is a directory. +func ExistsDir(path string) bool { + return Exists(path, func(stat os.FileInfo) bool { + return stat.Mode().IsDir() + }) +} + +// ExistsLink checks if the given path exists and is a symbolic link. +func ExistsLink(path string) bool { + return Exists(path, func(stat os.FileInfo) bool { + return stat.Mode()&os.ModeSymlink != 0 + }) +} + +// ExistsFile checks if the given path exists and is a regular file. +func ExistsFile(path string) bool { + return Exists(path, func(stat os.FileInfo) bool { + return stat.Mode().IsRegular() + }) +} + +// ExistsSocket checks if the given path exists and is a socket. +func ExistsSocket(path string) bool { + return Exists(path, func(stat os.FileInfo) bool { + return stat.Mode()&os.ModeSocket != 0 + }) +} + +// ExistsDevice checks if the given path exists and is a device. +func ExistsDevice(path string) bool { + return Exists(path, func(stat os.FileInfo) bool { + return stat.Mode()&os.ModeDevice != 0 + }) +} + +// Close closes the given io.Closer without error. +func Close(c io.Closer) { + if c == nil { + return + } + _ = c.Close() +} diff --git a/util/osx/file_mmap.go b/util/osx/file_mmap.go new file mode 100644 index 0000000..93f7072 --- /dev/null +++ b/util/osx/file_mmap.go @@ -0,0 +1,109 @@ +// Copyright 2018 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package osx + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "runtime/debug" + "strings" + "syscall" +) + +type MmapFile struct { + f *os.File + b []byte +} + +func OpenMmapFile(path string) (*MmapFile, error) { + return OpenMmapFileWithSize(path, 0) +} + +func OpenMmapFileWithSize(path string, size int) (*MmapFile, error) { + p := filepath.Clean(path) + if strings.HasPrefix(p, "~"+string(filepath.Separator)) { + hd, err := os.UserHomeDir() + if err != nil { + return nil, err + } + p = filepath.Join(hd, p[2:]) + } + + f, err := os.Open(p) + if err != nil { + return nil, fmt.Errorf("try lock file: %w", err) + } + if size <= 0 { + info, err := f.Stat() + if err != nil { + Close(f) + return nil, fmt.Errorf("stat: %w", err) + } + size = int(info.Size()) + } + + b, err := mmap(f, size) + if err != nil { + Close(f) + return nil, fmt.Errorf("mmap, size %d: %w", size, err) + } + + return &MmapFile{f: f, b: b}, nil +} + +func (f *MmapFile) Close() error { + err0 := munmap(f.b) + err1 := f.f.Close() + + if err0 != nil { + return err0 + } + return err1 +} + +func (f *MmapFile) Bytes() []byte { + return f.b +} + +func (f *MmapFile) Len() int64 { + return int64(len(f.b)) +} + +var ErrPageFault = errors.New("page fault occurred while reading from memory map") + +func (f *MmapFile) ReadAt(p []byte, off int64) (_ int, err error) { + if off < 0 { + return 0, syscall.EINVAL + } + if off > f.Len() { + return 0, io.EOF + } + + old := debug.SetPanicOnFault(true) + defer func() { + debug.SetPanicOnFault(old) + if recover() != nil { + err = ErrPageFault + } + }() + + n := copy(p, f.b[off:]) + if n < len(p) { + err = io.EOF + } + return n, err +} diff --git a/util/osx/file_mmap_js.go b/util/osx/file_mmap_js.go new file mode 100644 index 0000000..9172a88 --- /dev/null +++ b/util/osx/file_mmap_js.go @@ -0,0 +1,27 @@ +// Copyright 2022 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package osx + +import ( + "errors" + "os" +) + +func mmap(f *os.File, length int) ([]byte, error) { + return nil, errors.New("unsupported") +} + +func munmap(b []byte) (err error) { + return errors.New("unsupported") +} diff --git a/util/osx/file_mmap_unix.go b/util/osx/file_mmap_unix.go new file mode 100644 index 0000000..18725bb --- /dev/null +++ b/util/osx/file_mmap_unix.go @@ -0,0 +1,30 @@ +// Copyright 2017 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris + +package osx + +import ( + "os" + + "golang.org/x/sys/unix" +) + +func mmap(f *os.File, length int) ([]byte, error) { + return unix.Mmap(int(f.Fd()), 0, length, unix.PROT_READ, unix.MAP_SHARED) +} + +func munmap(b []byte) (err error) { + return unix.Munmap(b) +} diff --git a/util/osx/file_mmap_windows.go b/util/osx/file_mmap_windows.go new file mode 100644 index 0000000..b9879fc --- /dev/null +++ b/util/osx/file_mmap_windows.go @@ -0,0 +1,33 @@ +package osx + +import ( + "os" + "syscall" + "unsafe" +) + +func mmap(f *os.File, size int) ([]byte, error) { + low, high := uint32(size), uint32(size>>32) + h, errno := syscall.CreateFileMapping(syscall.Handle(f.Fd()), nil, syscall.PAGE_READONLY, high, low, nil) + if h == 0 { + return nil, os.NewSyscallError("CreateFileMapping", errno) + } + + addr, errno := syscall.MapViewOfFile(h, syscall.FILE_MAP_READ, 0, 0, uintptr(size)) + if addr == 0 { + return nil, os.NewSyscallError("MapViewOfFile", errno) + } + + if err := syscall.CloseHandle(h); err != nil { + return nil, os.NewSyscallError("CloseHandle", err) + } + + return (*[maxMapSize]byte)(unsafe.Pointer(uintptr(addr)))[:size], nil +} + +func munmap(b []byte) error { + if err := syscall.UnmapViewOfFile((uintptr)(unsafe.Pointer(&b[0]))); err != nil { + return os.NewSyscallError("UnmapViewOfFile", err) + } + return nil +} diff --git a/util/osx/file_mmap_windows_386.go b/util/osx/file_mmap_windows_386.go new file mode 100644 index 0000000..aab64c3 --- /dev/null +++ b/util/osx/file_mmap_windows_386.go @@ -0,0 +1,16 @@ +// Copyright 2018 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package osx + +const maxMapSize = 0x7FFFFFFF // 2GB diff --git a/util/osx/file_mmap_windows_non386.go b/util/osx/file_mmap_windows_non386.go new file mode 100644 index 0000000..c194070 --- /dev/null +++ b/util/osx/file_mmap_windows_non386.go @@ -0,0 +1,18 @@ +// Copyright 2018 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows && !386 + +package osx + +const maxMapSize = 0xFFFFFFFFFFFF // 256TB diff --git a/util/ptr/pointer.go b/util/ptr/pointer.go new file mode 100644 index 0000000..b7f1cc4 --- /dev/null +++ b/util/ptr/pointer.go @@ -0,0 +1,163 @@ +package ptr + +import ( + "time" + + "golang.org/x/exp/constraints" +) + +func Int(v int) *int { + return Ref(v) +} + +func IntDeref(v *int, def int) int { + return Deref(v, def) +} + +func Int8(v int8) *int8 { + return Ref(v) +} + +func Int8Deref(v *int8, def int8) int8 { + return Deref(v, def) +} + +func Int16(v int16) *int16 { + return Ref(v) +} + +func Int16Deref(v *int16, def int16) int16 { + return Deref(v, def) +} + +func Int32(v int32) *int32 { + return Ref(v) +} + +func Int32Deref(v *int32, def int32) int32 { + return Deref(v, def) +} + +func Int64(v int64) *int64 { + return Ref(v) +} + +func Int64Deref(v *int64, def int64) int64 { + return Deref(v, def) +} + +func Uint(v uint) *uint { + return Ref(v) +} + +func UintDeref(v *uint, def uint) uint { + return Deref(v, def) +} + +func Uint8(v uint8) *uint8 { + return Ref(v) +} + +func Uint8Deref(v *uint8, def uint8) uint8 { + return Deref(v, def) +} + +func Uint16(v uint16) *uint16 { + return Ref(v) +} + +func Uint16Deref(v *uint16, def uint16) uint16 { + return Deref(v, def) +} + +func Uint32(v uint32) *uint32 { + return Ref(v) +} + +func Uint32Deref(v *uint32, def uint32) uint32 { + return Deref(v, def) +} + +func Uint64(v uint64) *uint64 { + return Ref(v) +} + +func Uint64Deref(v *uint64, def uint64) uint64 { + return Deref(v, def) +} + +func Float32(v float32) *float32 { + return Ref(v) +} + +func Float32Deref(v *float32, def float32) float32 { + return Deref(v, def) +} + +func Float64(v float64) *float64 { + return Ref(v) +} + +func Float64Deref(v *float64, def float64) float64 { + return Deref(v, def) +} + +func String(v string) *string { + return Ref(v) +} + +func StringDeref(v *string, def string) string { + return Deref(v, def) +} + +func Bool(v bool) *bool { + return Ref(v) +} + +func BoolDeref(v *bool, def bool) bool { + return Deref(v, def) +} + +func Duration(v time.Duration) *time.Duration { + return Ref(v) +} + +func DurationDeref(v *time.Duration, def time.Duration) time.Duration { + return Deref(v, def) +} + +func Time(v time.Time) *time.Time { + return Ref(v) +} + +func TimeDeref(v *time.Time, def time.Time) time.Time { + return Deref(v, def) +} + +type Pointerable interface { + constraints.Ordered | ~bool | time.Time +} + +func Ref[T Pointerable](v T) *T { + return &v +} + +func To[T Pointerable](v T) *T { + return Ref(v) +} + +func Deref[T Pointerable](ptr *T, def T) T { + if ptr != nil { + return *ptr + } + + return def +} + +func Equal[T Pointerable](a, b *T) bool { + if a != nil && b != nil { + return *a == *b + } + + return false +} diff --git a/util/stringx/string.go b/util/stringx/string.go new file mode 100644 index 0000000..24449d7 --- /dev/null +++ b/util/stringx/string.go @@ -0,0 +1,91 @@ +package stringx + +import ( + "fmt" + "strings" + "unicode" + "unicode/utf8" +) + +// Join concatenates the elements of strs to create a single string. +func Join[T ~string](sep string, strs ...T) string { + switch len(strs) { + case 0: + return "" + case 1: + return string(strs[0]) + } + + n := len(sep) * (len(strs) - 1) + for i := 0; i < len(strs); i++ { + n += len(strs[i]) + } + + var b strings.Builder + + b.Grow(n) + b.WriteString(string(strs[0])) + + for i := range strs[1:] { + b.WriteString(sep) + b.WriteString(string(strs[i+1])) + } + + return b.String() +} + +// Strings converts a slice of fmt.Stringer to a slice of string. +func Strings[T fmt.Stringer](v []T) []string { + if len(v) == 0 { + return nil + } + + s := make([]string, len(v)) + for i := range v { + s[i] = v[i].String() + } + + return s +} + +var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} + +// IsSpace reports whether the rune is a space character. +func IsSpace(r rune) bool { + if r > utf8.RuneSelf { + return unicode.IsSpace(r) + } + return asciiSpace[r] == 1 +} + +// TrimAllSpace removes space chars from the given string. +func TrimAllSpace(s string) string { + if len(s) == 0 { + return s + } + + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + if IsSpace(r) { + continue + } + b.WriteRune(r) + } + return b.String() +} + +// ReplaceFunc returns a copy of the string s with all +// non-overlapping instances of old replaced by new. +func ReplaceFunc(s string, rp func(rune) rune) string { + if len(s) == 0 || rp == nil { + return s + } + + var b strings.Builder + b.Grow(len(s)) + for _, r := range s { + b.WriteRune(rp(r)) + } + return b.String() +} diff --git a/zz_generated.ggmltype.stringer.go b/zz_generated.ggmltype.stringer.go new file mode 100644 index 0000000..5770a63 --- /dev/null +++ b/zz_generated.ggmltype.stringer.go @@ -0,0 +1,54 @@ +// Code generated by "stringer -linecomment -type GGMLType -output zz_generated.ggmltype.stringer.go -trimprefix GGMLType"; DO NOT EDIT. + +package gguf_parser + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[GGMLTypeF32-0] + _ = x[GGMLTypeF16-1] + _ = x[GGMLTypeQ4_0-2] + _ = x[GGMLTypeQ4_1-3] + _ = x[GGMLTypeQ4_2-4] + _ = x[GGMLTypeQ4_3-5] + _ = x[GGMLTypeQ5_0-6] + _ = x[GGMLTypeQ5_1-7] + _ = x[GGMLTypeQ8_0-8] + _ = x[GGMLTypeQ8_1-9] + _ = x[GGMLTypeQ2_K-10] + _ = x[GGMLTypeQ3_K-11] + _ = x[GGMLTypeQ4_K-12] + _ = x[GGMLTypeQ5_K-13] + _ = x[GGMLTypeQ6_K-14] + _ = x[GGMLTypeQ8_K-15] + _ = x[GGMLTypeIQ2_XXS-16] + _ = x[GGMLTypeIQ2_XS-17] + _ = x[GGMLTypeIQ3_XXS-18] + _ = x[GGMLTypeIQ1_S-19] + _ = x[GGMLTypeIQ4_NL-20] + _ = x[GGMLTypeIQ3_S-21] + _ = x[GGMLTypeIQ2_S-22] + _ = x[GGMLTypeIQ4_XS-23] + _ = x[GGMLTypeI8-24] + _ = x[GGMLTypeI16-25] + _ = x[GGMLTypeI32-26] + _ = x[GGMLTypeI64-27] + _ = x[GGMLTypeF64-28] + _ = x[GGMLTypeIQ1_M-29] + _ = x[GGMLTypeBF16-30] + _ = x[_GGMLTypeCount-31] +} + +const _GGMLType_name = "F32F16Q4_0Q4_1Q4_2Q4_3Q5_0Q5_1Q8_0Q8_1Q2_KQ3_KQ4_KQ5_KQ6_KQ8_KIQ2_XXSIQ2_XSIQ3_XXSIQ1_SIQ4_NLIQ3_SIQ2_SIQ4_XSI8I16I32I64F64IQ1_MBF16Unknown" + +var _GGMLType_index = [...]uint8{0, 3, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 69, 75, 82, 87, 93, 98, 103, 109, 111, 114, 117, 120, 123, 128, 132, 139} + +func (i GGMLType) String() string { + if i >= GGMLType(len(_GGMLType_index)-1) { + return "GGMLType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _GGMLType_name[_GGMLType_index[i]:_GGMLType_index[i+1]] +} diff --git a/zz_generated.gguffiletype.stringer.go b/zz_generated.gguffiletype.stringer.go new file mode 100644 index 0000000..9b98ca1 --- /dev/null +++ b/zz_generated.gguffiletype.stringer.go @@ -0,0 +1,48 @@ +// Code generated by "stringer -linecomment -type GGUFFileType -output zz_generated.gguffiletype.stringer.go -trimprefix GGUFFileType"; DO NOT EDIT. + +package gguf_parser + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[GGUFFileTypeAllF32-0] + _ = x[GGUFFileTypeMostlyF16-1] + _ = x[GGUFFileTypeMostlyQ4_0-2] + _ = x[GGUFFileTypeMostlyQ4_1-3] + _ = x[GGUFFileTypeMostlyQ4_1_F16-4] + _ = x[GGUFFileTypeMostlyQ4_2-5] + _ = x[GGUFFileTypeMostlyQ4_3-6] + _ = x[GGUFFileTypeMostlyQ8_0-7] + _ = x[GGUFFileTypeMostlyQ5_0-8] + _ = x[GGUFFileTypeMostlyQ5_1-9] + _ = x[GGUFFileTypeMostlyQ2_K-10] + _ = x[GGUFFileTypeMostlyQ3_K-11] + _ = x[GGUFFileTypeMostlyQ4_K-12] + _ = x[GGUFFileTypeMostlyQ5_K-13] + _ = x[GGUFFileTypeMostlyQ6_K-14] + _ = x[GGUFFileTypeMostlyIQ2_XXS-15] + _ = x[GGUFFileTypeMostlyIQ2_XS-16] + _ = x[GGUFFileTypeMostlyIQ3_XXS-17] + _ = x[GGUFFileTypeMostlyIQ1_S-18] + _ = x[GGUFFileTypeMostlyIQ4_NL-19] + _ = x[GGUFFileTypeMostlyIQ3_S-20] + _ = x[GGUFFileTypeMostlyIQ2_S-21] + _ = x[GGUFFileTypeMostlyIQ4_XS-22] + _ = x[GGUFFileTypeMostlyIQ1_M-23] + _ = x[GGUFFileTypeMostlyBF16-24] + _ = x[_GGUFFileTypeCount-25] +} + +const _GGUFFileType_name = "F32F16Q4_0Q4_1Q4_1_F16Q4_2Q4_3Q8_0Q5_0Q5_1Q2_KQ3_K/Q3_K_SQ4_K/Q3_K_MQ5_K/Q3_K_LQ6_K/Q4_K_SIQ2_XXS/Q4_K_MIQ2_XS/Q5_K_SIQ3_XXS/Q5_K_MIQ1_S/Q6_KIQ4_NLIQ3_SIQ2_SIQ4_XSIQ1_MBF16Unknown" + +var _GGUFFileType_index = [...]uint8{0, 3, 6, 10, 14, 22, 26, 30, 34, 38, 42, 46, 57, 68, 79, 90, 104, 117, 131, 141, 147, 152, 157, 163, 168, 172, 179} + +func (i GGUFFileType) String() string { + if i >= GGUFFileType(len(_GGUFFileType_index)-1) { + return "GGUFFileType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _GGUFFileType_name[_GGUFFileType_index[i]:_GGUFFileType_index[i+1]] +} diff --git a/zz_generated.ggufmagic.stringer.go b/zz_generated.ggufmagic.stringer.go new file mode 100644 index 0000000..e6227e4 --- /dev/null +++ b/zz_generated.ggufmagic.stringer.go @@ -0,0 +1,41 @@ +// Code generated by "stringer -linecomment -type GGUFMagic -output zz_generated.ggufmagic.stringer.go -trimprefix GGUFMagic"; DO NOT EDIT. + +package gguf_parser + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[GGUFMagicGGML-1734831468] + _ = x[GGUFMagicGGMF-1734831462] + _ = x[GGUFMagicGGJT-1734830708] + _ = x[GGUFMagicGGUFLe-1179993927] + _ = x[GGUFMagicGGUFBe-1195857222] +} + +const ( + _GGUFMagic_name_0 = "GGUF" + _GGUFMagic_name_1 = "GGUF" + _GGUFMagic_name_2 = "GGJT" + _GGUFMagic_name_3 = "GGMF" + _GGUFMagic_name_4 = "GGML" +) + +func (i GGUFMagic) String() string { + switch { + case i == 1179993927: + return _GGUFMagic_name_0 + case i == 1195857222: + return _GGUFMagic_name_1 + case i == 1734830708: + return _GGUFMagic_name_2 + case i == 1734831462: + return _GGUFMagic_name_3 + case i == 1734831468: + return _GGUFMagic_name_4 + default: + return "GGUFMagic(" + strconv.FormatInt(int64(i), 10) + ")" + } +} diff --git a/zz_generated.ggufmetadatavaluetype.stringer.go b/zz_generated.ggufmetadatavaluetype.stringer.go new file mode 100644 index 0000000..78760c6 --- /dev/null +++ b/zz_generated.ggufmetadatavaluetype.stringer.go @@ -0,0 +1,36 @@ +// Code generated by "stringer -linecomment -type GGUFMetadataValueType -output zz_generated.ggufmetadatavaluetype.stringer.go -trimprefix GGUFMetadataValueType"; DO NOT EDIT. + +package gguf_parser + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[GGUFMetadataValueTypeUint8-0] + _ = x[GGUFMetadataValueTypeInt8-1] + _ = x[GGUFMetadataValueTypeUint16-2] + _ = x[GGUFMetadataValueTypeInt16-3] + _ = x[GGUFMetadataValueTypeUint32-4] + _ = x[GGUFMetadataValueTypeInt32-5] + _ = x[GGUFMetadataValueTypeFloat32-6] + _ = x[GGUFMetadataValueTypeBool-7] + _ = x[GGUFMetadataValueTypeString-8] + _ = x[GGUFMetadataValueTypeArray-9] + _ = x[GGUFMetadataValueTypeUint64-10] + _ = x[GGUFMetadataValueTypeInt64-11] + _ = x[GGUFMetadataValueTypeFloat64-12] + _ = x[_GGUFMetadataValueTypeCount-13] +} + +const _GGUFMetadataValueType_name = "Uint8Int8Uint16Int16Uint32Int32Float32BoolStringArrayUint64Int64Float64Unknown" + +var _GGUFMetadataValueType_index = [...]uint8{0, 5, 9, 15, 20, 26, 31, 38, 42, 48, 53, 59, 64, 71, 78} + +func (i GGUFMetadataValueType) String() string { + if i >= GGUFMetadataValueType(len(_GGUFMetadataValueType_index)-1) { + return "GGUFMetadataValueType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _GGUFMetadataValueType_name[_GGUFMetadataValueType_index[i]:_GGUFMetadataValueType_index[i+1]] +} diff --git a/zz_generated.ggufversion.stringer.go b/zz_generated.ggufversion.stringer.go new file mode 100644 index 0000000..a54ffe9 --- /dev/null +++ b/zz_generated.ggufversion.stringer.go @@ -0,0 +1,26 @@ +// Code generated by "stringer -linecomment -type GGUFVersion -output zz_generated.ggufversion.stringer.go -trimprefix GGUFVersion"; DO NOT EDIT. + +package gguf_parser + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[GGUFVersionV1-1] + _ = x[GGUFVersionV2-2] + _ = x[GGUFVersionV3-3] +} + +const _GGUFVersion_name = "V1V2V3" + +var _GGUFVersion_index = [...]uint8{0, 2, 4, 6} + +func (i GGUFVersion) String() string { + i -= 1 + if i >= GGUFVersion(len(_GGUFVersion_index)-1) { + return "GGUFVersion(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _GGUFVersion_name[_GGUFVersion_index[i]:_GGUFVersion_index[i+1]] +}