diff --git a/file_estimate.go b/file_estimate.go index 65e2b8d..92c7020 100644 --- a/file_estimate.go +++ b/file_estimate.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/gpustack/gguf-parser-go/util/ptr" + "github.com/gpustack/gguf-parser-go/util/slicex" ) // Types for LLaMACpp estimation. @@ -260,13 +261,15 @@ func (gf *GGUFFile) EstimateLLaMACppRun(opts ...LLaMACppRunEstimateOption) (e LL // Partial offload: !isOffloadOutputLayer. // Zero offload: nOffloadLayers == 0. var ( - nLoadLayers = a.BlockCount nOffloadLayers uint64 - isOffloadOutputLayer bool + nActualOffloadLayers uint64 + nLoadLayers = a.BlockCount fullOffload, partialOffload, zeroOffload bool ) { + var isOffloadOutputLayer bool + // For none model, // see https://github.com/ggerganov/llama.cpp/blob/148ec970b62c3c5ae0a8bfdaad2fc237aaae350d/examples/llava/clip.cpp#L994-L1008. if a.Type != "model" { @@ -284,14 +287,18 @@ func (gf *GGUFFile) EstimateLLaMACppRun(opts ...LLaMACppRunEstimateOption) (e LL nOffloadLayers = a.BlockCount } } + nActualOffloadLayers = nOffloadLayers + if isOffloadOutputLayer { + nActualOffloadLayers += 1 + } nLoadLayers -= nOffloadLayers - e.FullOffloaded = isOffloadOutputLayer && nLoadLayers == 0 - e.OffloadLayers = nOffloadLayers - fullOffload = isOffloadOutputLayer && nLoadLayers == 0 partialOffload = !isOffloadOutputLayer zeroOffload = nOffloadLayers == 0 + + e.FullOffloaded = isOffloadOutputLayer && nLoadLayers == 0 + e.OffloadLayers = nOffloadLayers } // Footprint. @@ -343,13 +350,8 @@ func (gf *GGUFFile) EstimateLLaMACppRun(opts ...LLaMACppRunEstimateOption) (e LL e.Devices[0].HandleLastLayer = i e.Devices[0].Weight.Compute += GGUFBytesScalar(tfLs[i].Bytes()) case i >= offloadStart: - x := float64(i-offloadStart) / float64(nOffloadLayers) - for k := j; k < len(o.TensorSplitFraction); k++ { - if x < o.TensorSplitFraction[k] { - j = k - break - } - } + x := float64(i-offloadStart) / float64(nActualOffloadLayers) + j = slicex.UpperBound(o.TensorSplitFraction, x) e.Devices[j+1].HandleLayers += 1 e.Devices[j+1].HandleLastLayer = i e.Devices[j+1].Remote = len(o.TensorSplitFraction)-len(o.RPCServers) <= j diff --git a/util/slicex/search.go b/util/slicex/search.go new file mode 100644 index 0000000..2887511 --- /dev/null +++ b/util/slicex/search.go @@ -0,0 +1,17 @@ +package slicex + +import "golang.org/x/exp/constraints" + +// UpperBound returns an index of the first element that is greater than value. +func UpperBound[T constraints.Integer | constraints.Float](s []T, e T) int { + l, r := 0, len(s) + for l < r { + m := l + (r-l)/2 + if s[m] <= e { + l = m + 1 + } else { + r = m + } + } + return l +}