Skip to content

Commit

Permalink
refactor: estimate for flash attention
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Jun 14, 2024
1 parent 5228f7e commit f40a734
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 65 deletions.
102 changes: 58 additions & 44 deletions cmd/gguf-parser/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ func main() {
skipProxy bool
skipTLS bool
// estimate options
ctxSize = -1
kvType = "f16"
offloadLayers = -1
batchSize = 512
parallel = 1
noMMap bool
ctxSize = -1
batchSize = 512
parallelSize = 1
kvType = "f16"
offloadLayers = -1
flashAttention bool
noMMap bool
// output options
version bool
skipModel bool
Expand Down Expand Up @@ -65,11 +66,13 @@ func main() {
fs.BoolVar(&skipProxy, "skip-proxy", skipProxy, "Skip using proxy when reading from a remote URL")
fs.BoolVar(&skipTLS, "skip-tls", skipTLS, "Skip TLS verification when reading from a remote URL")
fs.IntVar(&ctxSize, "ctx-size", ctxSize, "Context size to estimate memory usage, default is equal to the model's maximum context size")
fs.StringVar(&kvType, "kv-type", kvType, "Key-Value cache type, select from [f32, f16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1]")
fs.IntVar(&offloadLayers, "offload-layers", offloadLayers, "Specify how many layers to offload, default is fully offloading")
fs.IntVar(&batchSize, "batch-size", batchSize, "Physical maximum batch size")
fs.IntVar(&parallel, "parallel", parallel, "Number of parallel sequences to decode")
fs.BoolVar(&noMMap, "no-mmap", noMMap, "Do not use memory-mapping, which influences the estimate result")
fs.IntVar(&parallelSize, "parallel", parallelSize, "Number of parallel sequences to decode")
fs.StringVar(&kvType, "kv-type", kvType, "Key-Value cache type, select from [f32, f16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1], "+
"using quantization type means enabling Flash Attention as well")
fs.IntVar(&offloadLayers, "offload-layers", offloadLayers, "Specify how many layers to offload, default is fully offloading")
fs.BoolVar(&flashAttention, "flash-attention", flashAttention, "Enable Flash Attention to reduce the memory usage, which influences the estimate result")
fs.BoolVar(&noMMap, "no-mmap", noMMap, "Disable using memory-mapped model(file) loading, which influences the estimate result")
fs.BoolVar(&version, "version", version, "Show version")
fs.BoolVar(&skipModel, "skip-model", skipModel, "Skip model metadata")
fs.BoolVar(&skipArchitecture, "skip-architecture", skipArchitecture, "Skip architecture metadata")
Expand Down Expand Up @@ -110,6 +113,12 @@ func main() {
if ctxSize > 0 {
eopts = append(eopts, WithContextSize(int32(ctxSize)))
}
if batchSize > 0 {
eopts = append(eopts, WithBatchSize(int32(batchSize)))
}
if parallelSize > 0 {
eopts = append(eopts, WithParallelSize(int32(parallelSize)))
}
if kvType != "" {
kv := GGMLTypeF16
switch kvType {
Expand All @@ -135,11 +144,8 @@ func main() {
if offloadLayers >= 0 {
eopts = append(eopts, WithOffloadLayers(uint64(offloadLayers)))
}
if batchSize > 0 {
eopts = append(eopts, WithBatchSize(int32(batchSize)))
}
if parallel > 0 {
eopts = append(eopts, WithParallelSize(int32(parallel)))
if flashAttention {
eopts = append(eopts, WithFlashAttention())
}

// Parse GGUF file.
Expand Down Expand Up @@ -214,7 +220,7 @@ func main() {
}

if !skipModel {
tprintf(
tprint(
"MODEL",
[]string{"Name", "Arch", "Quantization Version", "File Type", "Little Endian", "Size", "Parameters", "BPW"},
[]string{
Expand All @@ -230,12 +236,14 @@ func main() {
}

if !skipArchitecture {
tprintf(
tprint(
"ARCHITECTURE",
[]string{"Max Context Len", "Embedding Len", "Layers", "Feed Forward Len", "Expert Cnt", "Vocabulary Len"},
[]string{"Max Context Len", "Embedding Len", "Embedding GQA", "Attention Head Cnt", "Layers", "Feed Forward Len", "Expert Cnt", "Vocabulary Len"},
[]string{
sprintf(a.MaximumContextLength),
sprintf(a.EmbeddingLength),
sprintf(a.EmbeddingGQA),
sprintf(tenary(a.AttentionHeadCountKV == 0 || a.AttentionHeadCountKV == a.AttentionHeadCount, "N/A", a.AttentionHeadCount)),
sprintf(a.BlockCount),
sprintf(a.FeedForwardLength),
sprintf(a.ExpertCount),
Expand All @@ -244,55 +252,54 @@ func main() {
}

if !skipTokenizer {
sprintTokenID := func(a int64) string {
if a < 0 {
return "N/A"
}
return sprintf(a)
}
tprintf(
tprint(
"TOKENIZER",
[]string{"Model", "Tokens Size", "Tokens Len", "Added Tokens Len", "BOS Token", "EOS Token", "Unknown Token", "Separator Token", "Padding Token"},
[]string{
t.Model,
sprintf(GGUFBytesScalar(t.TokensSize)),
sprintf(t.TokensLength),
sprintf(t.AddedTokensLength),
sprintTokenID(t.BOSTokenID),
sprintTokenID(t.EOSTokenID),
sprintTokenID(t.UnknownTokenID),
sprintTokenID(t.SeparatorTokenID),
sprintTokenID(t.PaddingTokenID),
sprintf(tenary(t.BOSTokenID < 0, "N/A", t.BOSTokenID)),
sprintf(tenary(t.EOSTokenID < 0, "N/A", t.EOSTokenID)),
sprintf(tenary(t.UnknownTokenID < 0, "N/A", t.UnknownTokenID)),
sprintf(tenary(t.SeparatorTokenID < 0, "N/A", t.SeparatorTokenID)),
sprintf(tenary(t.PaddingTokenID < 0, "N/A", t.PaddingTokenID)),
})
}

if !skipEstimate {
es := e.Summarize(!noMMap)
tprintf(
tprint(
"ESTIMATE",
[]string{"Arch", "Context Size", "Full Offload", "MMap Support", "Mem. Arch", "Usage"},
[]string{"Arch", "Context Size", "Full Offload", "Flash Attention", "MMap Support", "Mem. Arch", "Usage"},
[]string{
sprintf(e.Architecture),
sprintf(e.ContextSize),
sprintf(e.FullOffload),
sprintf(!e.NoMMap),
sprintf(es.Architecture),
sprintf(es.ContextSize),
sprintf(es.FullOffload),
sprintf(es.FlashAttention),
sprintf(!es.NoMMap),
"UMA",
sprintf(es.UMA),
},
[]string{
sprintf(e.Architecture),
sprintf(e.ContextSize),
sprintf(e.FullOffload),
sprintf(!e.NoMMap),
sprintf(es.Architecture),
sprintf(es.ContextSize),
sprintf(es.FullOffload),
sprintf(es.FlashAttention),
sprintf(!es.NoMMap),
"NonUMA",
fmt.Sprintf("%s (RAM) + %s (VRAM)", es.NonUMA.RAM, es.NonUMA.VRAM),
sprintf("%s (RAM) + %s (VRAM)", es.NonUMA.RAM, es.NonUMA.VRAM),
})
}
}

func sprintf(a any) string {
switch v := a.(type) {
func sprintf(f any, a ...any) string {
switch v := f.(type) {
case string:
if len(a) != 0 {
return fmt.Sprintf(v, a...)
}
return v
case []byte:
return string(v)
Expand All @@ -319,7 +326,7 @@ func sprintf(a any) string {
}
}

func tprintf(title string, header []string, body ...[]string) {
func tprint(title string, header []string, body ...[]string) {
title = strings.ToUpper(title)
for i := range header {
header[i] = strings.ToUpper(header[i])
Expand All @@ -338,3 +345,10 @@ func tprintf(title string, header []string, body ...[]string) {
tb.Render()
fmt.Println()
}

func tenary(c bool, t, f any) any {
if c {
return t
}
return f
}
4 changes: 2 additions & 2 deletions file_architecture.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ type GGUFArchitectureMetadata struct {
// AttentionLayerNormRMSEpsilon is the epsilon value used in the RMSNorm(Root Mean Square Layer Normalization),
// which is a simplification of the original LayerNorm.
AttentionLayerNormRMSEpsilon float32 `json:"attentionLayerNormRMSEpsilon,omitempty"`
// AttentionKeyLength is the size of a key head.
// AttentionKeyLength(n_embd_head_k) is the size of a key head.
//
// Defaults to `EmbeddingLength / AttentionHeadCount`.
AttentionKeyLength uint32 `json:"attentionKeyLength"`
// AttentionValueLength is the size of a value head.
// AttentionValueLength(n_embd_head_v) is the size of a value head.
//
// Defaults to `EmbeddingLength / AttentionHeadCount`.
AttentionValueLength uint32 `json:"attentionValueLength"`
Expand Down
87 changes: 75 additions & 12 deletions file_estimate.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type (
LLaMACppUsageEstimate struct {
// Architecture describes what architecture this model implements.
Architecture string `json:"architecture"`
// FlashAttention is the flag to indicate whether enable the flash attention,
// true for enable.
FlashAttention bool `json:"flashAttention"`
// FullOffload is the flag to indicate whether the layers are fully offloaded,
// false for partial offloaded or zero offloaded.
FullOffload bool `json:"fullOffload"`
Expand Down Expand Up @@ -86,6 +89,22 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) (
a, t := gf.Architecture(), gf.Tokenizer()
e.Architecture = a.Architecture

// Flash attention.
{
// Quantization requires flash attention,
// see https://github.com/ggerganov/llama.cpp/blob/172c8256840ffd882ab9992ecedbb587d9b21f15/llama.cpp#L16055-L16058.
if *o.CacheValueType > GGMLTypeF16 && !o.FlashAttention {
o.FlashAttention = true
}
// Grok is not compatible with flash attention,
// see https://github.com/ggerganov/llama.cpp/blob/172c8256840ffd882ab9992ecedbb587d9b21f15/llama.cpp#L16050-L16053.
if a.Architecture == "grok" {
o.FlashAttention = false
}

e.FlashAttention = o.FlashAttention
}

// Init hyperparameters,
// https://github.com/ggerganov/llama.cpp/blob/d6ef0e77dd25f54fb5856af47e3926cf6f36c281/llama.cpp#L6957-L7000.
var (
Expand Down Expand Up @@ -283,25 +302,44 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) (
}
e.Offload.Computation.Compute = GGUFBytesScalar(convInc + ssmInc)
} else {
kvcInc := uint64(e.Load.KVCache.Key + e.Offload.KVCache.Key)
for _, l := range tfLs[len(tfLs)-1].Search(regexp.MustCompile(`.*\.\d+\.attn_(norm|q|qkv)\.weight`)) {
rs := GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[l.NDimensions-1], nTokens})
kvcInc += rs
switch {
default:
continue
case strings.HasSuffix(l.Name, ".attn_q.weight"):
case strings.HasSuffix(l.Name, ".attn_qkv.weight"):
rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[0], nTokens})
attnInc := uint64(0)
if o.FlashAttention {
// https://github.com/ggerganov/llama.cpp/blob/172c8256840ffd882ab9992ecedbb587d9b21f15/llama.cpp#L7387.
attnInc = GGMLTypeF16.RowSizeOf([]uint64{nKV, nTokens})
for _, l := range tfLs[len(tfLs)-1].Search(regexp.MustCompile(`.*\.\d+\.attn_(norm|q|qkv)\.weight`)) {
if strings.HasSuffix(l.Name, ".attn_norm.weight") {
rs := GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[l.NDimensions-1], nTokens})
attnInc += rs
continue
}
rs := l.Bytes()
attnInc += rs
}
rs := o.CacheKeyType.RowSizeOf([]uint64{uint64(a.AttentionKeyLength), nKV, a.AttentionHeadCountKV})
attnInc += rs
rs = o.CacheValueType.RowSizeOf([]uint64{uint64(a.AttentionValueLength), nKV, a.AttentionHeadCountKV})
attnInc += rs
} else {
attnInc = uint64(e.Load.KVCache.Key + e.Offload.KVCache.Key)
for _, l := range tfLs[len(tfLs)-1].Search(regexp.MustCompile(`.*\.\d+\.attn_(norm|q|qkv)\.weight`)) {
rs := GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[l.NDimensions-1], nTokens})
attnInc += rs
switch {
default:
continue
case strings.HasSuffix(l.Name, ".attn_q.weight"):
case strings.HasSuffix(l.Name, ".attn_qkv.weight"):
rs = GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[0], nTokens})
}
attnInc += rs * 2 // for RoPE
}
kvcInc += rs * 2 // for RoPE
}
ffnInc := uint64(0)
for _, l := range tfLs[len(tfLs)-1].Search(regexp.MustCompile(`.*\.\d+\.(attn_norm|ffn_norm|ffn_gate|ffn_up)\.weight`)) {
rs := GGMLTypeF32.RowSizeOf([]uint64{l.Dimensions[l.NDimensions-1], nTokens})
ffnInc += rs
}
e.Offload.Computation.Compute = GGUFBytesScalar(max(kvcInc, ffnInc))
e.Offload.Computation.Compute = GGUFBytesScalar(max(attnInc, ffnInc))
// Special case: we cannot use mmap for splitting expert weights in MoE.
if a.ExpertCount > 0 {
e.NoMMap = len(tfLs[0].Search(regexp.MustCompile(`.*\.\d+\.ffn_gate_exps\.weight`))) == 0
Expand Down Expand Up @@ -330,6 +368,8 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) (

// LLaMACppUsageEstimateSummery represents the summary of the usage for loading the GGUF file in llama.cpp.
type LLaMACppUsageEstimateSummery struct {
/* Basic */

// UMA represents the usage of Unified Memory Architecture.
UMA GGUFBytesScalar `json:"uma"`
// NonUMA represents the usage of Non-Unified Memory Architecture.
Expand All @@ -339,6 +379,22 @@ type LLaMACppUsageEstimateSummery struct {
// VRAM is the memory usage for loading the GGUF file in VRAM.
VRAM GGUFBytesScalar `json:"vram"`
} `json:"nonUMA"`

/* Appendix */

// Architecture describes what architecture this model implements.
Architecture string `json:"architecture"`
// FlashAttention is the flag to indicate whether enable the flash attention,
// true for enable.
FlashAttention bool `json:"flashAttention"`
// FullOffload is the flag to indicate whether the layers are fully offloaded,
// false for partial offloaded or zero offloaded.
FullOffload bool `json:"fullOffload"`
// NoMMap is the flag to indicate whether the file must be loaded without mmap,
// true for total loaded.
NoMMap bool `json:"noMMap"`
// ContextSize is the size of the context.
ContextSize uint64 `json:"contextSize"`
}

func (e LLaMACppUsageEstimate) Summarize(mmap bool) (es LLaMACppUsageEstimateSummery) {
Expand Down Expand Up @@ -385,6 +441,13 @@ func (e LLaMACppUsageEstimate) Summarize(mmap bool) (es LLaMACppUsageEstimateSum
es.NonUMA.VRAM = fp + wg + kv + cp
}

// Just copy from the original estimate.
es.Architecture = e.Architecture
es.FlashAttention = e.FlashAttention
es.FullOffload = e.FullOffload
es.NoMMap = e.NoMMap
es.ContextSize = e.ContextSize

return es
}

Expand Down
22 changes: 15 additions & 7 deletions file_estimate_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import (
type (
_LLaMACppUsageEstimateOptions struct {
ContextSize *int32
ParallelSize *int32
BatchSize *int32
ParallelSize *int32
CacheKeyType *GGMLType
CacheValueType *GGMLType
OffloadLayers *uint64
FlashAttention bool
}
LLaMACppUsageEstimateOption func(*_LLaMACppUsageEstimateOptions)
)
Expand All @@ -26,23 +27,23 @@ func WithContextSize(size int32) LLaMACppUsageEstimateOption {
}
}

// WithParallelSize sets the (decoding sequences) parallel size for the estimate.
func WithParallelSize(size int32) LLaMACppUsageEstimateOption {
// WithBatchSize sets the physical batch size for the estimate.
func WithBatchSize(size int32) LLaMACppUsageEstimateOption {
return func(o *_LLaMACppUsageEstimateOptions) {
if size <= 0 {
return
}
o.ParallelSize = &size
o.BatchSize = &size
}
}

// WithBatchSize sets the physical batch size for the estimate.
func WithBatchSize(size int32) LLaMACppUsageEstimateOption {
// WithParallelSize sets the (decoding sequences) parallel size for the estimate.
func WithParallelSize(size int32) LLaMACppUsageEstimateOption {
return func(o *_LLaMACppUsageEstimateOptions) {
if size <= 0 {
return
}
o.BatchSize = &size
o.ParallelSize = &size
}
}

Expand Down Expand Up @@ -80,3 +81,10 @@ func WithOffloadLayers(layers uint64) LLaMACppUsageEstimateOption {
o.OffloadLayers = &layers
}
}

// WithFlashAttention sets the flash attention flag.
func WithFlashAttention() LLaMACppUsageEstimateOption {
return func(o *_LLaMACppUsageEstimateOptions) {
o.FlashAttention = true
}
}

0 comments on commit f40a734

Please sign in to comment.