Skip to content

Commit

Permalink
feat: support rerank
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Oct 17, 2024
1 parent 026e6d5 commit 40f24e1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cmd/gguf-parser/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,7 @@ func mainAction(c *cli.Context) error {
"Flash Attention",
"MMap Load",
"Embedding Only",
"Reranking",
"Distributable",
"Offload Layers",
"Full Offloaded",
Expand All @@ -1326,6 +1327,7 @@ func mainAction(c *cli.Context) error {
"Flash Attention",
"MMap Load",
"Embedding Only",
"Reranking",
"Distributable",
"Offload Layers",
"Full Offloaded",
Expand Down Expand Up @@ -1385,6 +1387,7 @@ func mainAction(c *cli.Context) error {
sprintf(tenary(flashAttention, tenary(es.FlashAttention, "Enabled", "Unsupported"), "Disabled")),
sprintf(tenary(mmap, tenary(!es.NoMMap, "Enabled", "Unsupported"), "Disabled")),
sprintf(tenary(es.EmbeddingOnly, "Yes", "No")),
sprintf(tenary(es.Reranking, "Supported", "Unsupported")),
sprintf(tenary(es.Distributable, "Supported", "Unsupported")),
sprintf(tenary(es.Items[i].FullOffloaded, sprintf("%d (%d + 1)",
es.Items[i].OffloadLayers, es.Items[i].OffloadLayers-1), es.Items[i].OffloadLayers)),
Expand Down
17 changes: 17 additions & 0 deletions file_estimate.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type (
// EmbeddingOnly is the flag to indicate whether the model is used for embedding only,
// true for embedding only.
EmbeddingOnly bool `json:"embeddingOnly"`
// Reranking is the flag to indicate whether the model is used for reranking,
// true for reranking.
Reranking bool `json:"reranking"`
// Distributable is the flag to indicate whether the model is distributable,
// true for distributable.
Distributable bool `json:"distributable"`
Expand Down Expand Up @@ -215,6 +218,10 @@ func (gf *GGUFFile) EstimateLLaMACppRun(opts ...LLaMACppRunEstimateOption) (e LL
if a.Type == "model" && !a.AttentionCausal {
e.EmbeddingOnly = true
o.PhysicalBatchSize = o.LogicalBatchSize
// Reranking.
if _, found := gf.TensorInfos.Index([]string{"cls.bias", "cls.weight"}); found > 0 {
e.Reranking = true
}
}

// Distributable,
Expand Down Expand Up @@ -357,13 +364,19 @@ func (gf *GGUFFile) EstimateLLaMACppRun(opts ...LLaMACppRunEstimateOption) (e LL
ls := gf.Layers()
ioLs, tfLs, _ := ls.Cut([]string{
"token_embd.weight",
"token_embd_norm.weight",
"token_embd_norm.bias",
"token_types.weight",
"output.weight",
"output.bias",
"output_norm.weight",
"output_norm.bias",
})
ipLs, opLs, _ := ioLs.Cut([]string{
"token_embd.weight",
"token_embd_norm.weight",
"token_embd_norm.bias",
"token_types.weight",
})

// Weight.
Expand Down Expand Up @@ -685,6 +698,9 @@ type (
// EmbeddingOnly is the flag to indicate whether the model is used for embedding only,
// true for embedding only.
EmbeddingOnly bool `json:"embeddingOnly"`
// Reranking is the flag to indicate whether the model is used for reranking,
// true for reranking.
Reranking bool `json:"reranking"`
// Distributable is the flag to indicate whether the model is distributable,
// true for distributable.
Distributable bool `json:"distributable"`
Expand Down Expand Up @@ -848,6 +864,7 @@ func (e LLaMACppRunEstimate) Summarize(mmap bool, nonUMARamFootprint, nonUMAVram
es.FlashAttention = e.FlashAttention
es.NoMMap = e.NoMMap
es.EmbeddingOnly = e.EmbeddingOnly
es.Reranking = e.Reranking
es.LogicalBatchSize = e.LogicalBatchSize
es.PhysicalBatchSize = e.PhysicalBatchSize
es.Distributable = e.Distributable
Expand Down

0 comments on commit 40f24e1

Please sign in to comment.