diff --git a/cmd/gguf-parser/README.md b/cmd/gguf-parser/README.md index 7167c8a..2d7c1a6 100644 --- a/cmd/gguf-parser/README.md +++ b/cmd/gguf-parser/README.md @@ -37,6 +37,10 @@ Usage of gguf-parser ...: Specify how many layers to offload, which is used to estimate the usage, default is full offloaded. [Deprecated, use --gpu-layers instead] (default -1) -offload-layers-step uint Specify the step of layers to offload, works with --offload-layers. [Deprecated, use --gpu-layers-step instead] + -ol-crawl + Crawl the Ollama model instead of blobs fetching, which will be more efficient and faster, but lossy. + -ol-model string + Model name of Ollama, e.g. gemma2. -parallel-size int Specify the number of parallel sequences to decode, which is used to estimate the usage, default is 1. (default 1) -path string @@ -61,6 +65,7 @@ Usage of gguf-parser ...: Url where the GGUF file to load, e.g. https://huggingface.co/NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF/resolve/main/Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf. Note that gguf-parser does not need to download the entire GGUF file. -version Show gguf-parser version. + ``` ### Parse @@ -155,6 +160,62 @@ $ gguf-parser --hf-repo="openbmb/MiniCPM-Llama3-V-2_5-gguf" --hf-file="ggml-mode ``` +#### Parse Ollama model + +```shell +$ gguf-parser --ol-model="gemma2" ++-------+--------+--------+----------------------+-----------+---------------+----------+------------+----------+ +| MODEL | NAME | ARCH | QUANTIZATION VERSION | FILE TYPE | LITTLE ENDIAN | SIZE | PARAMETERS | BPW | ++ +--------+--------+----------------------+-----------+---------------+----------+------------+----------+ +| | gemma2 | gemma2 | 2 | Q4_0 | true | 5.06 GiB | 9.24 B | 4.71 bpw | ++-------+--------+--------+----------------------+-----------+---------------+----------+------------+----------+ + ++--------------+-----------------+---------------+---------------+--------------------+--------+------------------+------------+----------------+ +| ARCHITECTURE | MAX CONTEXT LEN | EMBEDDING LEN | EMBEDDING GQA | ATTENTION HEAD CNT | LAYERS | FEED FORWARD LEN | EXPERT CNT | VOCABULARY LEN | ++ +-----------------+---------------+---------------+--------------------+--------+------------------+------------+----------------+ +| | 8192 | 3584 | 2048 | 16 | 42 | 14336 | 0 | 256000 | ++--------------+-----------------+---------------+---------------+--------------------+--------+------------------+------------+----------------+ + ++-----------+-------+-------------+------------+------------------+-----------+-----------+---------------+-----------------+---------------+ +| TOKENIZER | MODEL | TOKENS SIZE | TOKENS LEN | ADDED TOKENS LEN | BOS TOKEN | EOS TOKEN | UNKNOWN TOKEN | SEPARATOR TOKEN | PADDING TOKEN | ++ +-------+-------------+------------+------------------+-----------+-----------+---------------+-----------------+---------------+ +| | llama | 3.80 MiB | 256000 | 0 | 2 | 1 | 3 | N/A | 0 | ++-----------+-------+-------------+------------+------------------+-----------+-----------+---------------+-----------------+---------------+ + ++----------+--------+--------------+-----------------+--------------+----------------+----------------+----------+------------+-------------+ +| ESTIMATE | ARCH | CONTEXT SIZE | FLASH ATTENTION | MMAP SUPPORT | OFFLOAD LAYERS | FULL OFFLOADED | UMA RAM | NONUMA RAM | NONUMA VRAM | ++ +--------+--------------+-----------------+--------------+----------------+----------------+----------+------------+-------------+ +| | gemma2 | 8192 | false | true | 43 (42 + 1) | Yes | 2.69 GiB | 215.97 MiB | 8.43 GiB | ++----------+--------+--------------+-----------------+--------------+----------------+----------------+----------+------------+-------------+ + + +$ gguf-parser --ol-model="gemma2" --ol-crawl ++-------+--------+--------+----------------------+-----------+---------------+----------+------------+----------+ +| MODEL | NAME | ARCH | QUANTIZATION VERSION | FILE TYPE | LITTLE ENDIAN | SIZE | PARAMETERS | BPW | ++ +--------+--------+----------------------+-----------+---------------+----------+------------+----------+ +| | gemma2 | gemma2 | 2 | Q4_0 | true | 5.06 GiB | 9.24 B | 4.71 bpw | ++-------+--------+--------+----------------------+-----------+---------------+----------+------------+----------+ + ++--------------+-----------------+---------------+---------------+--------------------+--------+------------------+------------+----------------+ +| ARCHITECTURE | MAX CONTEXT LEN | EMBEDDING LEN | EMBEDDING GQA | ATTENTION HEAD CNT | LAYERS | FEED FORWARD LEN | EXPERT CNT | VOCABULARY LEN | ++ +-----------------+---------------+---------------+--------------------+--------+------------------+------------+----------------+ +| | 8192 | 3584 | 2048 | 16 | 42 | 14336 | 0 | 256000 | ++--------------+-----------------+---------------+---------------+--------------------+--------+------------------+------------+----------------+ + ++-----------+-------+-------------+------------+------------------+-----------+-----------+---------------+-----------------+---------------+ +| TOKENIZER | MODEL | TOKENS SIZE | TOKENS LEN | ADDED TOKENS LEN | BOS TOKEN | EOS TOKEN | UNKNOWN TOKEN | SEPARATOR TOKEN | PADDING TOKEN | ++ +-------+-------------+------------+------------------+-----------+-----------+---------------+-----------------+---------------+ +| | llama | 0 B | 256000 | 0 | 2 | 1 | 3 | N/A | 0 | ++-----------+-------+-------------+------------+------------------+-----------+-----------+---------------+-----------------+---------------+ + ++----------+--------+--------------+-----------------+--------------+----------------+----------------+----------+------------+-------------+ +| ESTIMATE | ARCH | CONTEXT SIZE | FLASH ATTENTION | MMAP SUPPORT | OFFLOAD LAYERS | FULL OFFLOADED | UMA RAM | NONUMA RAM | NONUMA VRAM | ++ +--------+--------------+-----------------+--------------+----------------+----------------+----------+------------+-------------+ +| | gemma2 | 8192 | false | true | 43 (42 + 1) | Yes | 2.69 GiB | 215.99 MiB | 8.12 GiB | ++----------+--------+--------------+-----------------+--------------+----------------+----------------+----------+------------+-------------+ + +``` + ### Estimate #### Estimate with zero layers offload diff --git a/cmd/gguf-parser/go.mod b/cmd/gguf-parser/go.mod index f6441fe..fde6162 100644 --- a/cmd/gguf-parser/go.mod +++ b/cmd/gguf-parser/go.mod @@ -11,10 +11,15 @@ require ( require ( github.com/henvic/httpretty v0.1.3 // indirect + github.com/json-iterator/go v1.1.12 // indirect github.com/mattn/go-runewidth v0.0.9 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 // indirect github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/mod v0.17.0 // indirect + golang.org/x/net v0.25.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.20.0 // indirect golang.org/x/tools v0.21.0 // indirect diff --git a/cmd/gguf-parser/go.sum b/cmd/gguf-parser/go.sum index 2d428fb..833e61b 100644 --- a/cmd/gguf-parser/go.sum +++ b/cmd/gguf-parser/go.sum @@ -1,21 +1,36 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/henvic/httpretty v0.1.3 h1:4A6vigjz6Q/+yAfTD4wqipCv+Px69C7Th/NhT0ApuU8= github.com/henvic/httpretty v0.1.3/go.mod h1:UUEv7c2kHZ5SPQ51uS3wBpzPDibg2U3Y+IaXyHy5GBg= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 h1:18kd+8ZUlt/ARXhljq+14TwAoKa61q6dX8jtwOf6DH8= +github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA= github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b h1:e9eeuSYSLmUKxy7ALzKcxo7ggTceQaVcBhjDIcewa9c= github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= diff --git a/cmd/gguf-parser/main.go b/cmd/gguf-parser/main.go index 829cbf4..028c354 100644 --- a/cmd/gguf-parser/main.go +++ b/cmd/gguf-parser/main.go @@ -8,10 +8,11 @@ import ( "strconv" "strings" "sync" - stdjson "encoding/json" "github.com/olekukonko/tablewriter" + "github.com/thxcode/gguf-parser-go/util/json" + . "github.com/thxcode/gguf-parser-go" ) @@ -24,9 +25,12 @@ func main() { var ( // model options - path string - url string - repo, file string + path string + url string + hfRepo string + hfFile string + olModel string + olCrawl bool // read options debug bool skipTLSVerify bool @@ -47,8 +51,8 @@ func main() { skipTokenizer bool skipEstimate bool inMib bool - json bool - jsonPretty = true + inJson bool + inPrettyJson = true ) fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError) fs.Usage = func() { @@ -62,14 +66,18 @@ func main() { "https://huggingface.co/NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF"+ "/resolve/main/Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf. "+ "Note that gguf-parser does not need to download the entire GGUF file.") - fs.StringVar(&repo, "repo", repo, "Repository of HuggingFace which the GGUF file store, e.g. "+ + fs.StringVar(&hfRepo, "repo", hfRepo, "Repository of HuggingFace which the GGUF file store, e.g. "+ "NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF, works with --file. [Deprecated, use --hf-repo instead]") - fs.StringVar(&file, "file", file, "Model file below the --repo, e.g. "+ + fs.StringVar(&hfFile, "file", hfFile, "Model file below the --repo, e.g. "+ "Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf. [Deprecated, use --hf-file instead]") // Deprecated. - fs.StringVar(&repo, "hf-repo", repo, "Repository of HuggingFace which the GGUF file store, e.g. "+ + fs.StringVar(&hfRepo, "hf-repo", hfRepo, "Repository of HuggingFace which the GGUF file store, e.g. "+ "NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF, works with --hf-file.") // Deprecated. - fs.StringVar(&file, "hf-file", file, "Model file below the --repo, e.g. "+ + fs.StringVar(&hfFile, "hf-file", hfFile, "Model file below the --repo, e.g. "+ "Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q4_K_M.gguf.") + fs.StringVar(&olModel, "ol-model", olModel, "Model name of Ollama, e.g. "+ + "gemma2.") + fs.BoolVar(&olCrawl, "ol-crawl", olCrawl, "Crawl the Ollama model instead of blobs fetching, "+ + "which will be more efficient and faster, but lossy.") fs.BoolVar(&debug, "debug", debug, "Enable debugging, verbosity.") fs.BoolVar(&skipTLSVerify, "skip-tls-verify", skipTLSVerify, "Skip TLS verification, works with --url.") fs.IntVar(&ctxSize, "ctx-size", ctxSize, "Specify the size of prompt context, "+ @@ -113,8 +121,8 @@ func main() { fs.BoolVar(&skipTokenizer, "skip-tokenizer", skipTokenizer, "Skip to display tokenizer metadata") fs.BoolVar(&skipEstimate, "skip-estimate", skipEstimate, "Skip to estimate.") fs.BoolVar(&inMib, "in-mib", inMib, "Display the estimated result in table with MiB.") - fs.BoolVar(&json, "json", json, "Output as JSON.") - fs.BoolVar(&jsonPretty, "json-pretty", jsonPretty, "Output as pretty JSON.") + fs.BoolVar(&inJson, "json", inJson, "Output as JSON.") + fs.BoolVar(&inPrettyJson, "json-pretty", inPrettyJson, "Output as pretty JSON.") if err := fs.Parse(os.Args[1:]); err != nil { fmt.Println(err.Error()) os.Exit(1) @@ -192,8 +200,10 @@ func main() { gf, err = ParseGGUFFile(path, ropts...) case url != "": gf, err = ParseGGUFFileRemote(ctx, url, ropts...) - case repo != "" && file != "": - gf, err = ParseGGUFFileFromHuggingFace(ctx, repo, file, ropts...) + case hfRepo != "" && hfFile != "": + gf, err = ParseGGUFFileFromHuggingFace(ctx, hfRepo, hfFile, ropts...) + case olModel != "": + gf, err = ParseGGUFFileFromOllama(ctx, olModel, olCrawl, ropts...) } if err != nil { _, _ = fmt.Fprintf(os.Stderr, "failed to parse GGUF file: %s\n", err.Error()) @@ -244,7 +254,7 @@ func main() { } } - if json { + if inJson { o := map[string]any{} if !skipModel { o["model"] = m @@ -286,8 +296,8 @@ func main() { o["estimate"] = es } - enc := stdjson.NewEncoder(os.Stdout) - if jsonPretty { + enc := json.NewEncoder(os.Stdout) + if inPrettyJson { enc.SetIndent("", " ") } if err := enc.Encode(o); err != nil { diff --git a/file.go b/file.go index 7980b5e..8425c2d 100644 --- a/file.go +++ b/file.go @@ -2,22 +2,18 @@ package gguf_parser import ( "bytes" - "context" "encoding/binary" "errors" "fmt" "io" - "net/http" "regexp" "strconv" "strings" - "time" "golang.org/x/exp/constraints" "github.com/thxcode/gguf-parser-go/util/bytex" "github.com/thxcode/gguf-parser-go/util/funcx" - "github.com/thxcode/gguf-parser-go/util/httpx" "github.com/thxcode/gguf-parser-go/util/osx" ) @@ -36,10 +32,14 @@ type GGUFFile struct { TensorInfos GGUFTensorInfos `json:"tensorInfos"` // Padding is the padding size of the GGUF file, // which is used to split Header and TensorInfos from tensor data. + // + // This might be empty if parse from crawler. Padding int64 `json:"padding"` // TensorDataStartOffset is the offset in bytes of the tensor data in this file. // // The offset is the start of the file. + // + // This might be lossy if parse from crawler. TensorDataStartOffset int64 `json:"tensorDataStartOffset"` /* Appendix */ @@ -151,7 +151,7 @@ type ( Len uint64 `json:"len"` // Array holds all array items. // - // Array may be empty if skipping. + // This might be empty if skipping or parse from crawler. Array []any `json:"array,omitempty"` /* Appendix */ @@ -159,9 +159,13 @@ type ( // StartOffset is the offset in bytes of the GGUFMetadataKVArrayValue in the GGUFFile file. // // The offset is the start of the file. + // + // This might be empty if parse from crawler. StartOffset int64 `json:"startOffset"` // Size is the size of the array in bytes. + // + // This might be empty if parse from crawler. Size int64 `json:"endOffset"` } @@ -195,6 +199,8 @@ type ( // StartOffset is the offset in bytes of the GGUFTensorInfo in the GGUFFile file. // // The offset is the start of the file. + // + // This might be empty if parse from crawler. StartOffset int64 `json:"startOffset"` } @@ -237,70 +243,6 @@ func ParseGGUFFile(path string, opts ...GGUFReadOption) (*GGUFFile, error) { return parseGGUFFile(s, f, o) } -// ParseGGUFFileRemote parses a GGUF file from a remote URL, -// and returns a GGUFFile, or an error if any. -func ParseGGUFFileRemote(ctx context.Context, url string, opts ...GGUFReadOption) (*GGUFFile, error) { - var o _GGUFReadOptions - for _, opt := range opts { - opt(&o) - } - - cli := httpx.Client( - httpx.ClientOptions(). - WithUserAgent("gguf-parser-go"). - If(o.Debug, func(x *httpx.ClientOption) *httpx.ClientOption { - return x.WithDebug() - }). - WithTimeout(0). - WithTransport( - httpx.TransportOptions(). - WithoutKeepalive(). - TimeoutForDial(5*time.Second). - TimeoutForTLSHandshake(5*time.Second). - TimeoutForResponseHeader(5*time.Second). - If(o.SkipProxy, func(x *httpx.TransportOption) *httpx.TransportOption { - return x.WithoutProxy() - }). - If(o.ProxyURL != nil, func(x *httpx.TransportOption) *httpx.TransportOption { - return x.WithProxy(http.ProxyURL(o.ProxyURL)) - }). - If(o.SkipTLSVerification, func(x *httpx.TransportOption) *httpx.TransportOption { - return x.WithoutInsecureVerify() - }))) - - var ( - f io.ReadSeeker - s int64 - ) - { - req, err := httpx.NewGetRequestWithContext(ctx, url) - if err != nil { - return nil, fmt.Errorf("new request: %w", err) - } - - var sf *httpx.SeekerFile - if o.BufferSize > 0 { - sf, err = httpx.OpenSeekerFileWithSize(cli, req, o.BufferSize, 0) - } else { - sf, err = httpx.OpenSeekerFile(cli, req) - } - if err != nil { - return nil, fmt.Errorf("open http file: %w", err) - } - defer osx.Close(sf) - f = io.NewSectionReader(sf, 0, sf.Len()) - s = sf.Len() - } - - return parseGGUFFile(s, f, o) -} - -// ParseGGUFFileFromHuggingFace parses a GGUF file from Hugging Face, -// and returns a GGUFFile, or an error if any. -func ParseGGUFFileFromHuggingFace(ctx context.Context, repo, file string, opts ...GGUFReadOption) (*GGUFFile, error) { - return ParseGGUFFileRemote(ctx, fmt.Sprintf("https://huggingface.co/%s/resolve/main/%s", repo, file), opts...) -} - func parseGGUFFile(s int64, f io.ReadSeeker, o _GGUFReadOptions) (_ *GGUFFile, err error) { var gf GGUFFile var bo binary.ByteOrder = binary.LittleEndian diff --git a/file_from_metadata.go b/file_from_metadata.go new file mode 100644 index 0000000..3e30c55 --- /dev/null +++ b/file_from_metadata.go @@ -0,0 +1,244 @@ +package gguf_parser + +import ( + "context" + "errors" + "fmt" + "net/http" + "regexp" + "sort" + "strconv" + "time" + + "golang.org/x/exp/maps" + "golang.org/x/net/html" + + "github.com/thxcode/gguf-parser-go/util/funcx" + "github.com/thxcode/gguf-parser-go/util/httpx" + "github.com/thxcode/gguf-parser-go/util/json" + "github.com/thxcode/gguf-parser-go/util/stringx" +) + +var ( + ErrOllamaInvalidModel = errors.New("ollama invalid model") + ErrOllamaBaseLayerNotFound = errors.New("ollama base layer not found") +) + +// ParseGGUFFileFromOllama parses a GGUF file from Ollama model's base layer, +// and returns a GGUFFile, or an error if any. +// +// If the crawl is true, it will try to crawl the metadata from Ollama website instead of blobs fetching, +// which will be more efficient and faster, but lossy. +// If the crawling fails, it will fall back to the default behavior. +func ParseGGUFFileFromOllama(ctx context.Context, model string, crawl bool, opts ...GGUFReadOption) (*GGUFFile, error) { + var o _GGUFReadOptions + for _, opt := range opts { + opt(&o) + } + + om := ParseOllamaModel(model) + if om == nil { + return nil, ErrOllamaInvalidModel + } + + cli := httpx.Client( + httpx.ClientOptions(). + WithUserAgent("gguf-parser-go"). + If(o.Debug, func(x *httpx.ClientOption) *httpx.ClientOption { + return x.WithDebug() + }). + WithTimeout(0). + WithTransport( + httpx.TransportOptions(). + WithoutKeepalive(). + TimeoutForDial(5*time.Second). + TimeoutForTLSHandshake(5*time.Second). + TimeoutForResponseHeader(5*time.Second). + If(o.SkipProxy, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithoutProxy() + }). + If(o.ProxyURL != nil, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithProxy(http.ProxyURL(o.ProxyURL)) + }). + If(o.SkipTLSVerification, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithoutInsecureVerify() + }))) + + var ml OllamaModelLayer + { + err := om.Complete(ctx, cli) + if err != nil { + return nil, fmt.Errorf("complete ollama model: %w", err) + } + + var ok bool + ml, ok = om.GetLayer("application/vnd.ollama.image.model") + if !ok { + return nil, ErrOllamaBaseLayerNotFound + } + } + + if crawl { + mwu, lwu := om.WebURL().String(), ml.WebURL().String() + req, err := httpx.NewGetRequestWithContext(ctx, lwu) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + req.Header.Add("Referer", mwu) + req.Header.Add("Hx-Current-Url", mwu) + req.Header.Add("Hx-Request", "true") + req.Header.Add("Hx-Target", "file-explorer") + + var n *html.Node + err = httpx.Do(cli, req, func(resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status code %d", resp.StatusCode) + } + + n, err = html.Parse(resp.Body) + if err != nil { + return fmt.Errorf("parse html: %w", err) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("do crawl request: %w", err) + } + + var wk func(*html.Node) string + wk = func(n *html.Node) string { + if n.Type == html.ElementNode && n.Data == "div" { + for i := range n.Attr { + if n.Attr[i].Key == "class" && n.Attr[i].Val == "whitespace-pre-wrap" { + return n.FirstChild.Data + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + if r := wk(c); r != "" { + return r + } + } + return "" + } + + r := wk(n) + if r != "" { + return parseGGUFFileFromMetadata("ollama", r, ml.Size) + } + + // Fallback to the default behavior. + } + + return parseGGUFFileFromRemote(ctx, cli, ml.URL().String(), o) +} + +type _OllamaMetadata struct { + Metadata map[string]any `json:"metadata"` + NumParams uint64 `json:"num_params"` + Tensors []struct { + Name string `json:"name"` + Shape []uint64 `json:"shape"` + Offset uint64 `json:"offset"` + Type uint32 `json:"type"` + } `json:"tensors"` + Version uint32 `json:"version"` +} + +func parseGGUFFileFromMetadata(source, data string, size uint64) (*GGUFFile, error) { + if source != "ollama" { + return nil, fmt.Errorf("invalid source %q", source) + } + + var m _OllamaMetadata + if err := json.Unmarshal([]byte(data), &m); err != nil { + return nil, fmt.Errorf("unmarshal metadata: %w", err) + } + + arrayMetadataValueRegex := regexp.MustCompile(`^\.{3} \((?P\d+) values\)$`) + + var gf GGUFFile + + gf.Header.Magic = GGUFMagicGGUFLe + gf.Header.Version = GGUFVersion(m.Version) + gf.Header.TensorCount = uint64(len(m.Tensors)) + gf.Header.MetadataKVCount = uint64(len(m.Metadata) + 1 /* tokenizer.chat_template */) + gf.Size = GGUFBytesScalar(size) + gf.ModelParameters = GGUFParametersScalar(m.NumParams) + + gf.Header.MetadataKV = make([]GGUFMetadataKV, 0, len(m.Metadata)) + for _, k := range func() []string { + ks := maps.Keys(m.Metadata) + ks = append(ks, "tokenizer.chat_template") + sort.Strings(ks) + return ks + }() { + if k == "tokenizer.chat_template" { + gf.Header.MetadataKV = append(gf.Header.MetadataKV, GGUFMetadataKV{ + Key: k, + ValueType: GGUFMetadataValueTypeString, + Value: "!!! tokenizer.chat_template !!!", + }) + continue + } + + var ( + vt GGUFMetadataValueType + v = m.Metadata[k] + ) + switch vv := v.(type) { + case bool: + vt = GGUFMetadataValueTypeBool + case float64: + vt = GGUFMetadataValueTypeFloat32 + v = float32(vv) + case int64: + vt = GGUFMetadataValueTypeUint32 + v = uint32(vv) + case string: + vt = GGUFMetadataValueTypeString + if r := arrayMetadataValueRegex.FindStringSubmatch(vv); len(r) == 2 { + vt = GGUFMetadataValueTypeArray + av := GGUFMetadataKVArrayValue{ + Type: GGUFMetadataValueTypeString, + Len: funcx.MustNoError(strconv.ParseUint(r[1], 10, 64)), + } + switch _, d, _ := stringx.CutFromRight(k, "."); d { + case "scores": + av.Type = GGUFMetadataValueTypeFloat32 + case "token_type": + av.Type = GGUFMetadataValueTypeInt32 + } + v = av + } + } + gf.Header.MetadataKV = append(gf.Header.MetadataKV, GGUFMetadataKV{ + Key: k, + ValueType: vt, + Value: v, + }) + } + + gf.TensorInfos = make([]GGUFTensorInfo, 0, len(m.Tensors)) + for i := range m.Tensors { + t := m.Tensors[i] + ti := GGUFTensorInfo{ + Name: t.Name, + NDimensions: uint32(len(t.Shape)), + Dimensions: t.Shape, + Offset: t.Offset, + Type: GGMLType(t.Type), + } + gf.TensorInfos = append(gf.TensorInfos, ti) + gf.ModelSize += GGUFBytesScalar(ti.Bytes()) + } + + gf.TensorDataStartOffset = int64(gf.Size - gf.ModelSize) + + if gf.ModelParameters != 0 { + gf.ModelBitsPerWeight = GGUFBitsPerWeightScalar(float64(gf.ModelSize) * 8 / float64(gf.ModelParameters)) + } + + return &gf, nil +} diff --git a/file_from_remote.go b/file_from_remote.go new file mode 100644 index 0000000..0e067e3 --- /dev/null +++ b/file_from_remote.go @@ -0,0 +1,80 @@ +package gguf_parser + +import ( + "context" + "fmt" + "io" + "net/http" + "time" + + "github.com/thxcode/gguf-parser-go/util/httpx" + "github.com/thxcode/gguf-parser-go/util/osx" +) + +// ParseGGUFFileFromHuggingFace parses a GGUF file from Hugging Face, +// and returns a GGUFFile, or an error if any. +func ParseGGUFFileFromHuggingFace(ctx context.Context, repo, file string, opts ...GGUFReadOption) (*GGUFFile, error) { + return ParseGGUFFileRemote(ctx, fmt.Sprintf("https://huggingface.co/%s/resolve/main/%s", repo, file), opts...) +} + +// ParseGGUFFileRemote parses a GGUF file from a remote URL, +// and returns a GGUFFile, or an error if any. +func ParseGGUFFileRemote(ctx context.Context, url string, opts ...GGUFReadOption) (*GGUFFile, error) { + var o _GGUFReadOptions + for _, opt := range opts { + opt(&o) + } + + cli := httpx.Client( + httpx.ClientOptions(). + WithUserAgent("gguf-parser-go"). + If(o.Debug, func(x *httpx.ClientOption) *httpx.ClientOption { + return x.WithDebug() + }). + WithTimeout(0). + WithTransport( + httpx.TransportOptions(). + WithoutKeepalive(). + TimeoutForDial(5*time.Second). + TimeoutForTLSHandshake(5*time.Second). + TimeoutForResponseHeader(5*time.Second). + If(o.SkipProxy, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithoutProxy() + }). + If(o.ProxyURL != nil, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithProxy(http.ProxyURL(o.ProxyURL)) + }). + If(o.SkipTLSVerification, func(x *httpx.TransportOption) *httpx.TransportOption { + return x.WithoutInsecureVerify() + }))) + + return parseGGUFFileFromRemote(ctx, cli, url, o) +} + +func parseGGUFFileFromRemote(ctx context.Context, cli *http.Client, url string, o _GGUFReadOptions) (*GGUFFile, error) { + var ( + f io.ReadSeeker + s int64 + ) + { + req, err := httpx.NewGetRequestWithContext(ctx, url) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + + var sf *httpx.SeekerFile + if o.BufferSize > 0 { + sf, err = httpx.OpenSeekerFileWithSize(cli, req, o.BufferSize, 0) + } else { + sf, err = httpx.OpenSeekerFile(cli, req) + } + if err != nil { + return nil, fmt.Errorf("open http file: %w", err) + } + defer osx.Close(sf) + f = io.NewSectionReader(sf, 0, sf.Len()) + s = sf.Len() + } + + return parseGGUFFile(s, f, o) +} diff --git a/file_test.go b/file_test.go index 7991c39..e48731c 100644 --- a/file_test.go +++ b/file_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "testing" + "time" "github.com/davecgh/go-spew/spew" ) @@ -189,6 +190,10 @@ func TestParseGGUFFileFromHuggingFace(t *testing.T) { "lmstudio-community/Yi-1.5-9B-Chat-GGUF", "Yi-1.5-9B-Chat-Q5_K_M.gguf", }, + { + "bartowski/gemma-2-9b-it-GGUF", + "gemma-2-9b-it-Q3_K_M.gguf", + }, } for _, tc := range cases { t.Run(tc[0]+"/"+tc[1], func(t *testing.T) { @@ -201,3 +206,62 @@ func TestParseGGUFFileFromHuggingFace(t *testing.T) { }) } } + +func TestParseGGUFFileFromOllama(t *testing.T) { + ctx := context.Background() + + cases := []string{ + "gemma2", + "llama3:8b", + "qwen2:72b-instruct-q3_K_M", + } + for _, tc := range cases { + t.Run(tc, func(t *testing.T) { + start := time.Now() + cf, err := ParseGGUFFileFromOllama(ctx, tc, true, SkipLargeMetadata()) + if err != nil { + t.Fatal(err) + return + } + t.Logf("cost: %v\n", time.Since(start)) + t.Log("\n", spew.Sdump(cf), "\n") + + start = time.Now() + sf, err := ParseGGUFFileFromOllama(ctx, tc, false, SkipLargeMetadata()) + if err != nil { + t.Fatal(err) + return + } + t.Logf("cost: %v\n", time.Since(start)) + t.Log("\n", spew.Sdump(sf), "\n") + }) + } +} + +func BenchmarkParseGGUFFileOllamaCrawl(b *testing.B) { + ctx := context.Background() + + b.ReportAllocs() + + b.ResetTimer() + b.Run("Without Crawl", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFileFromOllama(ctx, "gemma2", false, SkipLargeMetadata()) + if err != nil { + b.Fatal(err) + return + } + } + }) + + b.ResetTimer() + b.Run("With Crawl", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := ParseGGUFFileFromOllama(ctx, "gemma2", true, SkipLargeMetadata()) + if err != nil { + b.Fatal(err) + return + } + } + }) +} diff --git a/go.mod b/go.mod index 2998fb2..78047d8 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,19 @@ go 1.22 require ( github.com/davecgh/go-spew v1.1.1 github.com/henvic/httpretty v0.1.3 + github.com/json-iterator/go v1.1.12 + github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 + golang.org/x/net v0.25.0 golang.org/x/sys v0.20.0 golang.org/x/tools v0.21.0 ) require ( + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/sync v0.7.0 // indirect diff --git a/go.sum b/go.sum index 467e640..bea2dfb 100644 --- a/go.sum +++ b/go.sum @@ -1,17 +1,32 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/henvic/httpretty v0.1.3 h1:4A6vigjz6Q/+yAfTD4wqipCv+Px69C7Th/NhT0ApuU8= github.com/henvic/httpretty v0.1.3/go.mod h1:UUEv7c2kHZ5SPQ51uS3wBpzPDibg2U3Y+IaXyHy5GBg= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 h1:18kd+8ZUlt/ARXhljq+14TwAoKa61q6dX8jtwOf6DH8= +github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA= github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b h1:e9eeuSYSLmUKxy7ALzKcxo7ggTceQaVcBhjDIcewa9c= github.com/smallnest/ringbuffer v0.0.0-20240423223918-bab516b2000b/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= diff --git a/ollama_model.go b/ollama_model.go new file mode 100644 index 0000000..8475b9e --- /dev/null +++ b/ollama_model.go @@ -0,0 +1,214 @@ +package gguf_parser + +import ( + "context" + "fmt" + "net/http" + "net/url" + "regexp" + "strings" + + "github.com/thxcode/gguf-parser-go/util/httpx" + "github.com/thxcode/gguf-parser-go/util/json" + "github.com/thxcode/gguf-parser-go/util/stringx" +) + +// Inspired by https://github.com/ollama/ollama/blob/380e06e5bea06ae8ded37f47c37bd5d604194d3e/types/model/name.go, +// and https://github.com/ollama/ollama/blob/380e06e5bea06ae8ded37f47c37bd5d604194d3e/server/modelpath.go. + +const ( + OllamaDefaultScheme = "https" + OllamaDefaultRegistry = "ollama.com" + OllamaDefaultNamespace = "library" + OllamaDefaultTag = "latest" +) + +type ( + OllamaModel struct { + Schema string `json:"schema"` + Registry string `json:"registry"` + Namespace string `json:"namespace"` + Repository string `json:"repository"` + Tag string `json:"tag"` + SchemaVersion uint32 `json:"schemaVersion"` + MediaType string `json:"mediaType"` + Config OllamaModelLayer `json:"config"` + Layers []OllamaModelLayer `json:"layers"` + } + OllamaModelLayer struct { + MediaType string `json:"mediaType"` + Size uint64 `json:"size"` + Digest string `json:"digest"` + + model *OllamaModel + } +) + +// ParseOllamaModel parses the given Ollama model string, +// and returns the OllamaModel, or nil if the model is invalid. +func ParseOllamaModel(model string) *OllamaModel { + if model == "" { + return nil + } + + om := OllamaModel{ + Schema: OllamaDefaultScheme, + Registry: OllamaDefaultRegistry, + Namespace: OllamaDefaultNamespace, + Tag: OllamaDefaultTag, + } + + m := model + + // Drop digest. + m, _, _ = stringx.CutFromRight(m, "@") + + // Get tag. + m, s, ok := stringx.CutFromRight(m, ":") + if ok && s != "" { + om.Tag = s + } + + // Get repository. + m, s, ok = stringx.CutFromRight(m, "/") + if ok && s != "" { + om.Repository = s + } else if m != "" { + om.Repository = m + m = "" + } + + // Get namespace. + m, s, ok = stringx.CutFromRight(m, "/") + if ok && s != "" { + om.Namespace = s + } else if m != "" { + om.Namespace = m + m = "" + } + + // Get registry. + m, s, ok = stringx.CutFromLeft(m, "://") + if ok && s != "" { + om.Schema = m + om.Registry = s + } else if m != "" { + om.Registry = m + } + + if om.Repository == "" { + return nil + } + return &om +} + +func (om *OllamaModel) String() string { + var b strings.Builder + if om.Registry != "" { + b.WriteString(om.Registry) + b.WriteByte('/') + } + if om.Namespace != "" { + b.WriteString(om.Namespace) + b.WriteByte('/') + } + b.WriteString(om.Repository) + if om.Tag != "" { + b.WriteByte(':') + b.WriteString(om.Tag) + } + return b.String() +} + +// GetLayer returns the OllamaModelLayer with the given media type, +// and true if found, and false otherwise. +func (om *OllamaModel) GetLayer(mediaType string) (OllamaModelLayer, bool) { + for i := range om.Layers { + if om.Layers[i].MediaType == mediaType { + return om.Layers[i], true + } + } + return OllamaModelLayer{}, false +} + +// SearchLayers returns a list of OllamaModelLayer with the media type that matches the given regex. +func (om *OllamaModel) SearchLayers(mediaTypeRegex *regexp.Regexp) []OllamaModelLayer { + var ls []OllamaModelLayer + for i := range om.Layers { + if mediaTypeRegex.MatchString(om.Layers[i].MediaType) { + ls = append(ls, om.Layers[i]) + } + } + return ls +} + +// URL returns the URL of the OllamaModel. +func (om *OllamaModel) URL() *url.URL { + u := &url.URL{ + Scheme: om.Schema, + Host: om.Registry, + } + return u.JoinPath("v2", om.Namespace, om.Repository, "manifests", om.Tag) +} + +// WebURL returns the Ollama web URL of the OllamaModel. +func (om *OllamaModel) WebURL() *url.URL { + u := &url.URL{ + Scheme: om.Schema, + Host: om.Registry, + } + return u.JoinPath(om.Namespace, om.Repository+":"+om.Tag) +} + +// Complete completes the OllamaModel with the given context and http client. +func (om *OllamaModel) Complete(ctx context.Context, cli *http.Client) error { + req, err := httpx.NewGetRequestWithContext(ctx, om.URL().String()) + if err != nil { + return fmt.Errorf("new request: %w", err) + } + + err = httpx.Do(cli, req, func(resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status code %d", resp.StatusCode) + } + return json.NewDecoder(resp.Body).Decode(om) + }) + if err != nil { + return fmt.Errorf("do request: %w", err) + } + + // Connect. + om.Config.model = om + for i := range om.Layers { + om.Layers[i].model = om + } + + return nil +} + +// URL returns the URL of the OllamaModelLayer. +func (ol *OllamaModelLayer) URL() *url.URL { + if ol.model == nil { + return nil + } + + u := &url.URL{ + Scheme: ol.model.Schema, + Host: ol.model.Registry, + } + return u.JoinPath("v2", ol.model.Namespace, ol.model.Repository, "blobs", ol.Digest) +} + +// WebURL returns the Ollama web URL of the OllamaModelLayer. +func (ol *OllamaModelLayer) WebURL() *url.URL { + if ol.model == nil || len(ol.MediaType) < 12 { + return nil + } + + dg := strings.TrimPrefix(ol.Digest, "sha256:")[:12] + u := &url.URL{ + Scheme: ol.model.Schema, + Host: ol.model.Registry, + } + return u.JoinPath(ol.model.Namespace, ol.model.Repository+":"+ol.model.Tag, "blobs", dg) +} diff --git a/ollama_model_test.go b/ollama_model_test.go new file mode 100644 index 0000000..704b8af --- /dev/null +++ b/ollama_model_test.go @@ -0,0 +1,81 @@ +package gguf_parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseOllamaModel(t *testing.T) { + cases := []struct { + given string + expected *OllamaModel + }{ + { + given: "gemma2", + expected: &OllamaModel{ + Schema: OllamaDefaultScheme, + Registry: OllamaDefaultRegistry, + Namespace: OllamaDefaultNamespace, + Repository: "gemma2", + Tag: OllamaDefaultTag, + }, + }, + { + given: "gemma2:awesome", + expected: &OllamaModel{ + Schema: OllamaDefaultScheme, + Registry: OllamaDefaultRegistry, + Namespace: OllamaDefaultNamespace, + Repository: "gemma2", + Tag: "awesome", + }, + }, + { + given: "gemma2:awesome@sha256:1234567890abcdef", + expected: &OllamaModel{ + Schema: OllamaDefaultScheme, + Registry: OllamaDefaultRegistry, + Namespace: OllamaDefaultNamespace, + Repository: "gemma2", + Tag: "awesome", + }, + }, + { + given: "awesome/gemma2:latest@sha256:1234567890abcdef", + expected: &OllamaModel{ + Schema: OllamaDefaultScheme, + Registry: OllamaDefaultRegistry, + Namespace: "awesome", + Repository: "gemma2", + Tag: "latest", + }, + }, + { + given: "mysite.com/library/gemma2:latest@sha256:1234567890abcdef", + expected: &OllamaModel{ + Schema: OllamaDefaultScheme, + Registry: "mysite.com", + Namespace: "library", + Repository: "gemma2", + Tag: "latest", + }, + }, + { + given: "http://mysite.com/library/gemma2:latest@sha256:1234567890abcdef", + expected: &OllamaModel{ + Schema: "http", + Registry: "mysite.com", + Namespace: "library", + Repository: "gemma2", + Tag: "latest", + }, + }, + } + for _, tc := range cases { + t.Run(tc.given, func(t *testing.T) { + actual := ParseOllamaModel(tc.given) + assert.Equal(t, tc.expected, actual) + }) + } +} diff --git a/util/httpx/client.go b/util/httpx/client.go index 3c182df..5d0c486 100644 --- a/util/httpx/client.go +++ b/util/httpx/client.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "time" "github.com/henvic/httpretty" @@ -58,18 +59,41 @@ func Client(opts ...*ClientOption) *http.Client { root = pretty.RoundTripper(root) } - transport := RoundTripperChain{ + rtc := RoundTripperChain{ Next: root, } - for i := range o.roundTrips { - transport = RoundTripperChain{ - Do: o.roundTrips[i], - Next: transport, + for i := range o.roundTrippers { + rtc = RoundTripperChain{ + Do: o.roundTrippers[i], + Next: rtc, } } + var rt http.RoundTripper = rtc + if o.retryIf != nil { + rt = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + for i := 0; ; i++ { + resp, err := rtc.RoundTrip(req) + if !o.retryIf(resp, err) { + return resp, err + } + w, ok := o.retryBackoff(i+1, resp) + if !ok { + return resp, err + } + wt := time.NewTimer(w) + select { + case <-req.Context().Done(): + wt.Stop() + return resp, req.Context().Err() + case <-wt.C: + } + } + }) + } + return &http.Client{ - Transport: transport, + Transport: rt, Timeout: o.timeout, } } diff --git a/util/httpx/client_helper.go b/util/httpx/client_helper.go index afe6ce5..42b4058 100644 --- a/util/httpx/client_helper.go +++ b/util/httpx/client_helper.go @@ -2,13 +2,14 @@ package httpx import ( "bytes" - "encoding/json" "errors" "io" "net/http" "regexp" "github.com/henvic/httpretty" + + "github.com/thxcode/gguf-parser-go/util/json" ) var _ httpretty.Formatter = (*JSONFormatter)(nil) @@ -58,3 +59,9 @@ func (c RoundTripperChain) RoundTrip(req *http.Request) (*http.Response, error) } return nil, nil } + +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} diff --git a/util/httpx/client_options.go b/util/httpx/client_options.go index 2c2fa16..ed88dad 100644 --- a/util/httpx/client_options.go +++ b/util/httpx/client_options.go @@ -1,22 +1,29 @@ package httpx import ( + "math" "net/http" + "strconv" + "strings" "time" ) type ClientOption struct { *TransportOption - timeout time.Duration - debug bool - roundTrips []func(req *http.Request) error + timeout time.Duration + debug bool + retryIf func(resp *http.Response, err error) (retry bool) + retryBackoff func(attemptNum int, resp *http.Response) (wait time.Duration, ok bool) + roundTrippers []func(req *http.Request) error } func ClientOptions() *ClientOption { return &ClientOption{ TransportOption: TransportOptions().WithoutKeepalive(), timeout: 30 * time.Second, + retryIf: defaultRetryIf, + retryBackoff: createRetryBackoff(100*time.Millisecond, 5*time.Second, 5), } } @@ -51,18 +58,29 @@ func (o *ClientOption) WithDebug() *ClientOption { return o } -// WithRoundTrip sets the round trip function. -func (o *ClientOption) WithRoundTrip(rt func(req *http.Request) error) *ClientOption { - if o == nil || rt == nil { +// WithRetryIf specifies the if-condition of retry operation for request, +// or stops retrying if setting with `nil`. +func (o *ClientOption) WithRetryIf(retryIf func(resp *http.Response, err error) (retry bool)) *ClientOption { + if o == nil { + return o + } + o.retryIf = retryIf + return o +} + +// WithRetryBackoff specifies the retry-backoff mechanism for request, +// default retry 5 times within 1s, 2s, 4s, 8s, 15s waiting. +func (o *ClientOption) WithRetryBackoff(waitMin, waitMax time.Duration, attemptMax int) *ClientOption { + if o == nil || waitMin < 0 || waitMax < 0 || waitMax < waitMin || attemptMax <= 0 { return o } - o.roundTrips = append(o.roundTrips, rt) + o.retryBackoff = createRetryBackoff(waitMin, waitMax, attemptMax) return o } // WithUserAgent sets the user agent. func (o *ClientOption) WithUserAgent(ua string) *ClientOption { - return o.WithRoundTrip(func(req *http.Request) error { + return o.WithRoundTripper(func(req *http.Request) error { req.Header.Set("User-Agent", ua) return nil }) @@ -70,7 +88,7 @@ func (o *ClientOption) WithUserAgent(ua string) *ClientOption { // WithBearerAuth sets the bearer token. func (o *ClientOption) WithBearerAuth(token string) *ClientOption { - return o.WithRoundTrip(func(req *http.Request) error { + return o.WithRoundTripper(func(req *http.Request) error { req.Header.Set("Authorization", "Bearer "+token) return nil }) @@ -78,7 +96,7 @@ func (o *ClientOption) WithBearerAuth(token string) *ClientOption { // WithBasicAuth sets the basic authentication. func (o *ClientOption) WithBasicAuth(username, password string) *ClientOption { - return o.WithRoundTrip(func(req *http.Request) error { + return o.WithRoundTripper(func(req *http.Request) error { req.SetBasicAuth(username, password) return nil }) @@ -86,7 +104,7 @@ func (o *ClientOption) WithBasicAuth(username, password string) *ClientOption { // WithHeader sets the header. func (o *ClientOption) WithHeader(key, value string) *ClientOption { - return o.WithRoundTrip(func(req *http.Request) error { + return o.WithRoundTripper(func(req *http.Request) error { req.Header.Set(key, value) return nil }) @@ -94,7 +112,7 @@ func (o *ClientOption) WithHeader(key, value string) *ClientOption { // WithHeaders sets the headers. func (o *ClientOption) WithHeaders(headers map[string]string) *ClientOption { - return o.WithRoundTrip(func(req *http.Request) error { + return o.WithRoundTripper(func(req *http.Request) error { for k, v := range headers { req.Header.Set(k, v) } @@ -102,6 +120,15 @@ func (o *ClientOption) WithHeaders(headers map[string]string) *ClientOption { }) } +// WithRoundTripper sets the round tripper. +func (o *ClientOption) WithRoundTripper(rt func(req *http.Request) error) *ClientOption { + if o == nil || rt == nil { + return o + } + o.roundTrippers = append(o.roundTrippers, rt) + return o +} + // If is a conditional option, // which receives a boolean condition to trigger the given function or not. func (o *ClientOption) If(condition bool, then func(*ClientOption) *ClientOption) *ClientOption { @@ -110,3 +137,57 @@ func (o *ClientOption) If(condition bool, then func(*ClientOption) *ClientOption } return o } + +// defaultRetryIf is the default retry condition, +// inspired by https://github.com/hashicorp/go-retryablehttp/blob/40b0cad1633fd521cee5884724fcf03d039aaf3f/client.go#L68-L86. +func defaultRetryIf(resp *http.Response, respErr error) bool { + if respErr != nil { + switch errMsg := respErr.Error(); { + case strings.Contains(errMsg, `redirects`): + return false + case strings.Contains(errMsg, `unsupported protocol scheme`): + return false + case strings.Contains(errMsg, `certificate is not trusted`): + return false + case strings.Contains(errMsg, `invalid header`): + return false + case strings.Contains(errMsg, `failed to verify certificate`): + return false + } + + // Retry if receiving connection closed. + return true + } + + // Retry if receiving rate-limited of server. + if resp.StatusCode == http.StatusTooManyRequests { + return true + } + + // Retry if receiving unexpected responses. + if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != http.StatusNotImplemented) { + return true + } + + return false +} + +// createRetryBackoff creates a backoff function for retry operation. +func createRetryBackoff(waitMin, waitMax time.Duration, attemptMax int) func(int, *http.Response) (time.Duration, bool) { + return func(attemptNum int, resp *http.Response) (wait time.Duration, ok bool) { + if attemptNum > attemptMax { + return 0, false + } + + if resp != nil && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable) { + if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" { + if seconds, err := strconv.Atoi(retryAfter); err == nil { + return time.Duration(seconds) * time.Second, true + } + } + } + + wait = time.Duration(math.Pow(2, float64(attemptNum)) * float64(waitMin)) + return min(wait, waitMax), true + } +} diff --git a/util/httpx/resolver.go b/util/httpx/resolver.go new file mode 100644 index 0000000..985c0b4 --- /dev/null +++ b/util/httpx/resolver.go @@ -0,0 +1,49 @@ +package httpx + +import ( + "context" + "net" + "time" + + "github.com/rs/dnscache" +) + +// DefaultResolver is the default DNS resolver used by the package, +// which caches DNS lookups in memory. +var DefaultResolver = &dnscache.Resolver{ + Timeout: time.Second, + Resolver: net.DefaultResolver, +} + +func init() { + go func() { + t := time.NewTimer(5 * time.Minute) + defer t.Stop() + for range t.C { + DefaultResolver.RefreshWithOptions(dnscache.ResolverRefreshOptions{ + ClearUnused: true, + PersistOnFailure: false, + }) + } + }() +} + +func DNSCacheDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return func(ctx context.Context, nw, addr string) (conn net.Conn, err error) { + h, p, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ips, err := DefaultResolver.LookupHost(ctx, h) + if err != nil { + return nil, err + } + for _, ip := range ips { + conn, err = dialer.DialContext(ctx, nw, net.JoinHostPort(ip, p)) + if err == nil { + break + } + } + return conn, err + } +} diff --git a/util/httpx/transport_options.go b/util/httpx/transport_options.go index fd0f224..3ac9e59 100644 --- a/util/httpx/transport_options.go +++ b/util/httpx/transport_options.go @@ -23,7 +23,7 @@ func TransportOptions() *TransportOption { TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, }, - DialContext: dialer.DialContext, + DialContext: DNSCacheDialContext(dialer), ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -163,13 +163,22 @@ func (o *TransportOption) WithTLSClientConfig(config *tls.Config) *TransportOpti return o } +// WithoutDNSCache disables the dns cache. +func (o *TransportOption) WithoutDNSCache() *TransportOption { + if o == nil || o.transport == nil || o.dialer == nil { + return o + } + o.transport.DialContext = o.dialer.DialContext + return o +} + // WithDialer sets the dialer. func (o *TransportOption) WithDialer(dialer *net.Dialer) *TransportOption { if o == nil || o.transport == nil || dialer == nil { return o } o.dialer = dialer - o.transport.DialContext = dialer.DialContext + o.transport.DialContext = DNSCacheDialContext(o.dialer) return o } diff --git a/util/json/common.go b/util/json/common.go new file mode 100644 index 0000000..ec77692 --- /dev/null +++ b/util/json/common.go @@ -0,0 +1,64 @@ +package json + +import ( + stdjson "encoding/json" + "fmt" +) + +type RawMessage = stdjson.RawMessage + +var ( + MarshalIndent = stdjson.MarshalIndent + Indent = stdjson.Indent +) + +// MustMarshal is similar to Marshal, +// but panics if found error. +func MustMarshal(v any) []byte { + bs, err := Marshal(v) + if err != nil { + panic(fmt.Errorf("error marshaling json: %w", err)) + } + + return bs +} + +// MustUnmarshal is similar to Unmarshal, +// but panics if found error. +func MustUnmarshal(data []byte, v any) { + err := Unmarshal(data, v) + if err != nil { + panic(fmt.Errorf("error unmarshaling json: %w", err)) + } +} + +// MustMarshalIndent is similar to MarshalIndent, +// but panics if found error. +func MustMarshalIndent(v any, prefix, indent string) []byte { + bs, err := MarshalIndent(v, prefix, indent) + if err != nil { + panic(fmt.Errorf("error marshaling indent json: %w", err)) + } + + return bs +} + +// ShouldMarshal is similar to Marshal, +// but never return error. +func ShouldMarshal(v any) []byte { + bs, _ := Marshal(v) + return bs +} + +// ShouldUnmarshal is similar to Unmarshal, +// but never return error. +func ShouldUnmarshal(data []byte, v any) { + _ = Unmarshal(data, v) +} + +// ShouldMarshalIndent is similar to MarshalIndent, +// but never return error. +func ShouldMarshalIndent(v any, prefix, indent string) []byte { + bs, _ := MarshalIndent(v, prefix, indent) + return bs +} diff --git a/util/json/jsoniter.go b/util/json/jsoniter.go new file mode 100644 index 0000000..6cd66c1 --- /dev/null +++ b/util/json/jsoniter.go @@ -0,0 +1,48 @@ +//go:build !stdjson + +package json + +import ( + stdjson "encoding/json" + "strconv" + "unsafe" + + jsoniter "github.com/json-iterator/go" +) + +var json = jsoniter.ConfigCompatibleWithStandardLibrary + +func init() { + // borrowed from https://github.com/json-iterator/go/issues/145#issuecomment-323483602 + decodeNumberAsInt64IfPossible := func(ptr unsafe.Pointer, iter *jsoniter.Iterator) { + switch iter.WhatIsNext() { + case jsoniter.NumberValue: + var number stdjson.Number + + iter.ReadVal(&number) + i, err := strconv.ParseInt(string(number), 10, 64) + + if err == nil { + *(*any)(ptr) = i + return + } + + f, err := strconv.ParseFloat(string(number), 64) + if err == nil { + *(*any)(ptr) = f + return + } + default: + *(*any)(ptr) = iter.Read() + } + } + jsoniter.RegisterTypeDecoderFunc("interface {}", decodeNumberAsInt64IfPossible) +} + +var ( + Marshal = json.Marshal + Unmarshal = json.Unmarshal + NewDecoder = json.NewDecoder + NewEncoder = json.NewEncoder + Valid = json.Valid +) diff --git a/util/json/stdjson.go b/util/json/stdjson.go new file mode 100644 index 0000000..602394e --- /dev/null +++ b/util/json/stdjson.go @@ -0,0 +1,15 @@ +//go:build stdjson + +package json + +import ( + "encoding/json" +) + +var ( + Marshal = json.Marshal + Unmarshal = json.Unmarshal + NewDecoder = json.NewDecoder + NewEncoder = json.NewEncoder + Valid = json.Valid +) diff --git a/util/stringx/strings.go b/util/stringx/strings.go new file mode 100644 index 0000000..739fc34 --- /dev/null +++ b/util/stringx/strings.go @@ -0,0 +1,26 @@ +package stringx + +import "strings" + +// CutFromLeft is the same as strings.Cut, +// which starts from left to right, +// slices s around the first instance of sep, +// returning the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns s, "", false. +func CutFromLeft(s, sep string) (before, after string, found bool) { + return strings.Cut(s, sep) +} + +// CutFromRight takes the same arguments as CutFromLeft, +// but starts from right to left, +// slices s around the last instance of sep, +// return the text before and after sep. +// The found result reports whether sep appears in s. +// If sep does not appear in s, cut returns s, "", false. +func CutFromRight(s, sep string) (before, after string, found bool) { + if i := strings.LastIndex(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, "", false +}