Skip to content

Commit

Permalink
refactor: estimate for mamba
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Jun 14, 2024
1 parent f81e6c4 commit 5228f7e
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 89 deletions.
20 changes: 10 additions & 10 deletions cmd/gguf-parser/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,22 @@ func main() {

if !skipEstimate {
es := e.Summarize(!noMMap)
if ctxSize <= 0 {
if a.MaximumContextLength == 0 {
a = gf.Architecture()
}
ctxSize = int(a.MaximumContextLength)
}
tprintf(
"ESTIMATE",
[]string{"Context Size", "Mem. Arch", "Usage"},
[]string{"Arch", "Context Size", "Full Offload", "MMap Support", "Mem. Arch", "Usage"},
[]string{
sprintf(ctxSize),
sprintf(e.Architecture),
sprintf(e.ContextSize),
sprintf(e.FullOffload),
sprintf(!e.NoMMap),
"UMA",
sprintf(es.UMA),
},
[]string{
sprintf(ctxSize),
sprintf(e.Architecture),
sprintf(e.ContextSize),
sprintf(e.FullOffload),
sprintf(!e.NoMMap),
"NonUMA",
fmt.Sprintf("%s (RAM) + %s (VRAM)", es.NonUMA.RAM, es.NonUMA.VRAM),
})
Expand Down Expand Up @@ -330,7 +330,7 @@ func tprintf(title string, header []string, body ...[]string) {
tb.SetAlignment(tablewriter.ALIGN_CENTER)
tb.SetHeaderLine(true)
tb.SetRowLine(true)
tb.SetAutoMergeCellsByColumnIndex([]int{0, 1, 2, 3})
tb.SetAutoMergeCellsByColumnIndex([]int{0, 1, 2, 3, 4})
tb.Append(append([]string{title}, header...))
for i := range body {
tb.Append(append([]string{title}, body[i]...))
Expand Down
15 changes: 9 additions & 6 deletions file_architecture.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ type GGUFArchitectureMetadata struct {

/* Appendix */

// EmbeddingHeadCount is the number of heads in the embedding layer.
EmbeddingHeadCount uint64 `json:"embeddingHeadCount,omitempty"`
// EmbeddingKeyGQA is the number of key GQA in the embedding layer.
EmbeddingKeyGQA uint64 `json:"embeddingKeyGQA,omitempty"`
// EmbeddingValueGQA is the number of value GQA in the embedding layer.
Expand Down Expand Up @@ -261,10 +259,15 @@ func (gf *GGUFFile) Architecture() (ga GGUFArchitectureMetadata) {
ga.VocabularyLength = v.ValueArray().Len
}

if ga.AttentionHeadCount > 0 {
ga.EmbeddingHeadCount = ga.EmbeddingLength / ga.AttentionHeadCount
ga.EmbeddingKeyGQA = uint64(ga.AttentionKeyLength) * ga.AttentionHeadCountKV
ga.EmbeddingValueGQA = uint64(ga.AttentionValueLength) * ga.AttentionHeadCountKV
{
if ga.AttentionHeadCount > 0 {
ga.EmbeddingKeyGQA = uint64(ga.AttentionKeyLength) * ga.AttentionHeadCountKV
ga.EmbeddingValueGQA = uint64(ga.AttentionValueLength) * ga.AttentionHeadCountKV
}
if ga.Architecture == "mamba" {
ga.EmbeddingKeyGQA = uint64((ga.SSMConvolutionKernel - 1) * ga.SSMInnerSize)
ga.EmbeddingValueGQA = uint64(ga.SSMStateSize * ga.SSMInnerSize)
}
ga.EmbeddingGQA = ga.EmbeddingValueGQA
}

Expand Down
Loading

0 comments on commit 5228f7e

Please sign in to comment.