diff --git a/cmd/gguf-parser/main.go b/cmd/gguf-parser/main.go index b3ddd42..43a2e60 100644 --- a/cmd/gguf-parser/main.go +++ b/cmd/gguf-parser/main.go @@ -573,9 +573,12 @@ func mainAction(c *cli.Context) error { eopts = append(eopts, WithinMaxContextSize()) } if logicalBatchSize > 0 { - eopts = append(eopts, WithLogicalBatchSize(int32(logicalBatchSize))) + eopts = append(eopts, WithLogicalBatchSize(int32(max(32, logicalBatchSize)))) } if physicalBatchSize > 0 { + if physicalBatchSize > logicalBatchSize { + return errors.New("--ubatch-size must be less than or equal to --batch-size") + } eopts = append(eopts, WithPhysicalBatchSize(int32(physicalBatchSize))) } if parallelSize > 0 { diff --git a/file_estimate.go b/file_estimate.go index e446efc..5cf1068 100644 --- a/file_estimate.go +++ b/file_estimate.go @@ -103,10 +103,16 @@ func (gf *GGUFFile) EstimateLLaMACppUsage(opts ...LLaMACppUsageEstimateOption) ( } if o.LogicalBatchSize == nil { o.LogicalBatchSize = ptr.To(int32(2048)) + } else { + // See https://github.com/ggerganov/llama.cpp/blob/0bf16de07b0692e7df26b9a633e232bbd66e0360/src/llama.cpp#L16519-L16525. + o.LogicalBatchSize = ptr.To(max(32, *o.LogicalBatchSize)) } if o.PhysicalBatchSize == nil { o.PhysicalBatchSize = ptr.To(int32(512)) } + if *o.PhysicalBatchSize > *o.LogicalBatchSize { + panic("physical batch size must be less than or equal to logical batch size") + } // Architecture and tokenizer metadata. var (