Skip to content

Commit

Permalink
refactor: batch size
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Aug 6, 2024
1 parent 1478adc commit f1aa5f3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 4 additions & 1 deletion cmd/gguf-parser/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,12 @@ func mainAction(c *cli.Context) error {
eopts = append(eopts, WithinMaxContextSize())
}
if logicalBatchSize > 0 {
eopts = append(eopts, WithLogicalBatchSize(int32(logicalBatchSize)))
eopts = append(eopts, WithLogicalBatchSize(int32(max(32, logicalBatchSize))))
}
if physicalBatchSize > 0 {
if physicalBatchSize > logicalBatchSize {
return errors.New("--ubatch-size must be less than or equal to --batch-size")
}
eopts = append(eopts, WithPhysicalBatchSize(int32(physicalBatchSize)))
}
if parallelSize > 0 {
Expand Down
6 changes: 6 additions & 0 deletions file_estimate.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,16 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) (
}
if o.LogicalBatchSize == nil {
o.LogicalBatchSize = ptr.To(int32(2048))
} else {
// See https://github.com/ggerganov/llama.cpp/blob/0bf16de07b0692e7df26b9a633e232bbd66e0360/src/llama.cpp#L16519-L16525.
o.LogicalBatchSize = ptr.To(max(32, *o.LogicalBatchSize))
}
if o.PhysicalBatchSize == nil {
o.PhysicalBatchSize = ptr.To(int32(512))
}
if *o.PhysicalBatchSize > *o.LogicalBatchSize {
panic("physical batch size must be less than or equal to logical batch size")
}

// Architecture and tokenizer metadata.
var (
Expand Down

0 comments on commit f1aa5f3

Please sign in to comment.