From 503ed56ce099edbf589b2fffff12ad59fca6428b Mon Sep 17 00:00:00 2001 From: Piers Powlesland Date: Tue, 18 Apr 2023 15:52:33 +0100 Subject: [PATCH 1/2] Add validation to for compute batch size --- config/config.go | 4 +++ initialization/initialization_test.go | 39 +++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/config/config.go b/config/config.go index b8dde57d1..8d9ba8f27 100644 --- a/config/config.go +++ b/config/config.go @@ -136,6 +136,10 @@ func Validate(cfg Config, opts InitOpts) error { return fmt.Errorf("invalid `opts.MaxFileSize`; expected: >= %d, given: %d", minFileSize, opts.MaxFileSize) } + if opts.ComputeBatchSize == 0 || opts.ComputeBatchSize%8 != 0 { + return fmt.Errorf("invalid `opts.ComputeBatchSize` expected: > 0 and divisible by 8, given: %d", opts.ComputeBatchSize) + } + if res := shared.Uint64MulOverflow(cfg.LabelsPerUnit, uint64(opts.NumUnits)); res { return fmt.Errorf("uint64 overflow: `cfg.LabelsPerUnit` (%v) * `opts.NumUnits` (%v) exceeds the range allowed by uint64", cfg.LabelsPerUnit, opts.NumUnits) diff --git a/initialization/initialization_test.go b/initialization/initialization_test.go index ce66430f2..b6a4d22ab 100644 --- a/initialization/initialization_test.go +++ b/initialization/initialization_test.go @@ -834,6 +834,45 @@ func TestStop(t *testing.T) { } } +func TestValidateComputeBatchSize(t *testing.T) { + cfg := config.DefaultConfig() + opts := config.DefaultInitOpts() + + // Set invalid value of 0 + opts.ComputeBatchSize = 0 + + _, err := NewInitializer( + WithNodeId(nodeId), + WithCommitmentAtxId(commitmentAtxId), + WithConfig(cfg), + WithInitOpts(opts), + WithLogger(testLogger{t: t}), + ) + assert.Error(t, err) + + // Set invalid value of 4 (batch sizes must be divisible by 8) + opts.ComputeBatchSize = 4 + _, err = NewInitializer( + WithNodeId(nodeId), + WithCommitmentAtxId(commitmentAtxId), + WithConfig(cfg), + WithInitOpts(opts), + WithLogger(testLogger{t: t}), + ) + assert.Error(t, err) + + // Set invalid value of 13 (batch sizes must be divisible by 8) + opts.ComputeBatchSize = 13 + _, err = NewInitializer( + WithNodeId(nodeId), + WithCommitmentAtxId(commitmentAtxId), + WithConfig(cfg), + WithInitOpts(opts), + WithLogger(testLogger{t: t}), + ) + assert.Error(t, err) +} + func assertNumLabelsWritten(ctx context.Context, t *testing.T, init *Initializer) func() error { return func() error { timer := time.NewTimer(50 * time.Millisecond) From 2e0377cf46742ccebd871e31a23a7c36bad14270 Mon Sep 17 00:00:00 2001 From: Piers Powlesland Date: Mon, 24 Apr 2023 14:39:42 +0100 Subject: [PATCH 2/2] Change validation to just be > 0 Removed outdated doc specifying that the compute batch size must be divisible by 8. Updated returned validation error and test to not check for divisibility by 8. --- config/config.go | 9 ++++----- initialization/initialization_test.go | 22 ---------------------- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/config/config.go b/config/config.go index 8d9ba8f27..7a384206f 100644 --- a/config/config.go +++ b/config/config.go @@ -12,8 +12,6 @@ import ( const ( DefaultDataDirName = "data" - // DefaultComputeBatchSize value must be divisible by 8, to guarantee that writing to disk - // and file truncating is byte-granular. DefaultComputeBatchSize = 1 << 20 MinBitsPerLabel = 1 @@ -71,7 +69,8 @@ type InitOpts struct { ComputeProviderID int Throttle bool Scrypt ScryptParams - ComputeBatchSize uint64 + // ComputeBatchSize must be greater than 0 + ComputeBatchSize uint64 } type ScryptParams struct { @@ -136,8 +135,8 @@ func Validate(cfg Config, opts InitOpts) error { return fmt.Errorf("invalid `opts.MaxFileSize`; expected: >= %d, given: %d", minFileSize, opts.MaxFileSize) } - if opts.ComputeBatchSize == 0 || opts.ComputeBatchSize%8 != 0 { - return fmt.Errorf("invalid `opts.ComputeBatchSize` expected: > 0 and divisible by 8, given: %d", opts.ComputeBatchSize) + if opts.ComputeBatchSize == 0 { + return fmt.Errorf("invalid `opts.ComputeBatchSize` expected: > 0, given: %d", opts.ComputeBatchSize) } if res := shared.Uint64MulOverflow(cfg.LabelsPerUnit, uint64(opts.NumUnits)); res { diff --git a/initialization/initialization_test.go b/initialization/initialization_test.go index b6a4d22ab..d593728bb 100644 --- a/initialization/initialization_test.go +++ b/initialization/initialization_test.go @@ -849,28 +849,6 @@ func TestValidateComputeBatchSize(t *testing.T) { WithLogger(testLogger{t: t}), ) assert.Error(t, err) - - // Set invalid value of 4 (batch sizes must be divisible by 8) - opts.ComputeBatchSize = 4 - _, err = NewInitializer( - WithNodeId(nodeId), - WithCommitmentAtxId(commitmentAtxId), - WithConfig(cfg), - WithInitOpts(opts), - WithLogger(testLogger{t: t}), - ) - assert.Error(t, err) - - // Set invalid value of 13 (batch sizes must be divisible by 8) - opts.ComputeBatchSize = 13 - _, err = NewInitializer( - WithNodeId(nodeId), - WithCommitmentAtxId(commitmentAtxId), - WithConfig(cfg), - WithInitOpts(opts), - WithLogger(testLogger{t: t}), - ) - assert.Error(t, err) } func assertNumLabelsWritten(ctx context.Context, t *testing.T, init *Initializer) func() error {