From f1aa5f3eff07d634b2c410bfc74eaf08935a7bee Mon Sep 17 00:00:00 2001 From: thxCode Date: Tue, 6 Aug 2024 16:55:59 +0800 Subject: [PATCH] refactor: batch size Signed-off-by: thxCode --- cmd/gguf-parser/main.go | 5 ++++- file_estimate.go | 6 ++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/cmd/gguf-parser/main.go b/cmd/gguf-parser/main.go index b3ddd42..43a2e60 100644 --- a/cmd/gguf-parser/main.go +++ b/cmd/gguf-parser/main.go @@ -573,9 +573,12 @@ func mainAction(c *cli.Context) error { eopts = append(eopts, WithinMaxContextSize()) } if logicalBatchSize > 0 { - eopts = append(eopts, WithLogicalBatchSize(int32(logicalBatchSize))) + eopts = append(eopts, WithLogicalBatchSize(int32(max(32, logicalBatchSize)))) } if physicalBatchSize > 0 { + if physicalBatchSize > logicalBatchSize { + return errors.New("--ubatch-size must be less than or equal to --batch-size") + } eopts = append(eopts, WithPhysicalBatchSize(int32(physicalBatchSize))) } if parallelSize > 0 { diff --git a/file_estimate.go b/file_estimate.go index e446efc..5cf1068 100644 --- a/file_estimate.go +++ b/file_estimate.go @@ -103,10 +103,16 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) ( } if o.LogicalBatchSize == nil { o.LogicalBatchSize = ptr.To(int32(2048)) + } else { + // See https://github.com/ggerganov/llama.cpp/blob/0bf16de07b0692e7df26b9a633e232bbd66e0360/src/llama.cpp#L16519-L16525. + o.LogicalBatchSize = ptr.To(max(32, *o.LogicalBatchSize)) } if o.PhysicalBatchSize == nil { o.PhysicalBatchSize = ptr.To(int32(512)) } + if *o.PhysicalBatchSize > *o.LogicalBatchSize { + panic("physical batch size must be less than or equal to logical batch size") + } // Architecture and tokenizer metadata. var (