From a84eff96e545a168e46b66d82b40798188d78706 Mon Sep 17 00:00:00 2001 From: Daniel Mancia <21249320+dmanc@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:53:36 -0800 Subject: [PATCH] GPU accelerated encoder (#895) --- .gitignore | 2 + api/clients/retrieval_client_test.go | 5 +- core/test/core_test.go | 5 +- core/v2/core_test.go | 5 +- disperser/apiserver/server_test.go | 4 +- disperser/batcher/batcher_test.go | 3 +- disperser/cmd/apiserver/main.go | 3 +- disperser/cmd/encoder/config.go | 6 +- disperser/cmd/encoder/flags/flags.go | 16 ++ disperser/cmd/encoder/icicle.Dockerfile | 62 +++++ disperser/cmd/encoder/main.go | 31 ++- disperser/encoder/config.go | 2 + disperser/encoder/metrics.go | 2 +- disperser/encoder/server.go | 15 +- disperser/encoder/server_test.go | 3 +- disperser/encoder/server_v2.go | 7 + disperser/encoder/server_v2_test.go | 12 +- docker-bake.hcl | 131 ++++++---- encoding/backend.go | 44 ++++ encoding/bench/Makefile | 10 +- encoding/bench/main.go | 68 ++++-- encoding/icicle/device_setup.go | 136 +++++++++++ encoding/icicle/msm_setup.go | 41 ++++ encoding/icicle/ntt_setup.go | 45 ++++ encoding/icicle/utils.go | 103 ++++++++ encoding/kzg/kzgrs.go | 1 + encoding/kzg/pointsIO.go | 27 +- encoding/kzg/prover/decode.go | 2 +- encoding/kzg/prover/decode_test.go | 11 +- encoding/kzg/prover/gnark/commitments.go | 62 +++++ .../prover/{cpu => gnark}/multiframe_proof.go | 142 ++++------- encoding/kzg/prover/icicle.go | 73 ++++++ encoding/kzg/prover/icicle/ecntt.go | 45 ++++ encoding/kzg/prover/icicle/msm.go | 33 +++ .../kzg/prover/icicle/multiframe_proof.go | 231 ++++++++++++++++++ encoding/kzg/prover/noicicle.go | 16 ++ encoding/kzg/prover/parametrized_prover.go | 70 +++--- .../kzg/prover/parametrized_prover_test.go | 11 +- encoding/kzg/prover/precompute.go | 2 - encoding/kzg/prover/proof_backend.go | 19 ++ encoding/kzg/prover/proof_device.go | 16 -- encoding/kzg/prover/prover.go | 162 ++++++++++-- encoding/kzg/prover/prover_cpu.go | 89 ------- encoding/kzg/prover/prover_fuzz_test.go | 5 +- encoding/kzg/prover/prover_test.go | 12 +- .../verifier/batch_commit_equivalence_test.go | 14 +- encoding/kzg/verifier/frame_test.go | 8 +- encoding/kzg/verifier/length_test.go | 7 +- encoding/kzg/verifier/multiframe.go | 8 +- encoding/kzg/verifier/multiframe_test.go | 14 +- encoding/kzg/verifier/verifier.go | 87 ++++--- encoding/kzg/verifier/verifier_test.go | 14 +- encoding/rs/decode.go | 9 +- encoding/rs/encode.go | 140 ++--------- encoding/rs/encode_test.go | 75 ++---- encoding/rs/encoder.go | 97 ++++++-- encoding/rs/encoder_fuzz_test.go | 25 +- encoding/rs/frame_test.go | 63 +---- encoding/rs/{cpu => gnark}/extend_poly.go | 9 +- encoding/rs/icicle.go | 38 +++ encoding/rs/icicle/extend_poly.go | 65 +++++ encoding/rs/interpolation.go | 4 +- encoding/rs/noicicle.go | 15 ++ encoding/rs/parametrized_encoder.go | 125 ++++++++++ encoding/rs/params.go | 5 +- encoding/rs/utils_test.go | 5 +- encoding/test/main.go | 15 +- .../openCommitment/open_commitment_test.go | 15 +- go.mod | 3 +- go.sum | 6 +- inabox/tests/integration_suite_test.go | 10 +- node/grpc/server_test.go | 5 +- node/node.go | 3 +- relay/chunkstore/chunk_store_test.go | 47 +--- relay/relay_test_utils.go | 3 +- retriever/cmd/main.go | 4 +- retriever/server_test.go | 5 +- test/integration_test.go | 6 +- test/synthetic-test/synthetic_client_test.go | 1 + tools/traffic/generator_v2.go | 3 +- 80 files changed, 1940 insertions(+), 803 deletions(-) create mode 100644 disperser/cmd/encoder/icicle.Dockerfile create mode 100644 encoding/backend.go create mode 100644 encoding/icicle/device_setup.go create mode 100644 encoding/icicle/msm_setup.go create mode 100644 encoding/icicle/ntt_setup.go create mode 100644 encoding/icicle/utils.go create mode 100644 encoding/kzg/prover/gnark/commitments.go rename encoding/kzg/prover/{cpu => gnark}/multiframe_proof.go (56%) create mode 100644 encoding/kzg/prover/icicle.go create mode 100644 encoding/kzg/prover/icicle/ecntt.go create mode 100644 encoding/kzg/prover/icicle/msm.go create mode 100644 encoding/kzg/prover/icicle/multiframe_proof.go create mode 100644 encoding/kzg/prover/noicicle.go create mode 100644 encoding/kzg/prover/proof_backend.go delete mode 100644 encoding/kzg/prover/proof_device.go delete mode 100644 encoding/kzg/prover/prover_cpu.go rename encoding/rs/{cpu => gnark}/extend_poly.go (56%) create mode 100644 encoding/rs/icicle.go create mode 100644 encoding/rs/icicle/extend_poly.go create mode 100644 encoding/rs/noicicle.go create mode 100644 encoding/rs/parametrized_encoder.go diff --git a/.gitignore b/.gitignore index 30b0cf8ca9..40fe5f34c5 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ lightnode/docker/args.sh .idea .env .vscode + +icicle/* diff --git a/api/clients/retrieval_client_test.go b/api/clients/retrieval_client_test.go index 2eb1b72f40..0c61531b42 100644 --- a/api/clients/retrieval_client_test.go +++ b/api/clients/retrieval_client_test.go @@ -35,14 +35,15 @@ func makeTestComponents() (encoding.Prover, encoding.Verifier, error) { SRSOrder: 3000, SRSNumberToLoad: 3000, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - p, err := prover.NewProver(config, true) + p, err := prover.NewProver(config, nil) if err != nil { return nil, nil, err } - v, err := verifier.NewVerifier(config, true) + v, err := verifier.NewVerifier(config, nil) if err != nil { return nil, nil, err } diff --git a/core/test/core_test.go b/core/test/core_test.go index 9ee1ec6b10..a47e7e44a2 100644 --- a/core/test/core_test.go +++ b/core/test/core_test.go @@ -51,14 +51,15 @@ func makeTestComponents() (encoding.Prover, encoding.Verifier, error) { SRSOrder: 3000, SRSNumberToLoad: 3000, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - p, err := prover.NewProver(config, true) + p, err := prover.NewProver(config, nil) if err != nil { return nil, nil, err } - v, err := verifier.NewVerifier(config, true) + v, err := verifier.NewVerifier(config, nil) if err != nil { return nil, nil, err } diff --git a/core/v2/core_test.go b/core/v2/core_test.go index ebc6eadb77..e4ecd520b3 100644 --- a/core/v2/core_test.go +++ b/core/v2/core_test.go @@ -79,14 +79,15 @@ func makeTestComponents() (encoding.Prover, encoding.Verifier, error) { SRSOrder: 8192, SRSNumberToLoad: 8192, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - p, err := prover.NewProver(config, true) + p, err := prover.NewProver(config, nil) if err != nil { return nil, nil, err } - v, err := verifier.NewVerifier(config, true) + v, err := verifier.NewVerifier(config, nil) if err != nil { return nil, nil, err } diff --git a/disperser/apiserver/server_test.go b/disperser/apiserver/server_test.go index 83e6273bce..43a5b0333b 100644 --- a/disperser/apiserver/server_test.go +++ b/disperser/apiserver/server_test.go @@ -652,8 +652,10 @@ func setup() { SRSOrder: 8192, SRSNumberToLoad: 8192, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - prover, err = p.NewProver(config, true) + + prover, err = p.NewProver(config, nil) if err != nil { teardown() panic(fmt.Sprintf("failed to initialize KZG prover: %s", err.Error())) diff --git a/disperser/batcher/batcher_test.go b/disperser/batcher/batcher_test.go index 20fa0e7ae3..53c7950eff 100644 --- a/disperser/batcher/batcher_test.go +++ b/disperser/batcher/batcher_test.go @@ -58,9 +58,10 @@ func makeTestProver() (encoding.Prover, error) { SRSOrder: 3000, SRSNumberToLoad: 3000, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - return prover.NewProver(config, true) + return prover.NewProver(config, nil) } func makeTestBlob(securityParams []*core.SecurityParam) core.Blob { diff --git a/disperser/cmd/apiserver/main.go b/disperser/cmd/apiserver/main.go index 07ff7936cc..bebdc18128 100644 --- a/disperser/cmd/apiserver/main.go +++ b/disperser/cmd/apiserver/main.go @@ -162,7 +162,8 @@ func RunDisperserServer(ctx *cli.Context) error { bucketName := config.BlobstoreConfig.BucketName logger.Info("Blob store", "bucket", bucketName) if config.DisperserVersion == V2 { - prover, err := prover.NewProver(&config.EncodingConfig, true) + config.EncodingConfig.LoadG2Points = true + prover, err := prover.NewProver(&config.EncodingConfig, nil) if err != nil { return fmt.Errorf("failed to create encoder: %w", err) } diff --git a/disperser/cmd/encoder/config.go b/disperser/cmd/encoder/config.go index 4003adcf98..d69c2e6857 100644 --- a/disperser/cmd/encoder/config.go +++ b/disperser/cmd/encoder/config.go @@ -28,7 +28,7 @@ type Config struct { EncoderConfig kzg.KzgConfig LoggerConfig common.LoggerConfig ServerConfig *encoder.ServerConfig - MetricsConfig encoder.MetrisConfig + MetricsConfig *encoder.MetricsConfig } func NewConfig(ctx *cli.Context) (Config, error) { @@ -58,10 +58,12 @@ func NewConfig(ctx *cli.Context) (Config, error) { RequestPoolSize: ctx.GlobalInt(flags.RequestPoolSizeFlag.Name), EnableGnarkChunkEncoding: ctx.Bool(flags.EnableGnarkChunkEncodingFlag.Name), PreventReencoding: ctx.Bool(flags.PreventReencodingFlag.Name), + Backend: ctx.String(flags.BackendFlag.Name), + GPUEnable: ctx.Bool(flags.GPUEnableFlag.Name), PprofHttpPort: ctx.GlobalString(flags.PprofHttpPort.Name), EnablePprof: ctx.GlobalBool(flags.EnablePprof.Name), }, - MetricsConfig: encoder.MetrisConfig{ + MetricsConfig: &encoder.MetricsConfig{ HTTPPort: ctx.GlobalString(flags.MetricsHTTPPort.Name), EnableMetrics: ctx.GlobalBool(flags.EnableMetrics.Name), }, diff --git a/disperser/cmd/encoder/flags/flags.go b/disperser/cmd/encoder/flags/flags.go index 8c9399a399..dedb228d6b 100644 --- a/disperser/cmd/encoder/flags/flags.go +++ b/disperser/cmd/encoder/flags/flags.go @@ -3,6 +3,7 @@ package flags import ( "github.com/Layr-Labs/eigenda/common" "github.com/Layr-Labs/eigenda/common/aws" + "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/encoding/kzg" "github.com/urfave/cli" ) @@ -67,6 +68,19 @@ var ( Required: false, EnvVar: common.PrefixEnvVar(envVarPrefix, "ENABLE_GNARK_CHUNK_ENCODING"), } + GPUEnableFlag = cli.BoolFlag{ + Name: common.PrefixFlag(FlagPrefix, "gpu-enable"), + Usage: "Enable GPU, falls back to CPU if not available", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "GPU_ENABLE"), + } + BackendFlag = cli.StringFlag{ + Name: common.PrefixFlag(FlagPrefix, "backend"), + Usage: "Backend to use for encoding", + Required: false, + Value: string(encoding.GnarkBackend), + EnvVar: common.PrefixEnvVar(envVarPrefix, "BACKEND"), + } PreventReencodingFlag = cli.BoolTFlag{ Name: common.PrefixFlag(FlagPrefix, "prevent-reencoding"), Usage: "if true, will prevent reencoding of chunks by checking if the chunk already exists in the chunk store", @@ -100,6 +114,8 @@ var optionalFlags = []cli.Flag{ EnableGnarkChunkEncodingFlag, EncoderVersionFlag, S3BucketNameFlag, + GPUEnableFlag, + BackendFlag, PreventReencodingFlag, PprofHttpPort, EnablePprof, diff --git a/disperser/cmd/encoder/icicle.Dockerfile b/disperser/cmd/encoder/icicle.Dockerfile new file mode 100644 index 0000000000..28b8bd10d2 --- /dev/null +++ b/disperser/cmd/encoder/icicle.Dockerfile @@ -0,0 +1,62 @@ +FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS builder + +# Install Go +ENV GOLANG_VERSION=1.21.1 +ENV GOLANG_SHA256=b3075ae1ce5dab85f89bc7905d1632de23ca196bd8336afd93fa97434cfa55ae + +ADD https://go.dev/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz /tmp/go.tar.gz +RUN echo "${GOLANG_SHA256} /tmp/go.tar.gz" | sha256sum -c - && \ + tar -C /usr/local -xzf /tmp/go.tar.gz && \ + rm /tmp/go.tar.gz +ENV PATH="/usr/local/go/bin:${PATH}" + +# Set up the working directory +WORKDIR /app + +# Copy go.mod and go.sum first to leverage Docker cache +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy the source code +COPY ./disperser /app/disperser +COPY common /app/common +COPY contracts /app/contracts +COPY core /app/core +COPY api /app/api +COPY indexer /app/indexer +COPY encoding /app/encoding +COPY relay /app/relay + +# Define Icicle versions and checksums +ENV ICICLE_VERSION=3.1.0 +ENV ICICLE_BASE_SHA256=2e4e33b8bc3e335b2dd33dcfb10a9aaa18717885509614a24f492f47a2e4f4b1 +ENV ICICLE_CUDA_SHA256=cdba907eac6297445a6c128081ebba5c711d352003f69310145406a8fd781647 + +# Download Icicle tarballs +ADD https://github.com/ingonyama-zk/icicle/releases/download/v${ICICLE_VERSION}/icicle_${ICICLE_VERSION//./_}-ubuntu22.tar.gz /tmp/icicle.tar.gz +ADD https://github.com/ingonyama-zk/icicle/releases/download/v${ICICLE_VERSION}/icicle_${ICICLE_VERSION//./_}-ubuntu22-cuda122.tar.gz /tmp/icicle-cuda.tar.gz + +# Verify checksums and install Icicle +RUN echo "${ICICLE_BASE_SHA256} /tmp/icicle.tar.gz" | sha256sum -c - && \ + echo "${ICICLE_CUDA_SHA256} /tmp/icicle-cuda.tar.gz" | sha256sum -c - && \ + tar xzf /tmp/icicle.tar.gz && \ + cp -r ./icicle/lib/* /usr/lib/ && \ + cp -r ./icicle/include/icicle/ /usr/local/include/ && \ + tar xzf /tmp/icicle-cuda.tar.gz -C /opt && \ + rm /tmp/icicle.tar.gz /tmp/icicle-cuda.tar.gz + +# Build the server with icicle backend +WORKDIR /app/disperser +RUN go build -tags=icicle -o ./bin/server ./cmd/encoder + +# Start a new stage for the base image +FROM nvidia/cuda:12.2.2-base-ubuntu22.04 + +COPY --from=builder /app/disperser/bin/server /usr/local/bin/server +COPY --from=builder /usr/lib/libicicle* /usr/lib/ +COPY --from=builder /usr/local/include/icicle /usr/local/include/icicle +COPY --from=builder /opt/icicle /opt/icicle + +ENTRYPOINT ["server"] diff --git a/disperser/cmd/encoder/main.go b/disperser/cmd/encoder/main.go index 9c90d6bf7e..3382deedfa 100644 --- a/disperser/cmd/encoder/main.go +++ b/disperser/cmd/encoder/main.go @@ -10,11 +10,12 @@ import ( "github.com/Layr-Labs/eigenda/common/aws/s3" "github.com/Layr-Labs/eigenda/disperser/cmd/encoder/flags" blobstorev2 "github.com/Layr-Labs/eigenda/disperser/common/v2/blobstore" - grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" "github.com/Layr-Labs/eigenda/disperser/encoder" + "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/encoding/kzg/prover" - "github.com/prometheus/client_golang/prometheus" "github.com/Layr-Labs/eigenda/relay/chunkstore" + grpcprom "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" + "github.com/prometheus/client_golang/prometheus" "github.com/urfave/cli" ) @@ -69,9 +70,23 @@ func RunEncoderServer(ctx *cli.Context) error { reg.MustRegister(grpcMetrics) } + backendType, err := encoding.ParseBackendType(config.ServerConfig.Backend) + if err != nil { + return err + } + + // Set the encoding config + encodingConfig := &encoding.Config{ + BackendType: backendType, + GPUEnable: config.ServerConfig.GPUEnable, + NumWorker: config.EncoderConfig.NumWorker, + } + if config.EncoderVersion == V2 { - // We no longer compute the commitments in the encoder, so we don't need to load the G2 points - prover, err := prover.NewProver(&config.EncoderConfig, false) + // We no longer load the G2 points in V2 because the KZG commitments are computed + // on the API server side. + config.EncoderConfig.LoadG2Points = false + prover, err := prover.NewProver(&config.EncoderConfig, encodingConfig) if err != nil { return fmt.Errorf("failed to create encoder: %w", err) } @@ -82,6 +97,10 @@ func RunEncoderServer(ctx *cli.Context) error { } blobStoreBucketName := config.BlobStoreConfig.BucketName + if blobStoreBucketName == "" { + return fmt.Errorf("blob store bucket name is required") + } + blobStore := blobstorev2.NewBlobStore(blobStoreBucketName, s3Client, logger) logger.Info("Blob store", "bucket", blobStoreBucketName) @@ -101,7 +120,8 @@ func RunEncoderServer(ctx *cli.Context) error { return server.Start() } - prover, err := prover.NewProver(&config.EncoderConfig, true) + config.EncoderConfig.LoadG2Points = true + prover, err := prover.NewProver(&config.EncoderConfig, encodingConfig) if err != nil { return fmt.Errorf("failed to create encoder: %w", err) } @@ -109,5 +129,4 @@ func RunEncoderServer(ctx *cli.Context) error { server := encoder.NewEncoderServer(*config.ServerConfig, logger, prover, metrics, grpcMetrics) return server.Start() - } diff --git a/disperser/encoder/config.go b/disperser/encoder/config.go index b543efe7b2..168f4185dc 100644 --- a/disperser/encoder/config.go +++ b/disperser/encoder/config.go @@ -10,6 +10,8 @@ type ServerConfig struct { RequestPoolSize int EnableGnarkChunkEncoding bool PreventReencoding bool + Backend string + GPUEnable bool PprofHttpPort string EnablePprof bool } diff --git a/disperser/encoder/metrics.go b/disperser/encoder/metrics.go index d4c4d88680..e8b85fa0d2 100644 --- a/disperser/encoder/metrics.go +++ b/disperser/encoder/metrics.go @@ -13,7 +13,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) -type MetrisConfig struct { +type MetricsConfig struct { HTTPPort string EnableMetrics bool } diff --git a/disperser/encoder/server.go b/disperser/encoder/server.go index f7f06e682c..9cab207a30 100644 --- a/disperser/encoder/server.go +++ b/disperser/encoder/server.go @@ -99,13 +99,6 @@ func (s *EncoderServer) Start() error { return gs.Serve(listener) } -func (s *EncoderServer) Close() { - if s.close == nil { - return - } - s.close() -} - func (s *EncoderServer) EncodeBlob(ctx context.Context, req *pb.EncodeBlobRequest) (*pb.EncodeBlobReply, error) { startTime := time.Now() blobSize := len(req.GetData()) @@ -193,7 +186,6 @@ func (s *EncoderServer) handleEncoding(ctx context.Context, req *pb.EncodeBlobRe } var chunksData [][]byte - var format pb.ChunkEncodingFormat if s.config.EnableGnarkChunkEncoding { format = pb.ChunkEncodingFormat_GNARK @@ -228,3 +220,10 @@ func (s *EncoderServer) handleEncoding(ctx context.Context, req *pb.EncodeBlobRe ChunkEncodingFormat: format, }, nil } + +func (s *EncoderServer) Close() { + if s.close == nil { + return + } + s.close() +} diff --git a/disperser/encoder/server_test.go b/disperser/encoder/server_test.go index 0b2ba4da23..b8c83e0503 100644 --- a/disperser/encoder/server_test.go +++ b/disperser/encoder/server_test.go @@ -42,9 +42,10 @@ func makeTestProver(numPoint uint64) (encoding.Prover, ServerConfig) { SRSOrder: 3000, SRSNumberToLoad: numPoint, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - p, _ := prover.NewProver(kzgConfig, true) + p, _ := prover.NewProver(kzgConfig, nil) encoderServerConfig := ServerConfig{ GrpcPort: "3000", MaxConcurrentRequests: 16, diff --git a/disperser/encoder/server_v2.go b/disperser/encoder/server_v2.go index 498955594c..2842afd591 100644 --- a/disperser/encoder/server_v2.go +++ b/disperser/encoder/server_v2.go @@ -249,3 +249,10 @@ func extractProofsAndCoeffs(frames []*encoding.Frame) ([]*encoding.Proof, []*rs. } return proofs, coeffs } + +func (s *EncoderServerV2) Close() { + if s.close == nil { + return + } + s.close() +} diff --git a/disperser/encoder/server_v2_test.go b/disperser/encoder/server_v2_test.go index 69b30b7448..83dcb2af1d 100644 --- a/disperser/encoder/server_v2_test.go +++ b/disperser/encoder/server_v2_test.go @@ -49,9 +49,9 @@ func makeTestProver(numPoint uint64) (encoding.Prover, error) { SRSOrder: 300000, SRSNumberToLoad: numPoint, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: false, } - - p, err := prover.NewProver(kzgConfig, false) + p, err := prover.NewProver(kzgConfig, nil) return p, err } @@ -59,7 +59,7 @@ func makeTestProver(numPoint uint64) (encoding.Prover, error) { func TestEncodeBlob(t *testing.T) { const ( testDataSize = 16 * 1024 - timeoutSeconds = 30 + timeoutSeconds = 60 randSeed = uint64(42) ) @@ -176,6 +176,12 @@ func TestEncodeBlob(t *testing.T) { // Create and execute encoding request again resp, err := server.EncodeBlob(ctx, req) assert.NoError(t, err) + + if !assert.NotNil(t, resp, "Response should not be nil") { + t.FailNow() // Stop the test here to prevent nil pointer panic + return + } + assert.Equal(t, uint32(294916), resp.FragmentInfo.TotalChunkSizeBytes, "Unexpected total chunk size") assert.Equal(t, uint32(512*1024), resp.FragmentInfo.FragmentSizeBytes, "Unexpected fragment size") assert.Equal(t, c.s3Client.Called["UploadObject"], expectedUploadCalls) diff --git a/docker-bake.hcl b/docker-bake.hcl index ad20c497aa..08e0877083 100644 --- a/docker-bake.hcl +++ b/docker-bake.hcl @@ -30,13 +30,21 @@ variable "GITDATE" { } # GROUPS - group "default" { targets = ["all"] } group "all" { - targets = ["node-group", "batcher", "disperser", "encoder", "retriever", "churner", "dataapi", "traffic-generator"] + targets = [ + "node-group", + "batcher", + "disperser", + "encoder", + "retriever", + "churner", + "dataapi", + "traffic-generator" + ] } group "node-group" { @@ -50,17 +58,32 @@ group "node-group-release" { # Github CI builds group "ci-release" { - targets = ["node-group", "batcher", "disperser", "encoder", "retriever", "churner", "dataapi"] + targets = [ + "node-group", + "batcher", + "disperser", + "encoder", + "retriever", + "churner", + "dataapi" + ] } # Internal devops builds group "internal-release" { - targets = ["node-internal", "batcher-internal", "disperser-internal", "encoder-internal", "retriever-internal", "churner-internal", "dataapi-internal", "traffic-generator-internal"] + targets = [ + "node-internal", + "batcher-internal", + "disperser-internal", + "encoder-internal", + "retriever-internal", + "churner-internal", + "dataapi-internal", + "traffic-generator-internal" + ] } - # DISPERSER TARGETS - target "batcher" { context = "." dockerfile = "./Dockerfile" @@ -70,10 +93,11 @@ target "batcher" { target "batcher-internal" { inherits = ["batcher"] - tags = ["${REGISTRY}/eigenda-batcher:${BUILD_TAG}", - "${REGISTRY}/eigenda-batcher:${GIT_SHA}", - "${REGISTRY}/eigenda-batcher:sha-${GIT_SHORT_SHA}", - ] + tags = [ + "${REGISTRY}/eigenda-batcher:${BUILD_TAG}", + "${REGISTRY}/eigenda-batcher:${GIT_SHA}", + "${REGISTRY}/eigenda-batcher:sha-${GIT_SHORT_SHA}" + ] } target "disperser" { @@ -85,10 +109,11 @@ target "disperser" { target "disperser-internal" { inherits = ["disperser"] - tags = ["${REGISTRY}/eigenda-disperser:${BUILD_TAG}", - "${REGISTRY}/eigenda-disperser:${GIT_SHA}", - "${REGISTRY}/eigenda-disperser:sha-${GIT_SHORT_SHA}", - ] + tags = [ + "${REGISTRY}/eigenda-disperser:${BUILD_TAG}", + "${REGISTRY}/eigenda-disperser:${GIT_SHA}", + "${REGISTRY}/eigenda-disperser:sha-${GIT_SHORT_SHA}" + ] } target "encoder" { @@ -98,12 +123,28 @@ target "encoder" { tags = ["${REGISTRY}/${REPO}/encoder:${BUILD_TAG}"] } +target "encoder-icicle" { + context = "." + dockerfile = "./disperser/cmd/encoder/icicle.Dockerfile" + tags = ["${REGISTRY}/${REPO}/encoder-icicle:${BUILD_TAG}"] +} + target "encoder-internal" { inherits = ["encoder"] - tags = ["${REGISTRY}/eigenda-encoder:${BUILD_TAG}", - "${REGISTRY}/eigenda-encoder:${GIT_SHA}", - "${REGISTRY}/eigenda-encoder:sha-${GIT_SHORT_SHA}", - ] + tags = [ + "${REGISTRY}/eigenda-encoder:${BUILD_TAG}", + "${REGISTRY}/eigenda-encoder:${GIT_SHA}", + "${REGISTRY}/eigenda-encoder:sha-${GIT_SHORT_SHA}" + ] +} + +target "encoder-icicle-internal" { + inherits = ["encoder-icicle"] + tags = [ + "${REGISTRY}/eigenda-encoder-icicle:${BUILD_TAG}", + "${REGISTRY}/eigenda-encoder-icicle:${GIT_SHA}", + "${REGISTRY}/eigenda-encoder-icicle:sha-${GIT_SHORT_SHA}" + ] } target "retriever" { @@ -115,10 +156,11 @@ target "retriever" { target "retriever-internal" { inherits = ["retriever"] - tags = ["${REGISTRY}/eigenda-retriever:${BUILD_TAG}", - "${REGISTRY}/eigenda-retriever:${GIT_SHA}", - "${REGISTRY}/eigenda-retriever:sha-${GIT_SHORT_SHA}", - ] + tags = [ + "${REGISTRY}/eigenda-retriever:${BUILD_TAG}", + "${REGISTRY}/eigenda-retriever:${GIT_SHA}", + "${REGISTRY}/eigenda-retriever:sha-${GIT_SHORT_SHA}" + ] } target "churner" { @@ -130,10 +172,11 @@ target "churner" { target "churner-internal" { inherits = ["churner"] - tags = ["${REGISTRY}/eigenda-churner:${BUILD_TAG}", - "${REGISTRY}/eigenda-churner:${GIT_SHA}", - "${REGISTRY}/eigenda-churner:sha-${GIT_SHORT_SHA}", - ] + tags = [ + "${REGISTRY}/eigenda-churner:${BUILD_TAG}", + "${REGISTRY}/eigenda-churner:${GIT_SHA}", + "${REGISTRY}/eigenda-churner:sha-${GIT_SHORT_SHA}" + ] } target "traffic-generator" { @@ -145,10 +188,11 @@ target "traffic-generator" { target "traffic-generator-internal" { inherits = ["traffic-generator"] - tags = ["${REGISTRY}/eigenda-traffic-generator:${BUILD_TAG}", - "${REGISTRY}/eigenda-traffic-generator:${GIT_SHA}", - "${REGISTRY}/eigenda-traffic-generator:sha-${GIT_SHORT_SHA}", - ] + tags = [ + "${REGISTRY}/eigenda-traffic-generator:${BUILD_TAG}", + "${REGISTRY}/eigenda-traffic-generator:${GIT_SHA}", + "${REGISTRY}/eigenda-traffic-generator:sha-${GIT_SHORT_SHA}" + ] } target "traffic-generator2" { @@ -160,9 +204,10 @@ target "traffic-generator2" { target "traffic-generator2-internal" { inherits = ["traffic-generator2"] - tags = ["${REGISTRY}/eigenda-traffic-generator2:${BUILD_TAG}", + tags = [ + "${REGISTRY}/eigenda-traffic-generator2:${BUILD_TAG}", "${REGISTRY}/eigenda-traffic-generator2:${GIT_SHA}", - "${REGISTRY}/eigenda-traffic-generator2:sha-${GIT_SHORT_SHA}", + "${REGISTRY}/eigenda-traffic-generator2:sha-${GIT_SHORT_SHA}" ] } @@ -175,19 +220,19 @@ target "dataapi" { target "dataapi-internal" { inherits = ["dataapi"] - tags = ["${REGISTRY}/eigenda-dataapi:${BUILD_TAG}", - "${REGISTRY}/eigenda-dataapi:${GIT_SHA}", - "${REGISTRY}/eigenda-dataapi:sha-${GIT_SHORT_SHA}", - ] + tags = [ + "${REGISTRY}/eigenda-dataapi:${BUILD_TAG}", + "${REGISTRY}/eigenda-dataapi:${GIT_SHA}", + "${REGISTRY}/eigenda-dataapi:sha-${GIT_SHORT_SHA}" + ] } # NODE TARGETS - target "node" { context = "." dockerfile = "./Dockerfile" target = "node" - args = { + args = { SEMVER = "${SEMVER}" GITCOMMIT = "${GIT_SHORT_SHA}" GITDATE = "${GITDATE}" @@ -197,10 +242,11 @@ target "node" { target "node-internal" { inherits = ["node"] - tags = ["${REGISTRY}/eigenda-node:${BUILD_TAG}", - "${REGISTRY}/eigenda-node:${GIT_SHA}", - "${REGISTRY}/eigenda-node:sha-${GIT_SHORT_SHA}", - ] + tags = [ + "${REGISTRY}/eigenda-node:${BUILD_TAG}", + "${REGISTRY}/eigenda-node:${GIT_SHA}", + "${REGISTRY}/eigenda-node:sha-${GIT_SHORT_SHA}" + ] } target "nodeplugin" { @@ -211,7 +257,6 @@ target "nodeplugin" { } # PUBLIC RELEASE TARGETS - target "_release" { platforms = ["linux/amd64", "linux/arm64"] } diff --git a/encoding/backend.go b/encoding/backend.go new file mode 100644 index 0000000000..d840e30824 --- /dev/null +++ b/encoding/backend.go @@ -0,0 +1,44 @@ +package encoding + +import ( + "fmt" + "runtime" + + _ "go.uber.org/automaxprocs/maxprocs" +) + +type BackendType string + +const ( + GnarkBackend BackendType = "gnark" + IcicleBackend BackendType = "icicle" +) + +type Config struct { + NumWorker uint64 + BackendType BackendType + GPUEnable bool + Verbose bool +} + +// DefaultConfig returns a Config struct with default values +func DefaultConfig() *Config { + return &Config{ + NumWorker: uint64(runtime.GOMAXPROCS(0)), + BackendType: GnarkBackend, + GPUEnable: false, + Verbose: false, + } +} + +// ParseBackendType converts a string to BackendType and validates it +func ParseBackendType(backend string) (BackendType, error) { + switch BackendType(backend) { + case GnarkBackend: + return GnarkBackend, nil + case IcicleBackend: + return IcicleBackend, nil + default: + return "", fmt.Errorf("unsupported backend type: %s. Must be one of: gnark, icicle", backend) + } +} diff --git a/encoding/bench/Makefile b/encoding/bench/Makefile index cfb1e2a366..4dc0065d84 100644 --- a/encoding/bench/Makefile +++ b/encoding/bench/Makefile @@ -1,9 +1,15 @@ build_cpu: - go build -gcflags="all=-N -l" -ldflags="-s=false -w=false" -o bin/main_cpu main.go + go build -gcflags="all=-N -l" -ldflags="-s=false -w=false" -o bin/main main.go -benchmark_cpu: +build_icicle: + go build -tags=icicle -gcflags="all=-N -l" -ldflags="-s=false -w=false" -o bin/main_icicle main.go + +benchmark_default: go run main.go -cpuprofile cpu.prof -memprofile mem.prof +benchmark_icicle: + go run -tags=icicle main.go -cpuprofile cpu.prof -memprofile mem.prof + cpu_profile: go tool pprof -http=:8080 cpu.prof diff --git a/encoding/bench/main.go b/encoding/bench/main.go index 2d9818977e..aceb04c511 100644 --- a/encoding/bench/main.go +++ b/encoding/bench/main.go @@ -28,48 +28,60 @@ type BenchmarkResult struct { } type Config struct { - OutputFile string - BlobLength uint64 - NumChunks uint64 - NumRuns uint64 - CPUProfile string - MemProfile string - EnableVerify bool + MinBlobLength uint64 `json:"min_blob_length"` + MaxBlobLength uint64 `json:"max_blob_length"` + OutputFile string + BlobLength uint64 + NumChunks uint64 + NumRuns uint64 + CPUProfile string + MemProfile string + EnableVerify bool } func parseFlags() Config { config := Config{} flag.StringVar(&config.OutputFile, "output", "benchmark_results.json", "Output file for results") - flag.Uint64Var(&config.BlobLength, "blob-length", 1048576, "Blob length (power of 2)") + flag.Uint64Var(&config.MinBlobLength, "min-blob-length", 1024, "Minimum blob length (power of 2)") + flag.Uint64Var(&config.MaxBlobLength, "max-blob-length", 1048576, "Maximum blob length (power of 2)") flag.Uint64Var(&config.NumChunks, "num-chunks", 8192, "Minimum number of chunks (power of 2)") - flag.Uint64Var(&config.NumRuns, "num-runs", 10, "Number of times to run the benchmark") flag.StringVar(&config.CPUProfile, "cpuprofile", "", "Write CPU profile to file") flag.StringVar(&config.MemProfile, "memprofile", "", "Write memory profile to file") - flag.BoolVar(&config.EnableVerify, "enable-verify", false, "Verify blobs after encoding") + flag.BoolVar(&config.EnableVerify, "enable-verify", true, "Verify blobs after encoding") flag.Parse() return config } +var kzgConfig = &kzg.KzgConfig{} + func main() { config := parseFlags() fmt.Println("Config output", config.OutputFile) // Setup phase - kzgConfig := &kzg.KzgConfig{ - G1Path: "/home/ec2-user/resources/kzg/g1.point", - G2Path: "/home/ec2-user/resources/kzg/g2.point", - CacheDir: "/home/ec2-user/resources/kzg/SRSTables", + kzgConfig = &kzg.KzgConfig{ + G1Path: "/home/ubuntu/resources/kzg/g1.point", + G2Path: "/home/ubuntu/resources/kzg/g2.point", + CacheDir: "/home/ubuntu/resources/kzg/SRSTables", SRSOrder: 268435456, - SRSNumberToLoad: 2097152, + SRSNumberToLoad: 1048576, NumWorker: uint64(runtime.GOMAXPROCS(0)), - Verbose: true, + LoadG2Points: true, } fmt.Printf("* Task Starts\n") - // create encoding object - p, _ := prover.NewProver(kzgConfig, true) + cfg := &encoding.Config{ + BackendType: encoding.IcicleBackend, + GPUEnable: true, + NumWorker: uint64(runtime.GOMAXPROCS(0)), + } + p, err := prover.NewProver(kzgConfig, cfg) + + if err != nil { + log.Fatalf("Failed to create prover: %v", err) + } if config.CPUProfile != "" { f, err := os.Create(config.CPUProfile) @@ -115,12 +127,13 @@ func runBenchmark(p *prover.Prover, config *Config) []BenchmarkResult { // Fixed coding ratio of 8 codingRatio := uint64(8) - for i := uint64(0); i < config.NumRuns; i++ { - chunkLen := (config.BlobLength * codingRatio) / config.NumChunks + + for blobLength := config.MinBlobLength; blobLength <= config.MaxBlobLength; blobLength *= 2 { + chunkLen := (blobLength * codingRatio) / config.NumChunks if chunkLen < 1 { continue // Skip invalid configurations } - result := benchmarkEncodeAndVerify(p, config.BlobLength, config.NumChunks, chunkLen, config.EnableVerify) + result := benchmarkEncodeAndVerify(p, blobLength, config.NumChunks, chunkLen, config.EnableVerify) results = append(results, result) } return results @@ -134,7 +147,10 @@ func benchmarkEncodeAndVerify(p *prover.Prover, blobLength uint64, numChunks uin fmt.Printf("Running benchmark: numChunks=%d, chunkLen=%d, blobLength=%d\n", params.NumChunks, params.ChunkLength, blobLength) - enc, _ := p.GetKzgEncoder(params) + enc, err := p.GetKzgEncoder(params) + if err != nil { + log.Fatalf("Failed to get KZG encoder: %v", err) + } // Create polynomial inputSize := blobLength @@ -166,9 +182,13 @@ func benchmarkEncodeAndVerify(p *prover.Prover, blobLength uint64, numChunks uin log.Fatal("leading coset inconsistency") } - lc := enc.Fs.ExpandedRootsOfUnity[uint64(j)] + rs, err := enc.GetRsEncoder(enc.EncodingParams) + if err != nil { + log.Fatalf("%v", err) + } + lc := rs.Fs.ExpandedRootsOfUnity[uint64(j)] - g2Atn, err := kzg.ReadG2Point(uint64(len(f.Coeffs)), p.KzgConfig) + g2Atn, err := kzg.ReadG2Point(uint64(len(f.Coeffs)), kzgConfig.SRSOrder, kzgConfig.G2Path) if err != nil { log.Fatalf("Load g2 %v failed\n", err) } diff --git a/encoding/icicle/device_setup.go b/encoding/icicle/device_setup.go new file mode 100644 index 0000000000..77c25626ab --- /dev/null +++ b/encoding/icicle/device_setup.go @@ -0,0 +1,136 @@ +//go:build icicle + +package icicle + +import ( + "errors" + "fmt" + "log/slog" + "sync" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" + iciclebn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" + runtime "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" +) + +// IcicleDevice wraps the core device setup and configurations +type IcicleDevice struct { + Device runtime.Device + NttCfg core.NTTConfig[[iciclebn254.SCALAR_LIMBS]uint32] + MsmCfg core.MSMConfig + FlatFFTPointsT []iciclebn254.Affine + SRSG1Icicle []iciclebn254.Affine +} + +// IcicleDeviceConfig holds configuration options for a single device. +// - The GPUEnable parameter is used to enable GPU acceleration. +// - The NTTSize parameter is used to set the maximum domain size for NTT configuration. +// - The FFTPointsT and SRSG1 parameters are used to set up the MSM configuration. +// - MSM setup is optional and can be skipped by not providing these parameters. +// The reason for this is that not all applications require an MSM setup. For example +// in the case of reed-solomon, it only requires the NTT setup. +type IcicleDeviceConfig struct { + GPUEnable bool + NTTSize uint8 + + // MSM setup parameters (optional) + FFTPointsT [][]bn254.G1Affine + SRSG1 []bn254.G1Affine +} + +// NewIcicleDevice creates and initializes a new IcicleDevice +func NewIcicleDevice(config IcicleDeviceConfig) (*IcicleDevice, error) { + runtime.LoadBackendFromEnvOrDefault() + + device, err := setupDevice(config.GPUEnable) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + wg.Add(1) + + var ( + nttCfg core.NTTConfig[[iciclebn254.SCALAR_LIMBS]uint32] + msmCfg core.MSMConfig + flatFftPointsT []iciclebn254.Affine + srsG1Icicle []iciclebn254.Affine + setupErr error + icicleErr runtime.EIcicleError + ) + + // Setup NTT and optionally MSM on device + runtime.RunOnDevice(&device, func(args ...any) { + defer wg.Done() + + // Setup NTT + nttCfg, icicleErr = SetupNTT(config.NTTSize) + if icicleErr != runtime.Success { + setupErr = fmt.Errorf("could not setup NTT: %v", icicleErr.AsString()) + return + } + + // Setup MSM if parameters are provided + if config.FFTPointsT != nil && config.SRSG1 != nil { + flatFftPointsT, srsG1Icicle, msmCfg, icicleErr = SetupMsmG1( + config.FFTPointsT, + config.SRSG1, + ) + if icicleErr != runtime.Success { + setupErr = fmt.Errorf("could not setup MSM: %v", icicleErr.AsString()) + return + } + } + }) + + wg.Wait() + + if setupErr != nil { + return nil, setupErr + } + + return &IcicleDevice{ + Device: device, + NttCfg: nttCfg, + MsmCfg: msmCfg, + FlatFFTPointsT: flatFftPointsT, + SRSG1Icicle: srsG1Icicle, + }, nil +} + +// setupDevice initializes either a GPU or CPU device +func setupDevice(gpuEnable bool) (runtime.Device, error) { + if gpuEnable { + return setupGPUDevice() + } + + return setupCPUDevice() +} + +// setupGPUDevice attempts to initialize a CUDA device, falling back to CPU if unavailable +func setupGPUDevice() (runtime.Device, error) { + deviceCuda := runtime.CreateDevice("CUDA", 0) + if runtime.IsDeviceAvailable(&deviceCuda) { + device := runtime.CreateDevice("CUDA", 0) + slog.Info("CUDA device available, setting device") + runtime.SetDevice(&device) + + return device, nil + } + + slog.Info("CUDA device not available, falling back to CPU") + return setupCPUDevice() +} + +// setupCPUDevice initializes a CPU device +func setupCPUDevice() (runtime.Device, error) { + device := runtime.CreateDevice("CPU", 0) + if !runtime.IsDeviceAvailable(&device) { + slog.Error("CPU device is not available") + return device, errors.New("cpu device is not available") + } + + runtime.SetDevice(&device) + return device, nil +} diff --git a/encoding/icicle/msm_setup.go b/encoding/icicle/msm_setup.go new file mode 100644 index 0000000000..6d84741483 --- /dev/null +++ b/encoding/icicle/msm_setup.go @@ -0,0 +1,41 @@ +//go:build icicle + +package icicle + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" + iciclebn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" +) + +// SetupMsmG1 initializes the MSM configuration for G1 points. +func SetupMsmG1(rowsG1 [][]bn254.G1Affine, srsG1 []bn254.G1Affine) ([]iciclebn254.Affine, []iciclebn254.Affine, core.MSMConfig, runtime.EIcicleError) { + // Calculate total length needed for rowsG1Icicle + totalLen := 0 + for _, row := range rowsG1 { + totalLen += len(row) + } + + // Pre-allocate slice with exact capacity needed + rowsG1Icicle := make([]iciclebn254.Affine, totalLen) + + currentIdx := 0 + for _, row := range rowsG1 { + converted := BatchConvertGnarkAffineToIcicleAffine(row) + copy(rowsG1Icicle[currentIdx:], converted) + currentIdx += len(row) + } + + srsG1Icicle := BatchConvertGnarkAffineToIcicleAffine(srsG1) + cfgBn254 := core.GetDefaultMSMConfig() + cfgBn254.IsAsync = true + + streamBn254, err := runtime.CreateStream() + if err != runtime.Success { + return nil, nil, cfgBn254, err + } + + cfgBn254.StreamHandle = streamBn254 + return rowsG1Icicle, srsG1Icicle, cfgBn254, runtime.Success +} diff --git a/encoding/icicle/ntt_setup.go b/encoding/icicle/ntt_setup.go new file mode 100644 index 0000000000..cc5ae7dc4e --- /dev/null +++ b/encoding/icicle/ntt_setup.go @@ -0,0 +1,45 @@ +//go:build icicle + +package icicle + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/ntt" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" +) + +// SetupNTT initializes the NTT domain with the domain size of maxScale. +// It returns the NTT configuration and an error if the initialization fails. +func SetupNTT(maxScale uint8) (core.NTTConfig[[bn254.SCALAR_LIMBS]uint32], runtime.EIcicleError) { + cfg := core.GetDefaultNTTInitDomainConfig() + cfgBn254 := ntt.GetDefaultNttConfig() + cfgBn254.IsAsync = true + cfgBn254.Ordering = core.KNN + + err := initDomain(int(maxScale), cfg) + if err != runtime.Success { + return cfgBn254, err + } + + streamBn254, err := runtime.CreateStream() + if err != runtime.Success { + return cfgBn254, err + } + + cfgBn254.StreamHandle = streamBn254 + + return cfgBn254, runtime.Success +} + +func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcicleError { + rouMont, _ := fft.Generator(uint64(1 << largestTestSize)) + rou := rouMont.Bits() + rouIcicle := bn254.ScalarField{} + limbs := core.ConvertUint64ArrToUint32Arr(rou[:]) + + rouIcicle.FromLimbs(limbs) + e := ntt.InitDomain(rouIcicle, cfg) + return e +} diff --git a/encoding/icicle/utils.go b/encoding/icicle/utils.go new file mode 100644 index 0000000000..e4f405c640 --- /dev/null +++ b/encoding/icicle/utils.go @@ -0,0 +1,103 @@ +//go:build icicle + +package icicle + +import ( + "math" + "sync" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" + iciclebn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" +) + +func ConvertFrToScalarFieldsBytes(data []fr.Element) []iciclebn254.ScalarField { + scalars := make([]iciclebn254.ScalarField, len(data)) + + for i := 0; i < len(data); i++ { + src := data[i] // 4 uint64 + var littleEndian [32]byte + + fr.LittleEndian.PutElement(&littleEndian, src) + scalars[i].FromBytesLittleEndian(littleEndian[:]) + } + return scalars +} + +func ConvertScalarFieldsToFrBytes(scalars []iciclebn254.ScalarField) []fr.Element { + frElements := make([]fr.Element, len(scalars)) + + for i := 0; i < len(frElements); i++ { + v := scalars[i] + slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) + frElements[i] = slice64 + } + return frElements +} + +func BatchConvertGnarkAffineToIcicleAffine(gAffineList []bn254.G1Affine) []iciclebn254.Affine { + icicleAffineList := make([]iciclebn254.Affine, len(gAffineList)) + for i := 0; i < len(gAffineList); i++ { + GnarkAffineToIcicleAffine(&gAffineList[i], &icicleAffineList[i]) + } + return icicleAffineList +} + +func GnarkAffineToIcicleAffine(g1 *bn254.G1Affine, iciAffine *iciclebn254.Affine) { + var littleEndBytesX, littleEndBytesY [32]byte + fp.LittleEndian.PutElement(&littleEndBytesX, g1.X) + fp.LittleEndian.PutElement(&littleEndBytesY, g1.Y) + + iciAffine.X.FromBytesLittleEndian(littleEndBytesX[:]) + iciAffine.Y.FromBytesLittleEndian(littleEndBytesY[:]) +} + +func HostSliceIcicleProjectiveToGnarkAffine(ps core.HostSlice[iciclebn254.Projective], numWorker int) []bn254.G1Affine { + output := make([]bn254.G1Affine, len(ps)) + + if len(ps) < numWorker { + numWorker = len(ps) + } + + var wg sync.WaitGroup + + interval := int(math.Ceil(float64(len(ps)) / float64(numWorker))) + + for w := 0; w < numWorker; w++ { + wg.Add(1) + start := w * interval + end := (w + 1) * interval + if len(ps) < end { + end = len(ps) + } + + go func(workerStart, workerEnd int) { + defer wg.Done() + for i := workerStart; i < workerEnd; i++ { + output[i] = IcicleProjectiveToGnarkAffine(ps[i]) + } + + }(start, end) + } + wg.Wait() + return output +} + +func IcicleProjectiveToGnarkAffine(p iciclebn254.Projective) bn254.G1Affine { + px, _ := fp.LittleEndian.Element((*[fp.Bytes]byte)((&p.X).ToBytesLittleEndian())) + py, _ := fp.LittleEndian.Element((*[fp.Bytes]byte)((&p.Y).ToBytesLittleEndian())) + pz, _ := fp.LittleEndian.Element((*[fp.Bytes]byte)((&p.Z).ToBytesLittleEndian())) + + zInv := new(fp.Element) + x := new(fp.Element) + y := new(fp.Element) + + zInv.Inverse(&pz) + + x.Mul(&px, zInv) + y.Mul(&py, zInv) + + return bn254.G1Affine{X: *x, Y: *y} +} diff --git a/encoding/kzg/kzgrs.go b/encoding/kzg/kzgrs.go index ac52cc6aaf..c4b2a8e819 100644 --- a/encoding/kzg/kzgrs.go +++ b/encoding/kzg/kzgrs.go @@ -11,4 +11,5 @@ type KzgConfig struct { SRSNumberToLoad uint64 // Number of points to be loaded from the beginning Verbose bool PreloadEncoder bool + LoadG2Points bool } diff --git a/encoding/kzg/pointsIO.go b/encoding/kzg/pointsIO.go index 545cbece43..16a3abf32e 100644 --- a/encoding/kzg/pointsIO.go +++ b/encoding/kzg/pointsIO.go @@ -39,12 +39,12 @@ func ReadDesiredBytes(reader *bufio.Reader, numBytesToRead uint64) ([]byte, erro } // Read the n-th G1 point from SRS. -func ReadG1Point(n uint64, g *KzgConfig) (bn254.G1Affine, error) { - if n >= g.SRSOrder { - return bn254.G1Affine{}, fmt.Errorf("requested power %v is larger than SRSOrder %v", n, g.SRSOrder) +func ReadG1Point(n uint64, srsOrder uint64, g1Path string) (bn254.G1Affine, error) { + if n >= srsOrder { + return bn254.G1Affine{}, fmt.Errorf("requested power %v is larger than SRSOrder %v", n, srsOrder) } - g1point, err := ReadG1PointSection(g.G1Path, n, n+1, 1) + g1point, err := ReadG1PointSection(g1Path, n, n+1, 1) if err != nil { return bn254.G1Affine{}, fmt.Errorf("error read g1 point section %w", err) } @@ -53,12 +53,12 @@ func ReadG1Point(n uint64, g *KzgConfig) (bn254.G1Affine, error) { } // Read the n-th G2 point from SRS. -func ReadG2Point(n uint64, g *KzgConfig) (bn254.G2Affine, error) { - if n >= g.SRSOrder { - return bn254.G2Affine{}, fmt.Errorf("requested power %v is larger than SRSOrder %v", n, g.SRSOrder) +func ReadG2Point(n uint64, srsOrder uint64, g2Path string) (bn254.G2Affine, error) { + if n >= srsOrder { + return bn254.G2Affine{}, fmt.Errorf("requested power %v is larger than SRSOrder %v", n, srsOrder) } - g2point, err := ReadG2PointSection(g.G2Path, n, n+1, 1) + g2point, err := ReadG2PointSection(g2Path, n, n+1, 1) if err != nil { return bn254.G2Affine{}, fmt.Errorf("error read g2 point section %w", err) } @@ -66,7 +66,7 @@ func ReadG2Point(n uint64, g *KzgConfig) (bn254.G2Affine, error) { } // Read g2 points from power of 2 file -func ReadG2PointOnPowerOf2(exponent uint64, g *KzgConfig) (bn254.G2Affine, error) { +func ReadG2PointOnPowerOf2(exponent uint64, srsOrder uint64, g2PowerOf2Path string) (bn254.G2Affine, error) { // the powerOf2 file, only [tau^exp] are stored. // exponent 0, 1, 2, , .. @@ -79,18 +79,18 @@ func ReadG2PointOnPowerOf2(exponent uint64, g *KzgConfig) (bn254.G2Affine, error // if a actual SRS order is 15, the file will contain four symbols (1,2,4,8) with indices [0,1,2,3] // if a actual SRS order is 16, the file will contain five symbols (1,2,4,8,16) with indices [0,1,2,3,4] - actualPowerOfTau := g.SRSOrder - 1 + actualPowerOfTau := srsOrder - 1 largestPowerofSRS := uint64(math.Log2(float64(actualPowerOfTau))) if exponent > largestPowerofSRS { return bn254.G2Affine{}, fmt.Errorf("requested power %v is larger than largest power of SRS %v", uint64(math.Pow(2, float64(exponent))), largestPowerofSRS) } - if len(g.G2PowerOf2Path) == 0 { + if len(g2PowerOf2Path) == 0 { return bn254.G2Affine{}, errors.New("G2PathPowerOf2 path is empty") } - g2point, err := ReadG2PointSection(g.G2PowerOf2Path, exponent, exponent+1, 1) + g2point, err := ReadG2PointSection(g2PowerOf2Path, exponent, exponent+1, 1) if err != nil { return bn254.G2Affine{}, fmt.Errorf("error read g2 point on power of 2 %w", err) } @@ -120,7 +120,6 @@ func ReadG1Points(filepath string, n uint64, numWorker uint64) ([]bn254.G1Affine if err != nil { return nil, err } - // measure reading time t := time.Now() elapsed := t.Sub(startTimer) @@ -159,6 +158,7 @@ func ReadG1Points(filepath string, n uint64, numWorker uint64) ([]bn254.G1Affine t = time.Now() elapsed = t.Sub(startTimer) log.Println(" Parsing takes", elapsed) + return s1Outs, nil } @@ -328,6 +328,7 @@ func ReadG2Points(filepath string, n uint64, numWorker uint64) ([]bn254.G2Affine t = time.Now() elapsed = t.Sub(startTimer) log.Println(" Parsing takes", elapsed) + return s2Outs, nil } diff --git a/encoding/kzg/prover/decode.go b/encoding/kzg/prover/decode.go index 6b01e7d599..912ba34ca4 100644 --- a/encoding/kzg/prover/decode.go +++ b/encoding/kzg/prover/decode.go @@ -11,5 +11,5 @@ func (g *ParametrizedProver) Decode(frames []enc.Frame, indices []uint64, maxInp rsFrames[ind] = rs.Frame{Coeffs: frame.Coeffs} } - return g.Encoder.Decode(rsFrames, indices, maxInputSize) + return g.Encoder.Decode(rsFrames, indices, maxInputSize, g.EncodingParams) } diff --git a/encoding/kzg/prover/decode_test.go b/encoding/kzg/prover/decode_test.go index 4e14301882..32009c772a 100644 --- a/encoding/kzg/prover/decode_test.go +++ b/encoding/kzg/prover/decode_test.go @@ -5,13 +5,14 @@ import ( "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/encoding/kzg/prover" + "github.com/Layr-Labs/eigenda/encoding/rs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncodeDecodeFrame_AreInverses(t *testing.T) { - - group, _ := prover.NewProver(kzgConfig, true) + group, err := prover.NewProver(kzgConfig, nil) + require.NoError(t, err) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(gettysburgAddressBytes))) @@ -20,7 +21,11 @@ func TestEncodeDecodeFrame_AreInverses(t *testing.T) { require.Nil(t, err) require.NotNil(t, p) - _, _, _, frames, _, err := p.EncodeBytes(gettysburgAddressBytes) + // Convert to inputFr + inputFr, err := rs.ToFrArray(gettysburgAddressBytes) + require.Nil(t, err) + + frames, _, err := p.GetFrames(inputFr) require.Nil(t, err) require.NotNil(t, frames, err) diff --git a/encoding/kzg/prover/gnark/commitments.go b/encoding/kzg/prover/gnark/commitments.go new file mode 100644 index 0000000000..66609d51a2 --- /dev/null +++ b/encoding/kzg/prover/gnark/commitments.go @@ -0,0 +1,62 @@ +package gnark + +import ( + "fmt" + + "github.com/Layr-Labs/eigenda/encoding/kzg" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +type KzgCommitmentsGnarkBackend struct { + KzgConfig *kzg.KzgConfig + Srs *kzg.SRS + G2Trailing []bn254.G2Affine +} + +func (p *KzgCommitmentsGnarkBackend) ComputeLengthProof(coeffs []fr.Element) (*bn254.G2Affine, error) { + inputLength := uint64(len(coeffs)) + return p.ComputeLengthProofForLength(coeffs, inputLength) +} + +func (p *KzgCommitmentsGnarkBackend) ComputeLengthProofForLength(coeffs []fr.Element, length uint64) (*bn254.G2Affine, error) { + if length < uint64(len(coeffs)) { + return nil, fmt.Errorf("length is less than the number of coefficients") + } + + start := p.KzgConfig.SRSNumberToLoad - length + shiftedSecret := p.G2Trailing[start : start+uint64(len(coeffs))] + config := ecc.MultiExpConfig{} + + //The proof of low degree is commitment of the polynomial shifted to the largest srs degree + var lengthProof bn254.G2Affine + _, err := lengthProof.MultiExp(shiftedSecret, coeffs, config) + if err != nil { + return nil, err + } + + return &lengthProof, nil +} + +func (p *KzgCommitmentsGnarkBackend) ComputeCommitment(coeffs []fr.Element) (*bn254.G1Affine, error) { + // compute commit for the full poly + config := ecc.MultiExpConfig{} + var commitment bn254.G1Affine + _, err := commitment.MultiExp(p.Srs.G1[:len(coeffs)], coeffs, config) + if err != nil { + return nil, err + } + return &commitment, nil +} + +func (p *KzgCommitmentsGnarkBackend) ComputeLengthCommitment(coeffs []fr.Element) (*bn254.G2Affine, error) { + config := ecc.MultiExpConfig{} + + var lengthCommitment bn254.G2Affine + _, err := lengthCommitment.MultiExp(p.Srs.G2[:len(coeffs)], coeffs, config) + if err != nil { + return nil, err + } + return &lengthCommitment, nil +} diff --git a/encoding/kzg/prover/cpu/multiframe_proof.go b/encoding/kzg/prover/gnark/multiframe_proof.go similarity index 56% rename from encoding/kzg/prover/cpu/multiframe_proof.go rename to encoding/kzg/prover/gnark/multiframe_proof.go index 10e1084fa9..ed00841a2c 100644 --- a/encoding/kzg/prover/cpu/multiframe_proof.go +++ b/encoding/kzg/prover/gnark/multiframe_proof.go @@ -1,7 +1,8 @@ -package cpu +package gnark import ( "fmt" + "log/slog" "math" "time" @@ -13,107 +14,32 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" ) -type KzgCpuProofDevice struct { +type KzgMultiProofGnarkBackend struct { *kzg.KzgConfig Fs *fft.FFTSettings FFTPointsT [][]bn254.G1Affine // transpose of FFTPoints SFs *fft.FFTSettings - Srs *kzg.SRS - G2Trailing []bn254.G2Affine } type WorkerResult struct { err error } -func (p *KzgCpuProofDevice) ComputeLengthProof(coeffs []fr.Element) (*bn254.G2Affine, error) { - inputLength := uint64(len(coeffs)) - return p.ComputeLengthProofForLength(coeffs, inputLength) -} - -func (p *KzgCpuProofDevice) ComputeLengthProofForLength(coeffs []fr.Element, length uint64) (*bn254.G2Affine, error) { - - if length < uint64(len(coeffs)) { - return nil, fmt.Errorf("length is less than the number of coefficients") - } - - start := p.KzgConfig.SRSNumberToLoad - length - shiftedSecret := p.G2Trailing[start : start+uint64(len(coeffs))] - config := ecc.MultiExpConfig{} - //The proof of low degree is commitment of the polynomial shifted to the largest srs degree - var lengthProof bn254.G2Affine - _, err := lengthProof.MultiExp(shiftedSecret, coeffs, config) - if err != nil { - return nil, err - } - return &lengthProof, nil - -} - -func (p *KzgCpuProofDevice) ComputeCommitment(coeffs []fr.Element) (*bn254.G1Affine, error) { - // compute commit for the full poly - config := ecc.MultiExpConfig{} - var commitment bn254.G1Affine - _, err := commitment.MultiExp(p.Srs.G1[:len(coeffs)], coeffs, config) - if err != nil { - return nil, err - } - return &commitment, nil -} - -func (p *KzgCpuProofDevice) ComputeLengthCommitment(coeffs []fr.Element) (*bn254.G2Affine, error) { - config := ecc.MultiExpConfig{} - - var lengthCommitment bn254.G2Affine - _, err := lengthCommitment.MultiExp(p.Srs.G2[:len(coeffs)], coeffs, config) - if err != nil { - return nil, err - } - return &lengthCommitment, nil -} - -func (p *KzgCpuProofDevice) ComputeMultiFrameProof(polyFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) { +func (p *KzgMultiProofGnarkBackend) ComputeMultiFrameProof(polyFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) { begin := time.Now() // Robert: Standardizing this to use the same math used in precomputeSRS dimE := numChunks l := chunkLen - sumVec := make([]bn254.G1Affine, dimE*2) - - jobChan := make(chan uint64, numWorker) - results := make(chan WorkerResult, numWorker) - - // create storage for intermediate fft outputs - coeffStore := make([][]fr.Element, dimE*2) - for i := range coeffStore { - coeffStore[i] = make([]fr.Element, l) - } - - for w := uint64(0); w < numWorker; w++ { - go p.proofWorker(polyFr, jobChan, l, dimE, coeffStore, results) - } - - for j := uint64(0); j < l; j++ { - jobChan <- j - } - close(jobChan) - - // return last error - var err error - for w := uint64(0); w < numWorker; w++ { - wr := <-results - if wr.err != nil { - err = wr.err - } - } - + // Pre-processing stage + coeffStore, err := p.computeCoeffStore(polyFr, numWorker, l, dimE) if err != nil { - return nil, fmt.Errorf("proof worker error: %v", err) + return nil, fmt.Errorf("coefficient computation error: %v", err) } - preprocessDone := time.Now() // compute proof by multi scaler multiplication + sumVec := make([]bn254.G1Affine, dimE*2) msmErrors := make(chan error, dimE*2) for i := uint64(0); i < dimE*2; i++ { @@ -150,18 +76,54 @@ func (p *KzgCpuProofDevice) ComputeMultiFrameProof(polyFr []fr.Element, numChunk secondECNttDone := time.Now() - fmt.Printf("Multiproof Time Decomp \n\t\ttotal %-20s \n\t\tpreproc %-20s \n\t\tmsm %-20s \n\t\tfft1 %-20s \n\t\tfft2 %-20s\n", - secondECNttDone.Sub(begin).String(), - preprocessDone.Sub(begin).String(), - msmDone.Sub(preprocessDone).String(), - firstECNttDone.Sub(msmDone).String(), - secondECNttDone.Sub(firstECNttDone).String(), + slog.Info("Multiproof Time Decomp", + "total", secondECNttDone.Sub(begin), + "preproc", preprocessDone.Sub(begin), + "msm", msmDone.Sub(preprocessDone), + "fft1", firstECNttDone.Sub(msmDone), + "fft2", secondECNttDone.Sub(firstECNttDone), ) return proofs, nil } -func (p *KzgCpuProofDevice) proofWorker( +// Helper function to handle coefficient computation +func (p *KzgMultiProofGnarkBackend) computeCoeffStore(polyFr []fr.Element, numWorker, l, dimE uint64) ([][]fr.Element, error) { + jobChan := make(chan uint64, numWorker) + results := make(chan WorkerResult, numWorker) + + coeffStore := make([][]fr.Element, dimE*2) + for i := range coeffStore { + coeffStore[i] = make([]fr.Element, l) + } + + // Start workers + for w := uint64(0); w < numWorker; w++ { + go p.proofWorker(polyFr, jobChan, l, dimE, coeffStore, results) + } + + // Send jobs + for j := uint64(0); j < l; j++ { + jobChan <- j + } + close(jobChan) + + // Collect results + var lastErr error + for w := uint64(0); w < numWorker; w++ { + if wr := <-results; wr.err != nil { + lastErr = wr.err + } + } + + if lastErr != nil { + return nil, fmt.Errorf("proof worker error: %v", lastErr) + } + + return coeffStore, nil +} + +func (p *KzgMultiProofGnarkBackend) proofWorker( polyFr []fr.Element, jobChan <-chan uint64, l uint64, @@ -193,7 +155,7 @@ func (p *KzgCpuProofDevice) proofWorker( // phi ^ (coset size ) = 1 // // implicitly pad slices to power of 2 -func (p *KzgCpuProofDevice) GetSlicesCoeff(polyFr []fr.Element, dimE, j, l uint64) ([]fr.Element, error) { +func (p *KzgMultiProofGnarkBackend) GetSlicesCoeff(polyFr []fr.Element, dimE, j, l uint64) ([]fr.Element, error) { // there is a constant term m := uint64(len(polyFr)) - 1 dim := (m - j) / l diff --git a/encoding/kzg/prover/icicle.go b/encoding/kzg/prover/icicle.go new file mode 100644 index 0000000000..67bb3f4f08 --- /dev/null +++ b/encoding/kzg/prover/icicle.go @@ -0,0 +1,73 @@ +//go:build icicle + +package prover + +import ( + "math" + "sync" + + "github.com/Layr-Labs/eigenda/encoding" + "github.com/Layr-Labs/eigenda/encoding/fft" + "github.com/Layr-Labs/eigenda/encoding/icicle" + "github.com/Layr-Labs/eigenda/encoding/kzg" + gnarkprover "github.com/Layr-Labs/eigenda/encoding/kzg/prover/gnark" + icicleprover "github.com/Layr-Labs/eigenda/encoding/kzg/prover/icicle" +) + +const ( + // MAX_NTT_SIZE is the maximum NTT domain size needed to compute FFTs for the + // largest supported blobs. Assuming a coding ratio of 1/8 and symbol size of 32 bytes: + // - Encoded size: 2^{MAX_NTT_SIZE} * 32 bytes ≈ 1 GB + // - Original blob size: 2^{MAX_NTT_SIZE} * 32 / 8 = 2^{MAX_NTT_SIZE + 2} ≈ 128 MB + MAX_NTT_SIZE = 25 +) + +func CreateIcicleBackendProver(p *Prover, params encoding.EncodingParams, fs *fft.FFTSettings, ks *kzg.KZGSettings) (*ParametrizedProver, error) { + _, fftPointsT, err := p.SetupFFTPoints(params) + if err != nil { + return nil, err + } + icicleDevice, err := icicle.NewIcicleDevice(icicle.IcicleDeviceConfig{ + GPUEnable: p.Config.GPUEnable, + NTTSize: MAX_NTT_SIZE, + FFTPointsT: fftPointsT, + SRSG1: p.Srs.G1[:p.KzgConfig.SRSNumberToLoad], + }) + if err != nil { + return nil, err + } + + // Create subgroup FFT settings + t := uint8(math.Log2(float64(2 * params.NumChunks))) + sfs := fft.NewFFTSettings(t) + + // Set up icicle multiproof backend + multiproofBackend := &icicleprover.KzgMultiProofIcicleBackend{ + Fs: fs, + FlatFFTPointsT: icicleDevice.FlatFFTPointsT, + SRSIcicle: icicleDevice.SRSG1Icicle, + SFs: sfs, + Srs: p.Srs, + NttCfg: icicleDevice.NttCfg, + MsmCfg: icicleDevice.MsmCfg, + KzgConfig: p.KzgConfig, + Device: icicleDevice.Device, + GpuLock: sync.Mutex{}, + } + + // Set up gnark commitments backend + commitmentsBackend := &gnarkprover.KzgCommitmentsGnarkBackend{ + Srs: p.Srs, + G2Trailing: p.G2Trailing, + KzgConfig: p.KzgConfig, + } + + return &ParametrizedProver{ + EncodingParams: params, + Encoder: p.encoder, + KzgConfig: p.KzgConfig, + Ks: ks, + KzgMultiProofBackend: multiproofBackend, + KzgCommitmentsBackend: commitmentsBackend, + }, nil +} diff --git a/encoding/kzg/prover/icicle/ecntt.go b/encoding/kzg/prover/icicle/ecntt.go new file mode 100644 index 0000000000..12b73ab4dd --- /dev/null +++ b/encoding/kzg/prover/icicle/ecntt.go @@ -0,0 +1,45 @@ +//go:build icicle + +package icicle + +import ( + "fmt" + + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" + iciclebn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" + ecntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/ecntt" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" +) + +func (c *KzgMultiProofIcicleBackend) ECNttToGnarkOnDevice(batchPoints core.DeviceSlice, isInverse bool, totalSize int) (core.DeviceSlice, error) { + output, err := c.ECNttOnDevice(batchPoints, isInverse, totalSize) + if err != nil { + return output, err + } + + return output, nil +} + +func (c *KzgMultiProofIcicleBackend) ECNttOnDevice(batchPoints core.DeviceSlice, isInverse bool, totalSize int) (core.DeviceSlice, error) { + var p iciclebn254.Projective + var out core.DeviceSlice + + output, err := out.Malloc(p.Size(), totalSize) + if err != runtime.Success { + return out, fmt.Errorf("allocating bytes on device failed: %v", err.AsString()) + } + + if isInverse { + err := ecntt.ECNtt(batchPoints, core.KInverse, &c.NttCfg, output) + if err != runtime.Success { + return out, fmt.Errorf("inverse ecntt failed: %v", err.AsString()) + } + } else { + err := ecntt.ECNtt(batchPoints, core.KForward, &c.NttCfg, output) + if err != runtime.Success { + return out, fmt.Errorf("forward ecntt failed: %v", err.AsString()) + } + } + + return output, nil +} diff --git a/encoding/kzg/prover/icicle/msm.go b/encoding/kzg/prover/icicle/msm.go new file mode 100644 index 0000000000..5f11f503f5 --- /dev/null +++ b/encoding/kzg/prover/icicle/msm.go @@ -0,0 +1,33 @@ +//go:build icicle + +package icicle + +import ( + "fmt" + + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" + iciclebn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/msm" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" +) + +// MsmBatchOnDevice function supports batch across blobs. +// totalSize is the number of output points, which equals to numPoly * 2 * dimE , dimE is number of chunks +func (c *KzgMultiProofIcicleBackend) MsmBatchOnDevice(rowsFrIcicleCopy core.DeviceSlice, rowsG1Icicle []iciclebn254.Affine, totalSize int) (core.DeviceSlice, error) { + rowsG1IcicleCopy := core.HostSliceFromElements[iciclebn254.Affine](rowsG1Icicle) + + var p iciclebn254.Projective + var out core.DeviceSlice + + _, err := out.Malloc(p.Size(), totalSize) + if err != runtime.Success { + return out, fmt.Errorf("allocating bytes on device failed: %v", err.AsString()) + } + + err = msm.Msm(rowsFrIcicleCopy, rowsG1IcicleCopy, &c.MsmCfg, out) + if err != runtime.Success { + return out, fmt.Errorf("msm error: %v", err.AsString()) + } + + return out, nil +} diff --git a/encoding/kzg/prover/icicle/multiframe_proof.go b/encoding/kzg/prover/icicle/multiframe_proof.go new file mode 100644 index 0000000000..ac4fd1adbe --- /dev/null +++ b/encoding/kzg/prover/icicle/multiframe_proof.go @@ -0,0 +1,231 @@ +//go:build icicle + +package icicle + +import ( + "fmt" + "log/slog" + "sync" + "time" + + "github.com/Layr-Labs/eigenda/encoding/fft" + "github.com/Layr-Labs/eigenda/encoding/icicle" + "github.com/Layr-Labs/eigenda/encoding/kzg" + "github.com/Layr-Labs/eigenda/encoding/utils/toeplitz" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" + iciclebn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" +) + +type KzgMultiProofIcicleBackend struct { + *kzg.KzgConfig + Fs *fft.FFTSettings + FlatFFTPointsT []iciclebn254.Affine + SRSIcicle []iciclebn254.Affine + SFs *fft.FFTSettings + Srs *kzg.SRS + NttCfg core.NTTConfig[[iciclebn254.SCALAR_LIMBS]uint32] + MsmCfg core.MSMConfig + Device runtime.Device + GpuLock sync.Mutex +} + +type WorkerResult struct { + err error +} + +// This function supports batching over multiple blobs. +// All blobs must have same size and concatenated passed as polyFr +func (p *KzgMultiProofIcicleBackend) ComputeMultiFrameProof(polyFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) { + begin := time.Now() + + dimE := numChunks + l := chunkLen + numPoly := uint64(len(polyFr)) / dimE / chunkLen + + // Pre-processing stage - CPU computations + flattenCoeffStoreFr, err := p.computeCoeffStore(polyFr, numWorker, l, dimE) + if err != nil { + return nil, fmt.Errorf("coefficient computation error: %v", err) + } + preprocessDone := time.Now() + + flattenCoeffStoreSf := icicle.ConvertFrToScalarFieldsBytes(flattenCoeffStoreFr) + flattenCoeffStoreCopy := core.HostSliceFromElements[iciclebn254.ScalarField](flattenCoeffStoreSf) + + var icicleFFTBatch []bn254.G1Affine + var icicleErr error + + // GPU operations + p.GpuLock.Lock() + defer p.GpuLock.Unlock() + + wg := sync.WaitGroup{} + wg.Add(1) + + var msmDone, firstECNttDone, secondECNttDone time.Time + runtime.RunOnDevice(&p.Device, func(args ...any) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + icicleErr = fmt.Errorf("GPU operation panic: %v", r) + } + }() + + // Copy the flatten coeff to device + var flattenStoreCopyToDevice core.DeviceSlice + flattenCoeffStoreCopy.CopyToDevice(&flattenStoreCopyToDevice, true) + + sumVec, err := p.MsmBatchOnDevice(flattenStoreCopyToDevice, p.FlatFFTPointsT, int(numPoly)*int(dimE)*2) + if err != nil { + icicleErr = fmt.Errorf("msm error: %w", err) + return + } + + // Free the flatten coeff store + flattenStoreCopyToDevice.Free() + + msmDone = time.Now() + + // Compute the first ecntt, and set new batch size for ntt + p.NttCfg.BatchSize = int32(numPoly) + sumVecInv, err := p.ECNttOnDevice(sumVec, true, int(dimE)*2*int(numPoly)) + if err != nil { + icicleErr = fmt.Errorf("first ECNtt error: %w", err) + return + } + + sumVec.Free() + + firstECNttDone = time.Now() + + prunedSumVecInv := sumVecInv.Range(0, int(dimE), false) + + // Compute the second ecntt on the reduced size array + flatProofsBatch, err := p.ECNttToGnarkOnDevice(prunedSumVecInv, false, int(numPoly)*int(dimE)) + if err != nil { + icicleErr = fmt.Errorf("second ECNtt error: %w", err) + return + } + + prunedSumVecInv.Free() + + secondECNttDone = time.Now() + + flatProofsBatchHost := make(core.HostSlice[iciclebn254.Projective], int(numPoly)*int(dimE)) + flatProofsBatchHost.CopyFromDevice(&flatProofsBatch) + flatProofsBatch.Free() + icicleFFTBatch = icicle.HostSliceIcicleProjectiveToGnarkAffine(flatProofsBatchHost, int(p.NumWorker)) + }) + + wg.Wait() + + if icicleErr != nil { + return nil, icicleErr + } + + end := time.Now() + + slog.Info("Multiproof Time Decomp", + "total", end.Sub(begin), + "preproc", preprocessDone.Sub(begin), + "msm", msmDone.Sub(preprocessDone), + "fft1", firstECNttDone.Sub(msmDone), + "fft2", secondECNttDone.Sub(firstECNttDone), + ) + + return icicleFFTBatch, nil +} + +// Modify the function signature to return a flat array +func (p *KzgMultiProofIcicleBackend) computeCoeffStore(polyFr []fr.Element, numWorker, l, dimE uint64) ([]fr.Element, error) { + totalSize := dimE * 2 * l // Total size of the flattened array + coeffStore := make([]fr.Element, totalSize) + + jobChan := make(chan uint64, numWorker) + results := make(chan WorkerResult, numWorker) + + // Start workers + for w := uint64(0); w < numWorker; w++ { + go p.proofWorker(polyFr, jobChan, l, dimE, coeffStore, results) + } + + // Send jobs + for j := uint64(0); j < l; j++ { + jobChan <- j + } + close(jobChan) + + // Collect results + var lastErr error + for w := uint64(0); w < numWorker; w++ { + if wr := <-results; wr.err != nil { + lastErr = wr.err + } + } + + if lastErr != nil { + return nil, fmt.Errorf("proof worker error: %v", lastErr) + } + + return coeffStore, nil +} + +// Modified worker function to write directly to the flat array +func (p *KzgMultiProofIcicleBackend) proofWorker( + polyFr []fr.Element, + jobChan <-chan uint64, + l uint64, + dimE uint64, + coeffStore []fr.Element, + results chan<- WorkerResult, +) { + for j := range jobChan { + coeffs, err := p.GetSlicesCoeff(polyFr, dimE, j, l) + if err != nil { + results <- WorkerResult{ + err: err, + } + return + } + + // Write directly to the correct positions in the flat array + // For each j, we need to write to the corresponding position in each block + for i := uint64(0); i < dimE*2; i++ { + coeffStore[i*l+j] = coeffs[i] + } + } + + results <- WorkerResult{ + err: nil, + } +} + +// output is in the form see primeField toeplitz +// +// phi ^ (coset size ) = 1 +// +// implicitly pad slices to power of 2 +func (p *KzgMultiProofIcicleBackend) GetSlicesCoeff(polyFr []fr.Element, dimE, j, l uint64) ([]fr.Element, error) { + // there is a constant term + m := uint64(len(polyFr)) - 1 + dim := (m - j) / l + + // maximal number of unique values from a toeplitz matrix + tDim := 2*dimE - 1 + + toeV := make([]fr.Element, tDim) + for i := uint64(0); i < dim; i++ { + + toeV[i].Set(&polyFr[m-(j+i*l)]) + } + + // use precompute table + tm, err := toeplitz.NewToeplitz(toeV, p.SFs) + if err != nil { + return nil, err + } + return tm.GetFFTCoeff() +} diff --git a/encoding/kzg/prover/noicicle.go b/encoding/kzg/prover/noicicle.go new file mode 100644 index 0000000000..fea3d68340 --- /dev/null +++ b/encoding/kzg/prover/noicicle.go @@ -0,0 +1,16 @@ +//go:build !icicle + +package prover + +import ( + "errors" + + "github.com/Layr-Labs/eigenda/encoding" + "github.com/Layr-Labs/eigenda/encoding/fft" + "github.com/Layr-Labs/eigenda/encoding/kzg" +) + +func CreateIcicleBackendProver(p *Prover, params encoding.EncodingParams, fs *fft.FFTSettings, ks *kzg.KZGSettings) (*ParametrizedProver, error) { + // Not supported + return nil, errors.New("icicle backend called without icicle build tag") +} diff --git a/encoding/kzg/prover/parametrized_prover.go b/encoding/kzg/prover/parametrized_prover.go index 0af9ab91f7..fade9db13e 100644 --- a/encoding/kzg/prover/parametrized_prover.go +++ b/encoding/kzg/prover/parametrized_prover.go @@ -2,7 +2,6 @@ package prover import ( "fmt" - "log" "log/slog" "time" @@ -16,12 +15,14 @@ import ( ) type ParametrizedProver struct { + encoding.EncodingParams *rs.Encoder - *kzg.KzgConfig - Ks *kzg.KZGSettings + KzgConfig *kzg.KzgConfig + Ks *kzg.KZGSettings - Computer ProofDevice + KzgMultiProofBackend KzgMultiProofsBackend + KzgCommitmentsBackend KzgCommitmentsBackend } type rsEncodeResult struct { @@ -104,11 +105,13 @@ func (g *ParametrizedProver) Encode(inputFr []fr.Element) (*bn254.G1Affine, *bn2 return nil, nil, nil, nil, nil, commitmentResult.Error } - totalProcessingTime := time.Since(encodeStart) + slog.Info("Encoding process details", + "Input_size_bytes", len(inputFr)*encoding.BYTES_PER_SYMBOL, + "Num_chunks", g.NumChunks, + "Chunk_length", g.ChunkLength, + "Total_duration", time.Since(encodeStart), + ) - if g.Verbose { - log.Printf("Total encoding took %v\n", totalProcessingTime) - } return commitmentResult.commitment, commitmentResult.lengthCommitment, commitmentResult.lengthProof, frames, indices, nil } @@ -126,7 +129,7 @@ func (g *ParametrizedProver) GetCommitments(inputFr []fr.Element, length uint64) // compute commit for the full poly go func() { start := time.Now() - commit, err := g.Computer.ComputeCommitment(inputFr) + commit, err := g.KzgCommitmentsBackend.ComputeCommitment(inputFr) commitmentChan <- commitmentResult{ Commitment: commit, Err: err, @@ -136,7 +139,7 @@ func (g *ParametrizedProver) GetCommitments(inputFr []fr.Element, length uint64) go func() { start := time.Now() - lengthCommitment, err := g.Computer.ComputeLengthCommitment(inputFr) + lengthCommitment, err := g.KzgCommitmentsBackend.ComputeLengthCommitment(inputFr) lengthCommitmentChan <- lengthCommitmentResult{ LengthCommitment: lengthCommitment, Err: err, @@ -146,7 +149,7 @@ func (g *ParametrizedProver) GetCommitments(inputFr []fr.Element, length uint64) go func() { start := time.Now() - lengthProof, err := g.Computer.ComputeLengthProofForLength(inputFr, length) + lengthProof, err := g.KzgCommitmentsBackend.ComputeLengthProofForLength(inputFr, length) lengthProofChan <- lengthProofResult{ LengthProof: lengthProof, Err: err, @@ -164,17 +167,16 @@ func (g *ParametrizedProver) GetCommitments(inputFr []fr.Element, length uint64) } totalProcessingTime := time.Since(encodeStart) - log.Printf("\n\t\tCommiting %-v\n\t\tLengthCommit %-v\n\t\tlengthProof %-v\n\t\tMetaInfo. order %-v shift %v\n", - commitmentResult.Duration, - lengthCommitmentResult.Duration, - lengthProofResult.Duration, - g.SRSOrder, - g.SRSOrder-uint64(len(inputFr)), + slog.Info("Commitment process details", + "Input_size_bytes", len(inputFr)*encoding.BYTES_PER_SYMBOL, + "Total_duration", totalProcessingTime, + "Commiting_duration", commitmentResult.Duration, + "LengthCommit_duration", lengthCommitmentResult.Duration, + "lengthProof_duration", lengthProofResult.Duration, + "SRSOrder", g.KzgConfig.SRSOrder, + "SRSOrder_shift", g.KzgConfig.SRSOrder-uint64(len(inputFr)), ) - if g.Verbose { - log.Printf("Total encoding took %v\n", totalProcessingTime) - } return commitmentResult.Commitment, lengthCommitmentResult.LengthCommitment, lengthProofResult.LengthProof, nil } @@ -183,6 +185,8 @@ func (g *ParametrizedProver) GetFrames(inputFr []fr.Element) ([]encoding.Frame, return nil, nil, err } + encodeStart := time.Now() + proofChan := make(chan proofsResult, 1) rsChan := make(chan rsEncodeResult, 1) @@ -190,7 +194,8 @@ func (g *ParametrizedProver) GetFrames(inputFr []fr.Element) ([]encoding.Frame, // compute chunks go func() { start := time.Now() - frames, indices, err := g.Encoder.Encode(inputFr) + + frames, indices, err := g.Encoder.Encode(inputFr, g.EncodingParams) rsChan <- rsEncodeResult{ Frames: frames, Indices: indices, @@ -212,7 +217,7 @@ func (g *ParametrizedProver) GetFrames(inputFr []fr.Element) ([]encoding.Frame, flatpaddedCoeffs = append(flatpaddedCoeffs, paddedCoeffs...) } - proofs, err := g.Computer.ComputeMultiFrameProof(flatpaddedCoeffs, g.NumChunks, g.ChunkLength, g.NumWorker) + proofs, err := g.KzgMultiProofBackend.ComputeMultiFrameProof(flatpaddedCoeffs, g.NumChunks, g.ChunkLength, g.KzgConfig.NumWorker) proofChan <- proofsResult{ Proofs: proofs, Err: err, @@ -227,11 +232,16 @@ func (g *ParametrizedProver) GetFrames(inputFr []fr.Element) ([]encoding.Frame, return nil, nil, multierror.Append(rsResult.Err, proofsResult.Err) } - log.Printf("\n\t\tRS encode %-v\n\t\tmultiProof %-v\n\t\tMetaInfo. order %-v shift %v\n", - rsResult.Duration, - proofsResult.Duration, - g.SRSOrder, - g.SRSOrder-uint64(len(inputFr)), + totalProcessingTime := time.Since(encodeStart) + slog.Info("Frame process details", + "Input_size_bytes", len(inputFr)*encoding.BYTES_PER_SYMBOL, + "Num_chunks", g.NumChunks, + "Chunk_length", g.ChunkLength, + "Total_duration", totalProcessingTime, + "RS_encode_duration", rsResult.Duration, + "multiProof_duration", proofsResult.Duration, + "SRSOrder", g.KzgConfig.SRSOrder, + "SRSOrder_shift", g.KzgConfig.SRSOrder-uint64(len(inputFr)), ) // assemble frames @@ -259,7 +269,7 @@ func (g *ParametrizedProver) GetMultiFrameProofs(inputFr []fr.Element) ([]encodi copy(paddedCoeffs, inputFr) paddingEnd := time.Since(paddingStart) - proofs, err := g.Computer.ComputeMultiFrameProof(paddedCoeffs, g.NumChunks, g.ChunkLength, g.NumWorker) + proofs, err := g.KzgMultiProofBackend.ComputeMultiFrameProof(paddedCoeffs, g.NumChunks, g.ChunkLength, g.KzgConfig.NumWorker) end := time.Since(start) @@ -269,8 +279,8 @@ func (g *ParametrizedProver) GetMultiFrameProofs(inputFr []fr.Element) ([]encodi "Chunk_length", g.ChunkLength, "Total_duration", end, "Padding_duration", paddingEnd, - "SRSOrder", g.SRSOrder, - "SRSOrder_shift", g.SRSOrder-uint64(len(inputFr)), + "SRSOrder", g.KzgConfig.SRSOrder, + "SRSOrder_shift", g.KzgConfig.SRSOrder-uint64(len(inputFr)), ) return proofs, err diff --git a/encoding/kzg/prover/parametrized_prover_test.go b/encoding/kzg/prover/parametrized_prover_test.go index 578ca5d0c5..8f9410aa7d 100644 --- a/encoding/kzg/prover/parametrized_prover_test.go +++ b/encoding/kzg/prover/parametrized_prover_test.go @@ -14,8 +14,8 @@ import ( ) func TestProveAllCosetThreads(t *testing.T) { - - group, _ := prover.NewProver(kzgConfig, true) + group, err := prover.NewProver(kzgConfig, nil) + require.NoError(t, err) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(gettysburgAddressBytes))) enc, err := group.GetKzgEncoder(params) @@ -37,9 +37,12 @@ func TestProveAllCosetThreads(t *testing.T) { assert.Equal(t, j, q, "leading coset inconsistency") fmt.Printf("frame %v leading coset %v\n", i, j) - lc := enc.Fs.ExpandedRootsOfUnity[uint64(j)] + rs, err := enc.GetRsEncoder(params) + require.Nil(t, err) + + lc := rs.Fs.ExpandedRootsOfUnity[uint64(j)] - g2Atn, err := kzg.ReadG2Point(uint64(len(f.Coeffs)), kzgConfig) + g2Atn, err := kzg.ReadG2Point(uint64(len(f.Coeffs)), kzgConfig.SRSOrder, kzgConfig.G2Path) require.Nil(t, err) assert.Nil(t, verifier.VerifyFrame(&f, enc.Ks, commit, &lc, &g2Atn), "Proof %v failed\n", i) } diff --git a/encoding/kzg/prover/precompute.go b/encoding/kzg/prover/precompute.go index 9c76166d82..4d36feb0a6 100644 --- a/encoding/kzg/prover/precompute.go +++ b/encoding/kzg/prover/precompute.go @@ -19,8 +19,6 @@ import ( ) type SubTable struct { - //SizeLow uint64 - //SizeUp uint64 FilePath string } diff --git a/encoding/kzg/prover/proof_backend.go b/encoding/kzg/prover/proof_backend.go new file mode 100644 index 0000000000..f9a05d5ed6 --- /dev/null +++ b/encoding/kzg/prover/proof_backend.go @@ -0,0 +1,19 @@ +package prover + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +// Proof device represents a backend capable of computing KZG multiproofs. +type KzgMultiProofsBackend interface { + ComputeMultiFrameProof(blobFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) +} + +// CommitmentDevice represents a backend capable of computing various KZG commitments. +type KzgCommitmentsBackend interface { + ComputeCommitment(coeffs []fr.Element) (*bn254.G1Affine, error) + ComputeLengthCommitment(coeffs []fr.Element) (*bn254.G2Affine, error) + ComputeLengthProof(coeffs []fr.Element) (*bn254.G2Affine, error) + ComputeLengthProofForLength(blobFr []fr.Element, length uint64) (*bn254.G2Affine, error) +} diff --git a/encoding/kzg/prover/proof_device.go b/encoding/kzg/prover/proof_device.go deleted file mode 100644 index b08f5196dd..0000000000 --- a/encoding/kzg/prover/proof_device.go +++ /dev/null @@ -1,16 +0,0 @@ -package prover - -import ( - "github.com/consensys/gnark-crypto/ecc/bn254" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" -) - -// Proof device represents a device capable of computing various KZG-related computations. -type ProofDevice interface { - // blobFr are coefficients - ComputeCommitment(blobFr []fr.Element) (*bn254.G1Affine, error) - ComputeMultiFrameProof(blobFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) - ComputeLengthCommitment(blobFr []fr.Element) (*bn254.G2Affine, error) - ComputeLengthProof(blobFr []fr.Element) (*bn254.G2Affine, error) - ComputeLengthProofForLength(blobFr []fr.Element, length uint64) (*bn254.G2Affine, error) -} diff --git a/encoding/kzg/prover/prover.go b/encoding/kzg/prover/prover.go index a6aa5dea83..bdcce2ab92 100644 --- a/encoding/kzg/prover/prover.go +++ b/encoding/kzg/prover/prover.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "log" + "log/slog" + "math" "os" "runtime" "strconv" @@ -11,32 +13,39 @@ import ( "sync" "github.com/Layr-Labs/eigenda/encoding" + "github.com/Layr-Labs/eigenda/encoding/fft" "github.com/Layr-Labs/eigenda/encoding/kzg" + gnarkprover "github.com/Layr-Labs/eigenda/encoding/kzg/prover/gnark" "github.com/Layr-Labs/eigenda/encoding/rs" "github.com/consensys/gnark-crypto/ecc/bn254" - _ "go.uber.org/automaxprocs" ) type Prover struct { - *kzg.KzgConfig - Srs *kzg.SRS - G2Trailing []bn254.G2Affine - mu sync.Mutex - LoadG2Points bool + Config *encoding.Config + KzgConfig *kzg.KzgConfig + encoder *rs.Encoder + encoding.BackendType + Srs *kzg.SRS + G2Trailing []bn254.G2Affine + mu sync.Mutex ParametrizedProvers map[encoding.EncodingParams]*ParametrizedProver } var _ encoding.Prover = &Prover{} -func NewProver(config *kzg.KzgConfig, loadG2Points bool) (*Prover, error) { - if config.SRSNumberToLoad > config.SRSOrder { +func NewProver(kzgConfig *kzg.KzgConfig, encoderConfig *encoding.Config) (*Prover, error) { + if encoderConfig == nil { + encoderConfig = encoding.DefaultConfig() + } + + if kzgConfig.SRSNumberToLoad > kzgConfig.SRSOrder { return nil, errors.New("SRSOrder is less than srsNumberToLoad") } // read the whole order, and treat it as entire SRS for low degree proof - s1, err := kzg.ReadG1Points(config.G1Path, config.SRSNumberToLoad, config.NumWorker) + s1, err := kzg.ReadG1Points(kzgConfig.G1Path, kzgConfig.SRSNumberToLoad, kzgConfig.NumWorker) if err != nil { log.Println("failed to read G1 points", err) return nil, err @@ -46,29 +55,29 @@ func NewProver(config *kzg.KzgConfig, loadG2Points bool) (*Prover, error) { g2Trailing := make([]bn254.G2Affine, 0) // PreloadEncoder is by default not used by operator node, PreloadEncoder - if loadG2Points { - if len(config.G2Path) == 0 { + if kzgConfig.LoadG2Points { + if len(kzgConfig.G2Path) == 0 { return nil, errors.New("G2Path is empty. However, object needs to load G2Points") } - s2, err = kzg.ReadG2Points(config.G2Path, config.SRSNumberToLoad, config.NumWorker) + s2, err = kzg.ReadG2Points(kzgConfig.G2Path, kzgConfig.SRSNumberToLoad, kzgConfig.NumWorker) if err != nil { log.Println("failed to read G2 points", err) return nil, err } g2Trailing, err = kzg.ReadG2PointSection( - config.G2Path, - config.SRSOrder-config.SRSNumberToLoad, - config.SRSOrder, // last exclusive - config.NumWorker, + kzgConfig.G2Path, + kzgConfig.SRSOrder-kzgConfig.SRSNumberToLoad, + kzgConfig.SRSOrder, // last exclusive + kzgConfig.NumWorker, ) if err != nil { return nil, err } } else { // todo, there are better ways to handle it - if len(config.G2PowerOf2Path) == 0 { + if len(kzgConfig.G2PowerOf2Path) == 0 { return nil, errors.New("G2PowerOf2Path is empty. However, object needs to load G2Points") } } @@ -81,17 +90,25 @@ func NewProver(config *kzg.KzgConfig, loadG2Points bool) (*Prover, error) { fmt.Println("numthread", runtime.GOMAXPROCS(0)) + // Create RS encoder + rsEncoder, err := rs.NewEncoder(encoderConfig) + if err != nil { + slog.Error("Could not create RS encoder", "err", err) + return nil, err + } + encoderGroup := &Prover{ - KzgConfig: config, + Config: encoderConfig, + encoder: rsEncoder, + KzgConfig: kzgConfig, Srs: srs, G2Trailing: g2Trailing, ParametrizedProvers: make(map[encoding.EncodingParams]*ParametrizedProver), - LoadG2Points: loadG2Points, } - if config.PreloadEncoder { + if kzgConfig.PreloadEncoder { // create table dir if not exist - err := os.MkdirAll(config.CacheDir, os.ModePerm) + err := os.MkdirAll(kzgConfig.CacheDir, os.ModePerm) if err != nil { log.Println("Cannot make CacheDir", err) return nil, err @@ -104,11 +121,10 @@ func NewProver(config *kzg.KzgConfig, loadG2Points bool) (*Prover, error) { } return encoderGroup, nil - } func (g *Prover) PreloadAllEncoders() error { - paramsAll, err := GetAllPrecomputedSrsMap(g.CacheDir) + paramsAll, err := GetAllPrecomputedSrsMap(g.KzgConfig.CacheDir) if err != nil { return err } @@ -134,7 +150,6 @@ func (g *Prover) PreloadAllEncoders() error { } func (e *Prover) EncodeAndProve(data []byte, params encoding.EncodingParams) (encoding.BlobCommitments, []*encoding.Frame, error) { - enc, err := e.GetKzgEncoder(params) if err != nil { return encoding.BlobCommitments{}, nil, err @@ -188,7 +203,6 @@ func (e *Prover) GetFrames(data []byte, params encoding.EncodingParams) ([]*enco chunks := make([]*encoding.Frame, len(kzgFrames)) for ind, frame := range kzgFrames { - chunks[ind] = &encoding.Frame{ Coeffs: frame.Coeffs, Proof: frame.Proof, @@ -267,7 +281,7 @@ func (g *Prover) GetKzgEncoder(params encoding.EncodingParams) (*ParametrizedPro } func (g *Prover) GetSRSOrder() uint64 { - return g.SRSOrder + return g.KzgConfig.SRSOrder } // Detect the precomputed table from the specified directory @@ -321,6 +335,7 @@ func (p *Prover) Decode(chunks []*encoding.Frame, indices []encoding.ChunkNumber Coeffs: chunks[i].Coeffs, } } + encoder, err := p.GetKzgEncoder(params) if err != nil { return nil, err @@ -336,3 +351,98 @@ func toUint64Array(chunkIndices []encoding.ChunkNumber) []uint64 { } return res } + +func (p *Prover) newProver(params encoding.EncodingParams) (*ParametrizedProver, error) { + if err := encoding.ValidateEncodingParams(params, p.KzgConfig.SRSOrder); err != nil { + return nil, err + } + + // Create FFT settings based on params + n := uint8(math.Log2(float64(params.NumEvaluations()))) + if params.ChunkLength == 1 { + n = uint8(math.Log2(float64(2 * params.NumChunks))) + } + fs := fft.NewFFTSettings(n) + + // Create base KZG settings + ks, err := kzg.NewKZGSettings(fs, p.Srs) + if err != nil { + return nil, fmt.Errorf("failed to create KZG settings: %w", err) + } + + switch p.Config.BackendType { + case encoding.GnarkBackend: + return p.createGnarkBackendProver(params, fs, ks) + case encoding.IcicleBackend: + return p.createIcicleBackendProver(params, fs, ks) + default: + return nil, fmt.Errorf("unsupported backend type: %v", p.Config.BackendType) + } + +} + +func (p *Prover) createGnarkBackendProver(params encoding.EncodingParams, fs *fft.FFTSettings, ks *kzg.KZGSettings) (*ParametrizedProver, error) { + if p.Config.GPUEnable { + return nil, errors.New("GPU is not supported in gnark backend") + } + + _, fftPointsT, err := p.SetupFFTPoints(params) + if err != nil { + return nil, err + } + + // Create subgroup FFT settings + t := uint8(math.Log2(float64(2 * params.NumChunks))) + sfs := fft.NewFFTSettings(t) + + // Set KZG Prover gnark backend + multiproofBackend := &gnarkprover.KzgMultiProofGnarkBackend{ + Fs: fs, + FFTPointsT: fftPointsT, + SFs: sfs, + KzgConfig: p.KzgConfig, + } + + // Set KZG Commitments gnark backend + commitmentsBackend := &gnarkprover.KzgCommitmentsGnarkBackend{ + Srs: p.Srs, + G2Trailing: p.G2Trailing, + KzgConfig: p.KzgConfig, + } + + return &ParametrizedProver{ + Encoder: p.encoder, + EncodingParams: params, + KzgConfig: p.KzgConfig, + Ks: ks, + KzgMultiProofBackend: multiproofBackend, + KzgCommitmentsBackend: commitmentsBackend, + }, nil +} + +func (p *Prover) createIcicleBackendProver(params encoding.EncodingParams, fs *fft.FFTSettings, ks *kzg.KZGSettings) (*ParametrizedProver, error) { + return CreateIcicleBackendProver(p, params, fs, ks) +} + +// Helper methods for setup +func (p *Prover) SetupFFTPoints(params encoding.EncodingParams) ([][]bn254.G1Affine, [][]bn254.G1Affine, error) { + subTable, err := NewSRSTable(p.KzgConfig.CacheDir, p.Srs.G1, p.KzgConfig.NumWorker) + if err != nil { + return nil, nil, fmt.Errorf("failed to create SRS table: %w", err) + } + + fftPoints, err := subTable.GetSubTables(params.NumChunks, params.ChunkLength) + if err != nil { + return nil, nil, fmt.Errorf("failed to get sub tables: %w", err) + } + + fftPointsT := make([][]bn254.G1Affine, len(fftPoints[0])) + for i := range fftPointsT { + fftPointsT[i] = make([]bn254.G1Affine, len(fftPoints)) + for j := uint64(0); j < params.ChunkLength; j++ { + fftPointsT[i][j] = fftPoints[j][i] + } + } + + return fftPoints, fftPointsT, nil +} diff --git a/encoding/kzg/prover/prover_cpu.go b/encoding/kzg/prover/prover_cpu.go deleted file mode 100644 index 8a250c249d..0000000000 --- a/encoding/kzg/prover/prover_cpu.go +++ /dev/null @@ -1,89 +0,0 @@ -//go:build !gpu -// +build !gpu - -package prover - -import ( - "log" - "math" - - "github.com/Layr-Labs/eigenda/encoding" - "github.com/Layr-Labs/eigenda/encoding/fft" - "github.com/Layr-Labs/eigenda/encoding/kzg" - kzg_prover_cpu "github.com/Layr-Labs/eigenda/encoding/kzg/prover/cpu" - "github.com/Layr-Labs/eigenda/encoding/rs" - rs_cpu "github.com/Layr-Labs/eigenda/encoding/rs/cpu" - "github.com/consensys/gnark-crypto/ecc/bn254" - - _ "go.uber.org/automaxprocs" -) - -func (g *Prover) newProver(params encoding.EncodingParams) (*ParametrizedProver, error) { - if err := encoding.ValidateEncodingParams(params, g.SRSOrder); err != nil { - return nil, err - } - - encoder, err := rs.NewEncoder(params, g.Verbose) - if err != nil { - log.Println("Could not create encoder: ", err) - return nil, err - } - - subTable, err := NewSRSTable(g.CacheDir, g.Srs.G1, g.NumWorker) - if err != nil { - log.Println("Could not create srs table:", err) - return nil, err - } - - fftPoints, err := subTable.GetSubTables(encoder.NumChunks, encoder.ChunkLength) - if err != nil { - log.Println("could not get sub tables", err) - return nil, err - } - - fftPointsT := make([][]bn254.G1Affine, len(fftPoints[0])) - for i := range fftPointsT { - fftPointsT[i] = make([]bn254.G1Affine, len(fftPoints)) - for j := uint64(0); j < encoder.ChunkLength; j++ { - fftPointsT[i][j] = fftPoints[j][i] - } - } - _ = fftPoints - n := uint8(math.Log2(float64(encoder.NumEvaluations()))) - if encoder.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * encoder.NumChunks))) - } - fs := fft.NewFFTSettings(n) - - ks, err := kzg.NewKZGSettings(fs, g.Srs) - if err != nil { - return nil, err - } - - t := uint8(math.Log2(float64(2 * encoder.NumChunks))) - sfs := fft.NewFFTSettings(t) - - // Set KZG Prover CPU computer - computer := &kzg_prover_cpu.KzgCpuProofDevice{ - Fs: fs, - FFTPointsT: fftPointsT, - SFs: sfs, - Srs: g.Srs, - G2Trailing: g.G2Trailing, - KzgConfig: g.KzgConfig, - } - - // Set RS CPU computer - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - encoder.Computer = RsComputeDevice - - return &ParametrizedProver{ - Encoder: encoder, - KzgConfig: g.KzgConfig, - Ks: ks, - Computer: computer, - }, nil -} diff --git a/encoding/kzg/prover/prover_fuzz_test.go b/encoding/kzg/prover/prover_fuzz_test.go index 4b2fd4457a..5fcf536386 100644 --- a/encoding/kzg/prover/prover_fuzz_test.go +++ b/encoding/kzg/prover/prover_fuzz_test.go @@ -6,14 +6,15 @@ import ( "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/encoding/kzg/prover" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func FuzzOnlySystematic(f *testing.F) { f.Add(gettysburgAddressBytes) f.Fuzz(func(t *testing.T, input []byte) { - - group, _ := prover.NewProver(kzgConfig, true) + group, err := prover.NewProver(kzgConfig, nil) + require.NoError(t, err) params := encoding.ParamsFromSysPar(10, 3, uint64(len(input))) enc, err := group.GetKzgEncoder(params) diff --git a/encoding/kzg/prover/prover_test.go b/encoding/kzg/prover/prover_test.go index 596f61c9b3..54cc8b1d71 100644 --- a/encoding/kzg/prover/prover_test.go +++ b/encoding/kzg/prover/prover_test.go @@ -15,6 +15,7 @@ import ( "github.com/Layr-Labs/eigenda/encoding/utils/codec" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -43,6 +44,7 @@ func setup() { SRSOrder: 3000, SRSNumberToLoad: 2900, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } numNode = uint64(4) @@ -72,9 +74,11 @@ func sampleFrames(frames []encoding.Frame, num uint64) ([]encoding.Frame, []uint } func TestEncoder(t *testing.T) { + p, err := prover.NewProver(kzgConfig, nil) + require.NoError(t, err) - p, _ := prover.NewProver(kzgConfig, true) - v, _ := verifier.NewVerifier(kzgConfig, true) + v, err := verifier.NewVerifier(kzgConfig, nil) + require.NoError(t, err) params := encoding.ParamsFromMins(5, 5) commitments, chunks, err := p.EncodeAndProve(gettysburgAddressBytes, params) @@ -118,8 +122,8 @@ func TestEncoder(t *testing.T) { // pkg: github.com/Layr-Labs/eigenda/core/encoding // BenchmarkEncode-12 1 2421900583 ns/op func BenchmarkEncode(b *testing.B) { - - p, _ := prover.NewProver(kzgConfig, true) + p, err := prover.NewProver(kzgConfig, nil) + require.NoError(b, err) params := encoding.EncodingParams{ ChunkLength: 512, diff --git a/encoding/kzg/verifier/batch_commit_equivalence_test.go b/encoding/kzg/verifier/batch_commit_equivalence_test.go index 6253079e4d..d2f9a67323 100644 --- a/encoding/kzg/verifier/batch_commit_equivalence_test.go +++ b/encoding/kzg/verifier/batch_commit_equivalence_test.go @@ -14,17 +14,21 @@ import ( ) func TestBatchEquivalence(t *testing.T) { + group, err := prover.NewProver(kzgConfig, nil) + require.NoError(t, err) + + v, err := verifier.NewVerifier(kzgConfig, nil) + require.NoError(t, err) - group, _ := prover.NewProver(kzgConfig, true) - v, _ := verifier.NewVerifier(kzgConfig, true) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(gettysburgAddressBytes))) enc, err := group.GetKzgEncoder(params) - require.Nil(t, err) + require.NoError(t, err) inputFr, err := rs.ToFrArray(gettysburgAddressBytes) - require.Nil(t, err) + require.NoError(t, err) + commit, g2commit, _, _, _, err := enc.Encode(inputFr) - require.Nil(t, err) + require.NoError(t, err) numBlob := 5 commitPairs := make([]verifier.CommitmentPair, numBlob) diff --git a/encoding/kzg/verifier/frame_test.go b/encoding/kzg/verifier/frame_test.go index 707d52fee2..912279d426 100644 --- a/encoding/kzg/verifier/frame_test.go +++ b/encoding/kzg/verifier/frame_test.go @@ -16,8 +16,8 @@ import ( ) func TestVerify(t *testing.T) { - - group, _ := prover.NewProver(kzgConfig, true) + group, err := prover.NewProver(kzgConfig, nil) + require.Nil(t, err) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(gettysburgAddressBytes))) @@ -34,10 +34,10 @@ func TestVerify(t *testing.T) { fs := fft.NewFFTSettings(n) require.NotNil(t, fs) - lc := enc.Fs.ExpandedRootsOfUnity[uint64(0)] + lc := fs.ExpandedRootsOfUnity[uint64(0)] require.NotNil(t, lc) - g2Atn, err := kzg.ReadG2Point(uint64(len(frames[0].Coeffs)), kzgConfig) + g2Atn, err := kzg.ReadG2Point(uint64(len(frames[0].Coeffs)), kzgConfig.SRSOrder, kzgConfig.G2Path) require.Nil(t, err) assert.Nil(t, verifier.VerifyFrame(&frames[0], enc.Ks, commit, &lc, &g2Atn)) } diff --git a/encoding/kzg/verifier/length_test.go b/encoding/kzg/verifier/length_test.go index fb6f00df7b..a47d9b5639 100644 --- a/encoding/kzg/verifier/length_test.go +++ b/encoding/kzg/verifier/length_test.go @@ -12,9 +12,12 @@ import ( ) func TestLengthProof(t *testing.T) { + group, err := prover.NewProver(kzgConfig, nil) + require.Nil(t, err) + + v, err := verifier.NewVerifier(kzgConfig, nil) + require.Nil(t, err) - group, _ := prover.NewProver(kzgConfig, true) - v, _ := verifier.NewVerifier(kzgConfig, true) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(gettysburgAddressBytes))) enc, err := group.GetKzgEncoder(params) require.Nil(t, err) diff --git a/encoding/kzg/verifier/multiframe.go b/encoding/kzg/verifier/multiframe.go index 5c27785479..0e294a9def 100644 --- a/encoding/kzg/verifier/multiframe.go +++ b/encoding/kzg/verifier/multiframe.go @@ -180,8 +180,8 @@ func (v *Verifier) UniversalVerify(params encoding.EncodingParams, samples []Sam D := params.ChunkLength - if D > v.SRSNumberToLoad { - return fmt.Errorf("requested chunkLen %v is larger than Loaded SRS points %v", D, v.SRSNumberToLoad) + if D > v.kzgConfig.SRSNumberToLoad { + return fmt.Errorf("requested chunkLen %v is larger than Loaded SRS points %v", D, v.kzgConfig.SRSNumberToLoad) } n := len(samples) @@ -212,11 +212,11 @@ func (v *Verifier) UniversalVerify(params encoding.EncodingParams, samples []Sam } // lhs g2 exponent := uint64(math.Log2(float64(D))) - G2atD, err := kzg.ReadG2PointOnPowerOf2(exponent, v.KzgConfig) + G2atD, err := kzg.ReadG2PointOnPowerOf2(exponent, v.kzgConfig.SRSOrder, v.kzgConfig.G2PowerOf2Path) if err != nil { // then try to access if there is a full list of g2 srs - G2atD, err = kzg.ReadG2Point(D, v.KzgConfig) + G2atD, err = kzg.ReadG2Point(D, v.kzgConfig.SRSOrder, v.kzgConfig.G2Path) if err != nil { return err } diff --git a/encoding/kzg/verifier/multiframe_test.go b/encoding/kzg/verifier/multiframe_test.go index 6312a2f302..5a4aef3937 100644 --- a/encoding/kzg/verifier/multiframe_test.go +++ b/encoding/kzg/verifier/multiframe_test.go @@ -12,9 +12,11 @@ import ( ) func TestUniversalVerify(t *testing.T) { + group, err := prover.NewProver(kzgConfig, nil) + require.Nil(t, err) - group, _ := prover.NewProver(kzgConfig, true) - v, _ := verifier.NewVerifier(kzgConfig, true) + v, err := verifier.NewVerifier(kzgConfig, nil) + require.Nil(t, err) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(gettysburgAddressBytes))) enc, err := group.GetKzgEncoder(params) @@ -54,13 +56,11 @@ func TestUniversalVerify(t *testing.T) { } func TestUniversalVerifyWithPowerOf2G2(t *testing.T) { - kzgConfigCopy := *kzgConfig - group, err := prover.NewProver(&kzgConfigCopy, true) - assert.NoError(t, err) - group.KzgConfig.G2Path = "" + group, err := prover.NewProver(&kzgConfigCopy, nil) + require.Nil(t, err) - v, err := verifier.NewVerifier(kzgConfig, true) + v, err := verifier.NewVerifier(kzgConfig, nil) assert.NoError(t, err) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(gettysburgAddressBytes))) diff --git a/encoding/kzg/verifier/verifier.go b/encoding/kzg/verifier/verifier.go index 56492b1f28..97de7906b5 100644 --- a/encoding/kzg/verifier/verifier.go +++ b/encoding/kzg/verifier/verifier.go @@ -18,22 +18,23 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" + _ "go.uber.org/automaxprocs" ) type Verifier struct { - *kzg.KzgConfig - Srs *kzg.SRS - G2Trailing []bn254.G2Affine - mu sync.Mutex - LoadG2Points bool + kzgConfig *kzg.KzgConfig + encoder *rs.Encoder + + Srs *kzg.SRS + G2Trailing []bn254.G2Affine + mu sync.Mutex ParametrizedVerifiers map[encoding.EncodingParams]*ParametrizedVerifier } var _ encoding.Verifier = &Verifier{} -func NewVerifier(config *kzg.KzgConfig, loadG2Points bool) (*Verifier, error) { - +func NewVerifier(config *kzg.KzgConfig, encoderConfig *encoding.Config) (*Verifier, error) { if config.SRSNumberToLoad > config.SRSOrder { return nil, errors.New("SRSOrder is less than srsNumberToLoad") } @@ -48,7 +49,7 @@ func NewVerifier(config *kzg.KzgConfig, loadG2Points bool) (*Verifier, error) { g2Trailing := make([]bn254.G2Affine, 0) // PreloadEncoder is by default not used by operator node, PreloadEncoder - if loadG2Points { + if config.LoadG2Points { if len(config.G2Path) == 0 { return nil, errors.New("G2Path is empty. However, object needs to load G2Points") } @@ -93,94 +94,91 @@ func NewVerifier(config *kzg.KzgConfig, loadG2Points bool) (*Verifier, error) { fmt.Println("numthread", runtime.GOMAXPROCS(0)) + encoder, err := rs.NewEncoder(encoderConfig) + if err != nil { + return nil, fmt.Errorf("failed to create encoder: %v", err) + } + encoderGroup := &Verifier{ - KzgConfig: config, + kzgConfig: config, + encoder: encoder, Srs: srs, G2Trailing: g2Trailing, ParametrizedVerifiers: make(map[encoding.EncodingParams]*ParametrizedVerifier), - LoadG2Points: loadG2Points, } return encoderGroup, nil - } type ParametrizedVerifier struct { *kzg.KzgConfig Srs *kzg.SRS - *rs.Encoder - Fs *fft.FFTSettings Ks *kzg.KZGSettings } -func (g *Verifier) GetKzgVerifier(params encoding.EncodingParams) (*ParametrizedVerifier, error) { - g.mu.Lock() - defer g.mu.Unlock() +func (v *Verifier) GetKzgVerifier(params encoding.EncodingParams) (*ParametrizedVerifier, error) { + v.mu.Lock() + defer v.mu.Unlock() - if err := params.Validate(); err != nil { + if err := encoding.ValidateEncodingParams(params, v.kzgConfig.SRSOrder); err != nil { return nil, err } - ver, ok := g.ParametrizedVerifiers[params] + ver, ok := v.ParametrizedVerifiers[params] if ok { return ver, nil } - ver, err := g.newKzgVerifier(params) + ver, err := v.newKzgVerifier(params) if err == nil { - g.ParametrizedVerifiers[params] = ver + v.ParametrizedVerifiers[params] = ver } return ver, err } func (g *Verifier) NewKzgVerifier(params encoding.EncodingParams) (*ParametrizedVerifier, error) { - g.mu.Lock() - defer g.mu.Unlock() return g.newKzgVerifier(params) } -func (g *Verifier) newKzgVerifier(params encoding.EncodingParams) (*ParametrizedVerifier, error) { - +func (v *Verifier) newKzgVerifier(params encoding.EncodingParams) (*ParametrizedVerifier, error) { if err := params.Validate(); err != nil { return nil, err } + // Create FFT settings based on params n := uint8(math.Log2(float64(params.NumEvaluations()))) fs := fft.NewFFTSettings(n) - ks, err := kzg.NewKZGSettings(fs, g.Srs) + // Create KZG settings + ks, err := kzg.NewKZGSettings(fs, v.Srs) if err != nil { - return nil, err - } - - encoder, err := rs.NewEncoder(params, g.Verbose) - if err != nil { - log.Println("Could not create encoder: ", err) - return nil, err + return nil, fmt.Errorf("failed to create KZG settings: %w", err) } return &ParametrizedVerifier{ - KzgConfig: g.KzgConfig, - Srs: g.Srs, - Encoder: encoder, + KzgConfig: v.kzgConfig, + Srs: v.Srs, Fs: fs, Ks: ks, }, nil } func (v *Verifier) VerifyBlobLength(commitments encoding.BlobCommitments) error { - return v.VerifyCommit((*bn254.G2Affine)(commitments.LengthCommitment), (*bn254.G2Affine)(commitments.LengthProof), uint64(commitments.Length)) - + return v.VerifyCommit( + (*bn254.G2Affine)(commitments.LengthCommitment), + (*bn254.G2Affine)(commitments.LengthProof), + uint64(commitments.Length), + ) } // VerifyCommit verifies the low degree proof; since it doesn't depend on the encoding parameters // we leave it as a method of the KzgEncoderGroup func (v *Verifier) VerifyCommit(lengthCommit *bn254.G2Affine, lengthProof *bn254.G2Affine, length uint64) error { - g1Challenge, err := kzg.ReadG1Point(v.SRSOrder-length, v.KzgConfig) + g1Challenge, err := kzg.ReadG1Point(v.kzgConfig.SRSOrder-length, v.kzgConfig.SRSOrder, v.kzgConfig.G1Path) if err != nil { return err } @@ -215,6 +213,7 @@ func (v *Verifier) VerifyFrames(frames []*encoding.Frame, indices []encoding.Chu (*bn254.G1Affine)(commitments.Commitment), frames[ind], uint64(indices[ind]), + params.NumChunks, ) if err != nil { @@ -226,17 +225,17 @@ func (v *Verifier) VerifyFrames(frames []*encoding.Frame, indices []encoding.Chu } -func (v *ParametrizedVerifier) VerifyFrame(commit *bn254.G1Affine, f *encoding.Frame, index uint64) error { +func (v *ParametrizedVerifier) VerifyFrame(commit *bn254.G1Affine, f *encoding.Frame, index uint64, numChunks uint64) error { j, err := rs.GetLeadingCosetIndex( uint64(index), - v.NumChunks, + numChunks, ) if err != nil { return err } - g2Atn, err := kzg.ReadG2Point(uint64(len(f.Coeffs)), v.KzgConfig) + g2Atn, err := kzg.ReadG2Point(uint64(len(f.Coeffs)), v.KzgConfig.SRSOrder, v.KzgConfig.G2Path) if err != nil { return err } @@ -301,12 +300,8 @@ func (v *Verifier) Decode(chunks []*encoding.Frame, indices []encoding.ChunkNumb Coeffs: chunks[i].Coeffs, } } - encoder, err := v.GetKzgVerifier(params) - if err != nil { - return nil, err - } - return encoder.Decode(frames, toUint64Array(indices), maxInputSize) + return v.encoder.Decode(frames, toUint64Array(indices), maxInputSize, params) } func toUint64Array(chunkIndices []encoding.ChunkNumber) []uint64 { diff --git a/encoding/kzg/verifier/verifier_test.go b/encoding/kzg/verifier/verifier_test.go index fbcd4e419c..817e2de2a3 100644 --- a/encoding/kzg/verifier/verifier_test.go +++ b/encoding/kzg/verifier/verifier_test.go @@ -14,6 +14,7 @@ import ( "github.com/Layr-Labs/eigenda/encoding/kzg/verifier" "github.com/Layr-Labs/eigenda/encoding/utils/codec" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -42,6 +43,7 @@ func setup() { SRSOrder: 3000, SRSNumberToLoad: 2900, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } numNode = uint64(4) @@ -59,9 +61,11 @@ func teardown() { func TestBenchmarkVerifyChunks(t *testing.T) { t.Skip("This test is meant to be run manually, not as part of the test suite") + p, err := prover.NewProver(kzgConfig, nil) + require.NoError(t, err) - p, _ := prover.NewProver(kzgConfig, true) - v, _ := verifier.NewVerifier(kzgConfig, true) + v, err := verifier.NewVerifier(kzgConfig, nil) + require.NoError(t, err) chunkLengths := []uint64{64, 128, 256, 512, 1024, 2048, 4096, 8192} chunkCounts := []int{4, 8, 16} @@ -112,9 +116,11 @@ func TestBenchmarkVerifyChunks(t *testing.T) { } func BenchmarkVerifyBlob(b *testing.B) { + p, err := prover.NewProver(kzgConfig, nil) + require.NoError(b, err) - p, _ := prover.NewProver(kzgConfig, true) - v, _ := verifier.NewVerifier(kzgConfig, true) + v, err := verifier.NewVerifier(kzgConfig, nil) + require.NoError(b, err) params := encoding.EncodingParams{ ChunkLength: 256, diff --git a/encoding/rs/decode.go b/encoding/rs/decode.go index 6bb67be3e6..5f56a08bf1 100644 --- a/encoding/rs/decode.go +++ b/encoding/rs/decode.go @@ -17,7 +17,13 @@ import ( // maxInputSize is the upper bound of the original data size. This is needed because // the frames and indices don't encode the length of the original data. If maxInputSize // is smaller than the original input size, decoded data will be trimmed to fit the maxInputSize. -func (g *Encoder) Decode(frames []Frame, indices []uint64, maxInputSize uint64) ([]byte, error) { +func (e *Encoder) Decode(frames []Frame, indices []uint64, maxInputSize uint64, params encoding.EncodingParams) ([]byte, error) { + // Get encoder + g, err := e.GetRsEncoder(params) + if err != nil { + return nil, err + } + numSys := encoding.GetNumSys(maxInputSize, g.ChunkLength) if uint64(len(frames)) < numSys { @@ -42,7 +48,6 @@ func (g *Encoder) Decode(frames []Frame, indices []uint64, maxInputSize uint64) for j := uint64(0); j < g.ChunkLength; j++ { p := j*g.NumChunks + uint64(e) samples[p] = new(fr.Element) - samples[p].Set(&evals[j]) } } diff --git a/encoding/rs/encode.go b/encoding/rs/encode.go index 5a8b5c4496..f467b8fca7 100644 --- a/encoding/rs/encode.go +++ b/encoding/rs/encode.go @@ -2,12 +2,10 @@ package rs import ( "fmt" - "log" + "log/slog" "time" "github.com/Layr-Labs/eigenda/encoding" - rb "github.com/Layr-Labs/eigenda/encoding/utils/reverseBits" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" ) @@ -17,12 +15,12 @@ type GlobalPoly struct { } // just a wrapper to take bytes not Fr Element -func (g *Encoder) EncodeBytes(inputBytes []byte) ([]Frame, []uint32, error) { +func (g *Encoder) EncodeBytes(inputBytes []byte, params encoding.EncodingParams) ([]Frame, []uint32, error) { inputFr, err := ToFrArray(inputBytes) if err != nil { return nil, nil, fmt.Errorf("cannot convert bytes to field elements, %w", err) } - return g.Encode(inputFr) + return g.Encode(inputFr, params) } // Encode function takes input in unit of Fr Element, creates a kzg commit and a list of frames @@ -32,140 +30,48 @@ func (g *Encoder) EncodeBytes(inputBytes []byte) ([]Frame, []uint32, error) { // frame, the multireveal interpolating coefficients are identical to the part of input bytes // in the form of field element. The extra returned integer list corresponds to which leading // coset root of unity, the frame is proving against, which can be deduced from a frame's index -func (g *Encoder) Encode(inputFr []fr.Element) ([]Frame, []uint32, error) { +func (g *Encoder) Encode(inputFr []fr.Element, params encoding.EncodingParams) ([]Frame, []uint32, error) { start := time.Now() intermediate := time.Now() - pdCoeffs, err := g.PadPolyEval(inputFr) + // Get RS encoder from params + encoder, err := g.GetRsEncoder(params) if err != nil { return nil, nil, err } - polyEvals, err := g.Computer.ExtendPolyEval(pdCoeffs) + pdCoeffs, err := encoder.PadPolyEval(inputFr) if err != nil { return nil, nil, err } + paddingDuration := time.Since(intermediate) - if g.verbose { - log.Printf(" Extending evaluation takes %v\n", time.Since(intermediate)) - } + intermediate = time.Now() - // create frames to group relevant info - frames, indices, err := g.MakeFrames(polyEvals) + polyEvals, err := encoder.RSEncoderComputer.ExtendPolyEval(pdCoeffs) if err != nil { return nil, nil, err } + extensionDuration := time.Since(intermediate) - log.Printf(" SUMMARY: RSEncode %v byte among %v numChunks with chunkLength %v takes %v\n", - len(inputFr)*encoding.BYTES_PER_SYMBOL, g.NumChunks, g.ChunkLength, time.Since(start)) - - return frames, indices, nil -} - -// PadPolyEval pads the input polynomial coefficients to match the number of evaluations -// required by the encoder. -func (g *Encoder) PadPolyEval(coeffs []fr.Element) ([]fr.Element, error) { - numEval := int(g.NumEvaluations()) - if len(coeffs) > numEval { - return nil, fmt.Errorf("the provided encoding parameters are not sufficient for the size of the data input") - } - - pdCoeffs := make([]fr.Element, numEval) - copy(pdCoeffs, coeffs) + intermediate = time.Now() - // Pad the remaining elements with zeroes - for i := len(coeffs); i < numEval; i++ { - pdCoeffs[i].SetZero() - } - - return pdCoeffs, nil -} - -// MakeFrames function takes extended evaluation data and bundles relevant information into Frame. -// Every frame is verifiable to the commitment. -func (g *Encoder) MakeFrames( - polyEvals []fr.Element, -) ([]Frame, []uint32, error) { - // reverse dataFr making easier to sample points - err := rb.ReverseBitOrderFr(polyEvals) + // create frames to group relevant info + frames, indices, err := encoder.MakeFrames(polyEvals) if err != nil { return nil, nil, err } - indices := make([]uint32, 0) - frames := make([]Frame, g.NumChunks) + framesDuration := time.Since(intermediate) - numWorker := uint64(g.NumRSWorker) - - if numWorker > g.NumChunks { - numWorker = g.NumChunks - } - - jobChan := make(chan JobRequest, numWorker) - results := make(chan error, numWorker) - - for w := uint64(0); w < numWorker; w++ { - go g.interpolyWorker( - polyEvals, - jobChan, - results, - frames, - ) - } - - for i := uint64(0); i < g.NumChunks; i++ { - j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i)) - jr := JobRequest{ - Index: i, - } - jobChan <- jr - indices = append(indices, j) - } - close(jobChan) - - for w := uint64(0); w < numWorker; w++ { - interPolyErr := <-results - if interPolyErr != nil { - err = interPolyErr - } - } - - if err != nil { - return nil, nil, fmt.Errorf("proof worker error: %v", err) - } + slog.Info("RSEncode details", + "input_size_bytes", len(inputFr)*encoding.BYTES_PER_SYMBOL, + "num_chunks", encoder.NumChunks, + "chunk_length", encoder.ChunkLength, + "padding_duration", paddingDuration, + "extension_duration", extensionDuration, + "frames_duration", framesDuration, + "total_duration", time.Since(start)) return frames, indices, nil } - -type JobRequest struct { - Index uint64 -} - -func (g *Encoder) interpolyWorker( - polyEvals []fr.Element, - jobChan <-chan JobRequest, - results chan<- error, - frames []Frame, -) { - - for jr := range jobChan { - i := jr.Index - j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i)) - ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)] - err := rb.ReverseBitOrderFr(ys) - if err != nil { - results <- err - continue - } - coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j)) - if err != nil { - results <- err - continue - } - - frames[i].Coeffs = coeffs - } - - results <- nil - -} diff --git a/encoding/rs/encode_test.go b/encoding/rs/encode_test.go index 59902761e5..7b0b3f16b7 100644 --- a/encoding/rs/encode_test.go +++ b/encoding/rs/encode_test.go @@ -2,16 +2,12 @@ package rs_test import ( "fmt" - "math" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/Layr-Labs/eigenda/encoding" - "github.com/Layr-Labs/eigenda/encoding/fft" "github.com/Layr-Labs/eigenda/encoding/rs" - rs_cpu "github.com/Layr-Labs/eigenda/encoding/rs/cpu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEncodeDecode_InvertsWhenSamplingAllFrames(t *testing.T) { @@ -20,30 +16,18 @@ func TestEncodeDecode_InvertsWhenSamplingAllFrames(t *testing.T) { params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(GETTYSBURG_ADDRESS_BYTES))) - enc, _ := rs.NewEncoder(params, true) - - n := uint8(math.Log2(float64(enc.NumEvaluations()))) - if enc.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * enc.NumChunks))) - } - fs := fft.NewFFTSettings(n) - - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - enc.Computer = RsComputeDevice - require.NotNil(t, enc) + cfg := encoding.DefaultConfig() + enc, err := rs.NewEncoder(cfg) + assert.Nil(t, err) inputFr, err := rs.ToFrArray(GETTYSBURG_ADDRESS_BYTES) assert.Nil(t, err) - frames, _, err := enc.Encode(inputFr) + frames, _, err := enc.Encode(inputFr, params) assert.Nil(t, err) // sample some frames samples, indices := sampleFrames(frames, uint64(len(frames))) - data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES))) + data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)), params) require.Nil(t, err) require.NotNil(t, data) @@ -56,30 +40,19 @@ func TestEncodeDecode_InvertsWhenSamplingMissingFrame(t *testing.T) { defer teardownSuite(t) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(GETTYSBURG_ADDRESS_BYTES))) - enc, _ := rs.NewEncoder(params, true) - - n := uint8(math.Log2(float64(enc.NumEvaluations()))) - if enc.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * enc.NumChunks))) - } - fs := fft.NewFFTSettings(n) - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - enc.Computer = RsComputeDevice - require.NotNil(t, enc) + cfg := encoding.DefaultConfig() + enc, err := rs.NewEncoder(cfg) + assert.Nil(t, err) inputFr, err := rs.ToFrArray(GETTYSBURG_ADDRESS_BYTES) assert.Nil(t, err) - frames, _, err := enc.Encode(inputFr) + frames, _, err := enc.Encode(inputFr, params) assert.Nil(t, err) // sample some frames samples, indices := sampleFrames(frames, uint64(len(frames)-1)) - data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES))) + data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)), params) require.Nil(t, err) require.NotNil(t, data) @@ -92,32 +65,20 @@ func TestEncodeDecode_ErrorsWhenNotEnoughSampledFrames(t *testing.T) { defer teardownSuite(t) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(GETTYSBURG_ADDRESS_BYTES))) - enc, _ := rs.NewEncoder(params, true) - - n := uint8(math.Log2(float64(enc.NumEvaluations()))) - if enc.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * enc.NumChunks))) - } - fs := fft.NewFFTSettings(n) - - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - enc.Computer = RsComputeDevice - require.NotNil(t, enc) + cfg := encoding.DefaultConfig() + enc, err := rs.NewEncoder(cfg) + assert.Nil(t, err) - fmt.Println("Num Chunks: ", enc.NumChunks) + fmt.Println("Num Chunks: ", params.NumChunks) inputFr, err := rs.ToFrArray(GETTYSBURG_ADDRESS_BYTES) assert.Nil(t, err) - frames, _, err := enc.Encode(inputFr) + frames, _, err := enc.Encode(inputFr, params) assert.Nil(t, err) // sample some frames samples, indices := sampleFrames(frames, uint64(len(frames)-2)) - data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES))) + data, err := enc.Decode(samples, indices, uint64(len(GETTYSBURG_ADDRESS_BYTES)), params) require.Nil(t, data) require.NotNil(t, err) diff --git a/encoding/rs/encoder.go b/encoding/rs/encoder.go index fda65d13b0..e78eaa3179 100644 --- a/encoding/rs/encoder.go +++ b/encoding/rs/encoder.go @@ -1,39 +1,61 @@ package rs import ( + "errors" + "fmt" "math" - "runtime" + "sync" "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/encoding/fft" + gnarkencoder "github.com/Layr-Labs/eigenda/encoding/rs/gnark" "github.com/consensys/gnark-crypto/ecc/bn254/fr" + _ "go.uber.org/automaxprocs" ) type Encoder struct { - encoding.EncodingParams + Config *encoding.Config - Fs *fft.FFTSettings + mu sync.Mutex + ParametrizedEncoder map[encoding.EncodingParams]*ParametrizedEncoder +} - verbose bool +// Proof device represents a device capable of computing reed-solomon operations. +type EncoderDevice interface { + ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, error) +} + +// NewEncoder creates a new encoder with the given options +func NewEncoder(config *encoding.Config) (*Encoder, error) { + if config == nil { + config = encoding.DefaultConfig() + } - NumRSWorker int + e := &Encoder{ + Config: config, + mu: sync.Mutex{}, + ParametrizedEncoder: make(map[encoding.EncodingParams]*ParametrizedEncoder), + } - Computer RsComputeDevice + return e, nil } -// RsComputeDevice represents a device capable of performing Reed-Solomon encoding computations. -// Implementations of this interface are expected to handle polynomial evaluation extensions. -type RsComputeDevice interface { - // ExtendPolyEval extends the evaluation of a polynomial given its coefficients. - // It takes a slice of polynomial coefficients and returns an extended evaluation. - // - // Parameters: - // - coeffs: A slice of fr.Element representing the polynomial coefficients. - // - // Returns: - // - A slice of fr.Element representing the extended polynomial evaluation. - // - An error if the extension process fails. - ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, error) +// GetRsEncoder returns a parametrized encoder for the given parameters. +// It caches the encoder for reuse. +func (g *Encoder) GetRsEncoder(params encoding.EncodingParams) (*ParametrizedEncoder, error) { + g.mu.Lock() + defer g.mu.Unlock() + enc, ok := g.ParametrizedEncoder[params] + if ok { + return enc, nil + } + + enc, err := g.newEncoder(params) + if err == nil { + g.ParametrizedEncoder[params] = enc + } + + return enc, err } // The function creates a high level struct that determines the encoding the a data of a @@ -44,21 +66,42 @@ type RsComputeDevice interface { // original data. When some systematic chunks are missing but identical parity chunk are // available, the receive can go through a Reed Solomon decoding to reconstruct the // original data. -func NewEncoder(params encoding.EncodingParams, verbose bool) (*Encoder, error) { - +func (e *Encoder) newEncoder(params encoding.EncodingParams) (*ParametrizedEncoder, error) { err := params.Validate() if err != nil { return nil, err } + fs := e.CreateFFTSettings(params) + + switch e.Config.BackendType { + case encoding.GnarkBackend: + return e.createGnarkBackendEncoder(params, fs) + case encoding.IcicleBackend: + return e.createIcicleBackendEncoder(params, fs) + default: + return nil, fmt.Errorf("unsupported backend type: %v", e.Config.BackendType) + } +} + +func (e *Encoder) CreateFFTSettings(params encoding.EncodingParams) *fft.FFTSettings { n := uint8(math.Log2(float64(params.NumEvaluations()))) - fs := fft.NewFFTSettings(n) + return fft.NewFFTSettings(n) +} - return &Encoder{ - EncodingParams: params, - Fs: fs, - verbose: verbose, - NumRSWorker: runtime.GOMAXPROCS(0), +func (e *Encoder) createGnarkBackendEncoder(params encoding.EncodingParams, fs *fft.FFTSettings) (*ParametrizedEncoder, error) { + if e.Config.GPUEnable { + return nil, errors.New("GPU is not supported in gnark backend") + } + + return &ParametrizedEncoder{ + Config: e.Config, + EncodingParams: params, + Fs: fs, + RSEncoderComputer: &gnarkencoder.RsGnarkBackend{Fs: fs}, }, nil +} +func (e *Encoder) createIcicleBackendEncoder(params encoding.EncodingParams, fs *fft.FFTSettings) (*ParametrizedEncoder, error) { + return CreateIcicleBackendEncoder(e, params, fs) } diff --git a/encoding/rs/encoder_fuzz_test.go b/encoding/rs/encoder_fuzz_test.go index a23e3c2a6e..e0cca75788 100644 --- a/encoding/rs/encoder_fuzz_test.go +++ b/encoding/rs/encoder_fuzz_test.go @@ -1,15 +1,11 @@ package rs_test import ( - "math" "testing" "github.com/Layr-Labs/eigenda/encoding" - "github.com/Layr-Labs/eigenda/encoding/fft" "github.com/Layr-Labs/eigenda/encoding/rs" - rs_cpu "github.com/Layr-Labs/eigenda/encoding/rs/cpu" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func FuzzOnlySystematic(f *testing.F) { @@ -18,27 +14,14 @@ func FuzzOnlySystematic(f *testing.F) { f.Fuzz(func(t *testing.T, input []byte) { params := encoding.ParamsFromSysPar(10, 3, uint64(len(input))) - enc, err := rs.NewEncoder(params, true) + cfg := encoding.DefaultConfig() + enc, err := rs.NewEncoder(cfg) if err != nil { t.Errorf("Error making rs: %q", err) } - n := uint8(math.Log2(float64(enc.NumEvaluations()))) - if enc.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * enc.NumChunks))) - } - fs := fft.NewFFTSettings(n) - - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - enc.Computer = RsComputeDevice - require.NotNil(t, enc) - //encode the data - frames, _, err := enc.EncodeBytes(input) + frames, _, err := enc.EncodeBytes(input, params) if err != nil { t.Errorf("Error Encoding:\n Data:\n %q \n Err: %q", input, err) } @@ -46,7 +29,7 @@ func FuzzOnlySystematic(f *testing.F) { //sample the correct systematic frames samples, indices := sampleFrames(frames, uint64(len(frames))) - data, err := enc.Decode(samples, indices, uint64(len(input))) + data, err := enc.Decode(samples, indices, uint64(len(input)), params) if err != nil { t.Errorf("Error Decoding:\n Data:\n %q \n Err: %q", input, err) } diff --git a/encoding/rs/frame_test.go b/encoding/rs/frame_test.go index ccf664af15..3c8237178a 100644 --- a/encoding/rs/frame_test.go +++ b/encoding/rs/frame_test.go @@ -2,13 +2,10 @@ package rs_test import ( "fmt" - "math" "testing" "github.com/Layr-Labs/eigenda/encoding" - "github.com/Layr-Labs/eigenda/encoding/fft" "github.com/Layr-Labs/eigenda/encoding/rs" - rs_cpu "github.com/Layr-Labs/eigenda/encoding/rs/cpu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,23 +15,11 @@ func TestEncodeDecodeFrame_AreInverses(t *testing.T) { defer teardownSuite(t) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(GETTYSBURG_ADDRESS_BYTES))) - enc, _ := rs.NewEncoder(params, true) + cfg := encoding.DefaultConfig() + enc, err := rs.NewEncoder(cfg) + assert.Nil(t, err) - n := uint8(math.Log2(float64(enc.NumEvaluations()))) - if enc.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * enc.NumChunks))) - } - fs := fft.NewFFTSettings(n) - - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - enc.Computer = RsComputeDevice - require.NotNil(t, enc) - - frames, _, err := enc.EncodeBytes(GETTYSBURG_ADDRESS_BYTES) + frames, _, err := enc.EncodeBytes(GETTYSBURG_ADDRESS_BYTES, params) require.Nil(t, err) require.NotNil(t, frames, err) @@ -54,23 +39,11 @@ func TestGnarkEncodeDecodeFrame_AreInverses(t *testing.T) { defer teardownSuite(t) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(GETTYSBURG_ADDRESS_BYTES))) - enc, _ := rs.NewEncoder(params, true) - - n := uint8(math.Log2(float64(enc.NumEvaluations()))) - if enc.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * enc.NumChunks))) - } - fs := fft.NewFFTSettings(n) + cfg := encoding.DefaultConfig() + enc, err := rs.NewEncoder(cfg) + assert.Nil(t, err) - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - enc.Computer = RsComputeDevice - require.NotNil(t, enc) - - frames, _, err := enc.EncodeBytes(GETTYSBURG_ADDRESS_BYTES) + frames, _, err := enc.EncodeBytes(GETTYSBURG_ADDRESS_BYTES, params) require.Nil(t, err) require.NotNil(t, frames, err) @@ -91,23 +64,11 @@ func TestGnarkEncodeDecodeFrames_AreInverses(t *testing.T) { defer teardownSuite(t) params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(GETTYSBURG_ADDRESS_BYTES))) - enc, _ := rs.NewEncoder(params, true) - - n := uint8(math.Log2(float64(enc.NumEvaluations()))) - if enc.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * enc.NumChunks))) - } - fs := fft.NewFFTSettings(n) - - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - enc.Computer = RsComputeDevice - require.NotNil(t, enc) + cfg := encoding.DefaultConfig() + enc, err := rs.NewEncoder(cfg) + assert.Nil(t, err) - frames, _, err := enc.EncodeBytes(GETTYSBURG_ADDRESS_BYTES) + frames, _, err := enc.EncodeBytes(GETTYSBURG_ADDRESS_BYTES, params) assert.NoError(t, err) framesPointers := make([]*rs.Frame, len(frames)) diff --git a/encoding/rs/cpu/extend_poly.go b/encoding/rs/gnark/extend_poly.go similarity index 56% rename from encoding/rs/cpu/extend_poly.go rename to encoding/rs/gnark/extend_poly.go index f4e40be0b0..0de2b8d8ea 100644 --- a/encoding/rs/cpu/extend_poly.go +++ b/encoding/rs/gnark/extend_poly.go @@ -1,19 +1,16 @@ -package cpu +package gnark import ( - "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/encoding/fft" "github.com/consensys/gnark-crypto/ecc/bn254/fr" ) -type RsCpuComputeDevice struct { +type RsGnarkBackend struct { Fs *fft.FFTSettings - - encoding.EncodingParams } // Encoding Reed Solomon using FFT -func (g *RsCpuComputeDevice) ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, error) { +func (g *RsGnarkBackend) ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, error) { evals, err := g.Fs.FFT(coeffs, false) if err != nil { return nil, err diff --git a/encoding/rs/icicle.go b/encoding/rs/icicle.go new file mode 100644 index 0000000000..da559d77cd --- /dev/null +++ b/encoding/rs/icicle.go @@ -0,0 +1,38 @@ +//go:build icicle + +package rs + +import ( + "sync" + + "github.com/Layr-Labs/eigenda/encoding" + "github.com/Layr-Labs/eigenda/encoding/fft" + "github.com/Layr-Labs/eigenda/encoding/icicle" + rsicicle "github.com/Layr-Labs/eigenda/encoding/rs/icicle" +) + +const ( + defaultNTTSize = 25 // Used for NTT setup in Icicle backend +) + +func CreateIcicleBackendEncoder(e *Encoder, params encoding.EncodingParams, fs *fft.FFTSettings) (*ParametrizedEncoder, error) { + icicleDevice, err := icicle.NewIcicleDevice(icicle.IcicleDeviceConfig{ + GPUEnable: e.Config.GPUEnable, + NTTSize: defaultNTTSize, + // No MSM setup needed for encoder + }) + if err != nil { + return nil, err + } + + return &ParametrizedEncoder{ + Config: e.Config, + EncodingParams: params, + Fs: fs, + RSEncoderComputer: &rsicicle.RsIcicleBackend{ + NttCfg: icicleDevice.NttCfg, + Device: icicleDevice.Device, + GpuLock: sync.Mutex{}, + }, + }, nil +} diff --git a/encoding/rs/icicle/extend_poly.go b/encoding/rs/icicle/extend_poly.go new file mode 100644 index 0000000000..b3c82062f8 --- /dev/null +++ b/encoding/rs/icicle/extend_poly.go @@ -0,0 +1,65 @@ +//go:build icicle + +package icicle + +import ( + "fmt" + "sync" + + "github.com/Layr-Labs/eigenda/encoding/icicle" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" + iciclebn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/ntt" + icicleRuntime "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" +) + +type RsIcicleBackend struct { + NttCfg core.NTTConfig[[iciclebn254.SCALAR_LIMBS]uint32] + Device icicleRuntime.Device + GpuLock sync.Mutex +} + +// Encoding Reed Solomon using FFT +func (g *RsIcicleBackend) ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, error) { + // Lock the GPU for operations + g.GpuLock.Lock() + defer g.GpuLock.Unlock() + + // Convert and prepare data + g.NttCfg.BatchSize = int32(1) + scalarsSF := icicle.ConvertFrToScalarFieldsBytes(coeffs) + scalars := core.HostSliceFromElements[iciclebn254.ScalarField](scalarsSF) + outputDevice := make(core.HostSlice[iciclebn254.ScalarField], len(coeffs)) + + // Set device + err := icicleRuntime.SetDevice(&g.Device) + if err != icicleRuntime.Success { + return nil, fmt.Errorf("failed to set device: %v", err.AsString()) + } + + // Perform NTT + var icicleErr error + wg := sync.WaitGroup{} + wg.Add(1) + icicleRuntime.RunOnDevice(&g.Device, func(args ...any) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + icicleErr = fmt.Errorf("GPU operation panic: %v", r) + } + }() + + ntt.Ntt(scalars, core.KForward, &g.NttCfg, outputDevice) + }) + + wg.Wait() + + // Check if there was a panic + if icicleErr != nil { + return nil, icicleErr + } + + evals := icicle.ConvertScalarFieldsToFrBytes(outputDevice) + return evals, nil +} diff --git a/encoding/rs/interpolation.go b/encoding/rs/interpolation.go index 405799d0d2..e73d3b489c 100644 --- a/encoding/rs/interpolation.go +++ b/encoding/rs/interpolation.go @@ -16,7 +16,7 @@ import ( // The reason behind is because Reed Solomon extension using FFT insert evaluation within original // Data. i.e. [o_1, o_2, o_3..] with coding ratio 0.5 becomes [o_1, p_1, o_2, p_2...] -func (g *Encoder) GetInterpolationPolyEval( +func (g *ParametrizedEncoder) GetInterpolationPolyEval( interpolationPoly []fr.Element, j uint32, ) ([]fr.Element, error) { @@ -62,7 +62,7 @@ func (g *Encoder) GetInterpolationPolyEval( } // Since both F W are invertible, c = W^-1 F^-1 d, convert it back. F W W^-1 F^-1 d = c -func (g *Encoder) GetInterpolationPolyCoeff(chunk []fr.Element, k uint32) ([]fr.Element, error) { +func (g *ParametrizedEncoder) GetInterpolationPolyCoeff(chunk []fr.Element, k uint32) ([]fr.Element, error) { coeffs := make([]fr.Element, g.ChunkLength) shiftedInterpolationPoly := make([]fr.Element, len(chunk)) err := g.Fs.InplaceFFT(chunk, shiftedInterpolationPoly, true) diff --git a/encoding/rs/noicicle.go b/encoding/rs/noicicle.go new file mode 100644 index 0000000000..74c1720671 --- /dev/null +++ b/encoding/rs/noicicle.go @@ -0,0 +1,15 @@ +//go:build !icicle + +package rs + +import ( + "errors" + + "github.com/Layr-Labs/eigenda/encoding" + "github.com/Layr-Labs/eigenda/encoding/fft" +) + +func CreateIcicleBackendEncoder(p *Encoder, params encoding.EncodingParams, fs *fft.FFTSettings) (*ParametrizedEncoder, error) { + // Not supported + return nil, errors.New("icicle backend called without icicle build tag") +} diff --git a/encoding/rs/parametrized_encoder.go b/encoding/rs/parametrized_encoder.go new file mode 100644 index 0000000000..62edc168f4 --- /dev/null +++ b/encoding/rs/parametrized_encoder.go @@ -0,0 +1,125 @@ +package rs + +import ( + "fmt" + + "github.com/Layr-Labs/eigenda/encoding" + "github.com/Layr-Labs/eigenda/encoding/fft" + rb "github.com/Layr-Labs/eigenda/encoding/utils/reverseBits" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +type ParametrizedEncoder struct { + *encoding.Config + encoding.EncodingParams + Fs *fft.FFTSettings + RSEncoderComputer EncoderDevice +} + +// PadPolyEval pads the input polynomial coefficients to match the number of evaluations +// required by the encoder. +func (g *ParametrizedEncoder) PadPolyEval(coeffs []fr.Element) ([]fr.Element, error) { + numEval := int(g.NumEvaluations()) + + if len(coeffs) > numEval { + return nil, fmt.Errorf("the provided encoding parameters are not sufficient for the size of the data input") + } + + pdCoeffs := make([]fr.Element, numEval) + copy(pdCoeffs, coeffs) + + // Pad the remaining elements with zeroes + for i := len(coeffs); i < numEval; i++ { + pdCoeffs[i].SetZero() + } + + return pdCoeffs, nil +} + +// MakeFrames function takes extended evaluation data and bundles relevant information into Frame. +// Every frame is verifiable to the commitment. +func (g *ParametrizedEncoder) MakeFrames( + polyEvals []fr.Element, +) ([]Frame, []uint32, error) { + // reverse dataFr making easier to sample points + err := rb.ReverseBitOrderFr(polyEvals) + if err != nil { + return nil, nil, err + } + + indices := make([]uint32, 0) + frames := make([]Frame, g.NumChunks) + + numWorker := uint64(g.Config.NumWorker) + if numWorker > g.NumChunks { + numWorker = g.NumChunks + } + + jobChan := make(chan JobRequest, numWorker) + results := make(chan error, numWorker) + + for w := uint64(0); w < numWorker; w++ { + go g.interpolyWorker( + polyEvals, + jobChan, + results, + frames, + ) + } + + for i := uint64(0); i < g.NumChunks; i++ { + j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i)) + jr := JobRequest{ + Index: i, + } + jobChan <- jr + indices = append(indices, j) + } + close(jobChan) + + for w := uint64(0); w < numWorker; w++ { + interPolyErr := <-results + if interPolyErr != nil { + err = interPolyErr + } + } + + if err != nil { + return nil, nil, fmt.Errorf("proof worker error: %v", err) + } + + return frames, indices, nil +} + +type JobRequest struct { + Index uint64 +} + +func (g *ParametrizedEncoder) interpolyWorker( + polyEvals []fr.Element, + jobChan <-chan JobRequest, + results chan<- error, + frames []Frame, +) { + + for jr := range jobChan { + i := jr.Index + j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i)) + ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)] + err := rb.ReverseBitOrderFr(ys) + if err != nil { + results <- err + continue + } + coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j)) + if err != nil { + results <- err + continue + } + + frames[i].Coeffs = coeffs + } + + results <- nil + +} diff --git a/encoding/rs/params.go b/encoding/rs/params.go index 8bd39b1746..d3244a6f23 100644 --- a/encoding/rs/params.go +++ b/encoding/rs/params.go @@ -24,7 +24,6 @@ func (p EncodingParams) NumEvaluations() uint64 { } func (p EncodingParams) Validate() error { - if NextPowerOf2(p.NumChunks) != p.NumChunks { return ErrInvalidParams } @@ -51,14 +50,12 @@ func ParamsFromMins(numChunks, chunkLen uint64) EncodingParams { NumChunks: numChunks, ChunkLen: chunkLen, } - } func GetEncodingParams(numSys, numPar, dataSize uint64) EncodingParams { - numNodes := numSys + numPar dataLen := RoundUpDivision(dataSize, encoding.BYTES_PER_SYMBOL) chunkLen := RoundUpDivision(dataLen, numSys) - return ParamsFromMins(numNodes, chunkLen) + return ParamsFromMins(numNodes, chunkLen) } diff --git a/encoding/rs/utils_test.go b/encoding/rs/utils_test.go index 44b139ebd6..19c31f969e 100644 --- a/encoding/rs/utils_test.go +++ b/encoding/rs/utils_test.go @@ -38,8 +38,9 @@ func TestToFrArrayAndToByteArray_AreInverses(t *testing.T) { numEle := rs.GetNumElement(1000, encoding.BYTES_PER_SYMBOL) assert.Equal(t, numEle, uint64(32)) - params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(GETTYSBURG_ADDRESS_BYTES))) - enc, _ := rs.NewEncoder(params, true) + cfg := encoding.DefaultConfig() + enc, err := rs.NewEncoder(cfg) + assert.Nil(t, err) require.NotNil(t, enc) dataFr, err := rs.ToFrArray(GETTYSBURG_ADDRESS_BYTES) diff --git a/encoding/test/main.go b/encoding/test/main.go index fb8314c4dd..665b3f642e 100644 --- a/encoding/test/main.go +++ b/encoding/test/main.go @@ -69,11 +69,14 @@ func TestKzgRs() { SRSOrder: 3000, SRSNumberToLoad: 3000, NumWorker: uint64(runtime.GOMAXPROCS(0)), - Verbose: true, + LoadG2Points: true, } // create encoding object - p, _ := prover.NewProver(kzgConfig, true) + p, err := prover.NewProver(kzgConfig, nil) + if err != nil { + log.Fatalf("Failed to create prover: %v", err) + } params := encoding.EncodingParams{NumChunks: numNode, ChunkLength: uint64(numSymbols) / numSys} enc, _ := p.GetKzgEncoder(params) @@ -113,9 +116,13 @@ func TestKzgRs() { } fmt.Printf("frame %v leading coset %v\n", i, j) - lc := enc.Fs.ExpandedRootsOfUnity[uint64(j)] + rsEncoder, err := enc.GetRsEncoder(params) + if err != nil { + log.Fatalf("%v", err) + } + lc := rsEncoder.Fs.ExpandedRootsOfUnity[uint64(j)] - g2Atn, err := kzg.ReadG2Point(uint64(len(f.Coeffs)), kzgConfig) + g2Atn, err := kzg.ReadG2Point(uint64(len(f.Coeffs)), kzgConfig.SRSOrder, kzgConfig.G2Path) if err != nil { log.Fatalf("Load g2 %v failed\n", err) } diff --git a/encoding/utils/openCommitment/open_commitment_test.go b/encoding/utils/openCommitment/open_commitment_test.go index 1430130b38..2493e55939 100644 --- a/encoding/utils/openCommitment/open_commitment_test.go +++ b/encoding/utils/openCommitment/open_commitment_test.go @@ -9,7 +9,7 @@ import ( "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/encoding/kzg" - kzgProver "github.com/Layr-Labs/eigenda/encoding/kzg/prover" + "github.com/Layr-Labs/eigenda/encoding/kzg/prover" "github.com/Layr-Labs/eigenda/encoding/rs" "github.com/Layr-Labs/eigenda/encoding/utils/codec" oc "github.com/Layr-Labs/eigenda/encoding/utils/openCommitment" @@ -38,6 +38,7 @@ func TestOpenCommitment(t *testing.T) { SRSOrder: 3000, SRSNumberToLoad: 3000, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } // input evaluation @@ -57,8 +58,8 @@ func TestOpenCommitment(t *testing.T) { } // we need prover only to access kzg SRS, and get kzg commitment of encoding - group, err := kzgProver.NewProver(kzgConfig, true) - require.Nil(t, err) + group, err := prover.NewProver(kzgConfig, nil) + require.NoError(t, err) // get root of unit for blob numNode = 4 @@ -69,10 +70,14 @@ func TestOpenCommitment(t *testing.T) { params := encoding.ParamsFromSysPar(numSys, numPar, uint64(len(validInput))) enc, err := group.GetKzgEncoder(params) require.Nil(t, err) - rootOfUnities := enc.Fs.ExpandedRootsOfUnity[:len(enc.Fs.ExpandedRootsOfUnity)-1] + + rs, err := enc.GetRsEncoder(params) + require.NoError(t, err) + + rootOfUnities := rs.Fs.ExpandedRootsOfUnity[:len(rs.Fs.ExpandedRootsOfUnity)-1] // Lagrange basis SRS in normal order, not butterfly - lagrangeG1SRS, err := enc.Fs.FFTG1(group.Srs.G1[:len(paddedInputFr)], true) + lagrangeG1SRS, err := rs.Fs.FFTG1(group.Srs.G1[:len(paddedInputFr)], true) require.Nil(t, err) // commit in lagrange form diff --git a/go.mod b/go.mod index dc262e397b..e01b2dfd41 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/gin-gonic/gin v1.9.1 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 github.com/hashicorp/go-multierror v1.1.1 + github.com/ingonyama-zk/icicle/v3 v3.1.0 github.com/jedib0t/go-pretty/v6 v6.5.9 github.com/joho/godotenv v1.5.1 github.com/onsi/ginkgo/v2 v2.11.0 @@ -198,7 +199,7 @@ require ( github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.48.0 github.com/prometheus/procfs v0.12.0 // indirect - github.com/rogpeppe/go-internal v1.10.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible github.com/stretchr/objx v0.5.2 // indirect diff --git a/go.sum b/go.sum index d3b4dde0bf..11c555d81c 100644 --- a/go.sum +++ b/go.sum @@ -309,6 +309,8 @@ github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc= github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= github.com/iden3/go-iden3-crypto v0.0.16 h1:zN867xiz6HgErXVIV/6WyteGcOukE9gybYTorBMEdsk= github.com/iden3/go-iden3-crypto v0.0.16/go.mod h1:dLpM4vEPJ3nDHzhWFXDjzkn1qHoBeOT/3UEhXsEsP3E= +github.com/ingonyama-zk/icicle/v3 v3.1.0 h1:NpYQxrcY7AN2ghepi2VyU52JiPF/SQdjSkxu3f0RtSU= +github.com/ingonyama-zk/icicle/v3 v3.1.0/go.mod h1:e0JHb27/P6WorCJS3YolbY5XffS4PGBuoW38OthLkDs= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= github.com/jackpal/go-nat-pmp v1.0.2/go.mod h1:QPH045xvCAeXUZOxsnwmrtiCoxIr9eob+4orBN1SBKc= github.com/jedib0t/go-pretty/v6 v6.5.9 h1:ACteMBRrrmm1gMsXe9PSTOClQ63IXDUt03H5U+UV8OU= @@ -451,8 +453,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= diff --git a/inabox/tests/integration_suite_test.go b/inabox/tests/integration_suite_test.go index 2f8250a1ae..3a4e53bbbb 100644 --- a/inabox/tests/integration_suite_test.go +++ b/inabox/tests/integration_suite_test.go @@ -161,17 +161,19 @@ func setupRetrievalClient(testConfig *deploy.Config) error { if err != nil { return err } - v, err := verifier.NewVerifier(&kzg.KzgConfig{ + kzgConfig := &kzg.KzgConfig{ G1Path: testConfig.Retriever.RETRIEVER_G1_PATH, G2Path: testConfig.Retriever.RETRIEVER_G2_PATH, G2PowerOf2Path: testConfig.Retriever.RETRIEVER_G2_POWER_OF_2_PATH, CacheDir: testConfig.Retriever.RETRIEVER_CACHE_PATH, - NumWorker: 1, SRSOrder: uint64(srsOrder), SRSNumberToLoad: uint64(srsOrder), - Verbose: true, + NumWorker: 1, PreloadEncoder: false, - }, false) + LoadG2Points: true, + } + + v, err := verifier.NewVerifier(kzgConfig, nil) if err != nil { return err } diff --git a/node/grpc/server_test.go b/node/grpc/server_test.go index 8953f1f1cd..825876192e 100644 --- a/node/grpc/server_test.go +++ b/node/grpc/server_test.go @@ -60,14 +60,15 @@ func makeTestComponents() (encoding.Prover, encoding.Verifier, error) { SRSOrder: 300000, SRSNumberToLoad: 300000, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - p, err := prover.NewProver(config, true) + p, err := prover.NewProver(config, nil) if err != nil { return nil, nil, err } - v, err := verifier.NewVerifier(config, true) + v, err := verifier.NewVerifier(config, nil) if err != nil { return nil, nil, err } diff --git a/node/node.go b/node/node.go index 62845f1345..3e36d5f787 100644 --- a/node/node.go +++ b/node/node.go @@ -179,7 +179,8 @@ func NewNode( metrics := NewMetrics(eigenMetrics, reg, logger, ":"+config.MetricsPort, config.ID, config.OnchainMetricsInterval, tx, cst) // Make validator - v, err := verifier.NewVerifier(&config.EncoderConfig, false) + config.EncoderConfig.LoadG2Points = false + v, err := verifier.NewVerifier(&config.EncoderConfig, nil) if err != nil { return nil, err } diff --git a/relay/chunkstore/chunk_store_test.go b/relay/chunkstore/chunk_store_test.go index f1f8300a64..6d0c3e2d3e 100644 --- a/relay/chunkstore/chunk_store_test.go +++ b/relay/chunkstore/chunk_store_test.go @@ -2,7 +2,6 @@ package chunkstore import ( "context" - "math" "math/rand" "os" "testing" @@ -14,14 +13,13 @@ import ( tu "github.com/Layr-Labs/eigenda/common/testutils" corev2 "github.com/Layr-Labs/eigenda/core/v2" "github.com/Layr-Labs/eigenda/encoding" - "github.com/Layr-Labs/eigenda/encoding/fft" "github.com/Layr-Labs/eigenda/encoding/rs" - rs_cpu "github.com/Layr-Labs/eigenda/encoding/rs/cpu" "github.com/Layr-Labs/eigenda/encoding/utils/codec" "github.com/Layr-Labs/eigenda/inabox/deploy" "github.com/Layr-Labs/eigensdk-go/logging" "github.com/consensys/gnark-crypto/ecc/bn254/fp" "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -190,8 +188,8 @@ func TestRandomProofs(t *testing.T) { } } -func generateRandomFrames(t *testing.T, encoder *rs.Encoder, size int) []*rs.Frame { - frames, _, err := encoder.EncodeBytes(codec.ConvertByPaddingEmptyByte(tu.RandomBytes(size))) +func generateRandomFrames(t *testing.T, encoder *rs.Encoder, size int, params encoding.EncodingParams) []*rs.Frame { + frames, _, err := encoder.EncodeBytes(codec.ConvertByPaddingEmptyByte(tu.RandomBytes(size)), params) result := make([]*rs.Frame, len(frames)) require.NoError(t, err) @@ -208,22 +206,10 @@ func RandomCoefficientsTest(t *testing.T, client s3.Client) { chunkSize := uint64(rand.Intn(1024) + 100) fragmentSize := int(chunkSize / 2) - params := encoding.ParamsFromSysPar(3, 1, chunkSize) - encoder, _ := rs.NewEncoder(params, true) - - n := uint8(math.Log2(float64(encoder.NumEvaluations()))) - if encoder.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * encoder.NumChunks))) - } - fs := fft.NewFFTSettings(n) - - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - encoder.Computer = RsComputeDevice + cfg := encoding.DefaultConfig() + encoder, err := rs.NewEncoder(cfg) + assert.Nil(t, err) require.NotNil(t, encoder) writer := NewChunkWriter(logger, client, bucket, fragmentSize) @@ -236,7 +222,7 @@ func RandomCoefficientsTest(t *testing.T, client s3.Client) { for i := 0; i < 100; i++ { key := corev2.BlobKey(tu.RandomBytes(32)) - coefficients := generateRandomFrames(t, encoder, int(chunkSize)) + coefficients := generateRandomFrames(t, encoder, int(chunkSize), params) expectedValues[key] = coefficients metadata, err := writer.PutChunkCoefficients(context.Background(), key, coefficients) @@ -282,20 +268,9 @@ func TestCheckProofCoefficientsExist(t *testing.T) { fragmentSize := int(chunkSize / 2) params := encoding.ParamsFromSysPar(3, 1, chunkSize) - encoder, _ := rs.NewEncoder(params, true) - - n := uint8(math.Log2(float64(encoder.NumEvaluations()))) - if encoder.ChunkLength == 1 { - n = uint8(math.Log2(float64(2 * encoder.NumChunks))) - } - fs := fft.NewFFTSettings(n) - - RsComputeDevice := &rs_cpu.RsCpuComputeDevice{ - Fs: fs, - EncodingParams: params, - } - - encoder.Computer = RsComputeDevice + cfg := encoding.DefaultConfig() + encoder, err := rs.NewEncoder(cfg) + assert.Nil(t, err) require.NotNil(t, encoder) writer := NewChunkWriter(logger, client, bucket, fragmentSize) @@ -308,7 +283,7 @@ func TestCheckProofCoefficientsExist(t *testing.T) { require.NoError(t, err) require.True(t, writer.ProofExists(ctx, key)) - coefficients := generateRandomFrames(t, encoder, int(chunkSize)) + coefficients := generateRandomFrames(t, encoder, int(chunkSize), params) metadata, err := writer.PutChunkCoefficients(ctx, key, coefficients) require.NoError(t, err) exist, fragmentInfo := writer.CoefficientsExists(ctx, key) diff --git a/relay/relay_test_utils.go b/relay/relay_test_utils.go index 1e5120f700..28051afb2c 100644 --- a/relay/relay_test_utils.go +++ b/relay/relay_test_utils.go @@ -72,9 +72,10 @@ func setup(t *testing.T) { SRSOrder: 8192, SRSNumberToLoad: 8192, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } var err error - prover, err = p.NewProver(config, true) + prover, err = p.NewProver(config, nil) require.NoError(t, err) } } diff --git a/retriever/cmd/main.go b/retriever/cmd/main.go index fd1704c315..8ea7b0429a 100644 --- a/retriever/cmd/main.go +++ b/retriever/cmd/main.go @@ -78,7 +78,9 @@ func RetrieverMain(ctx *cli.Context) error { } nodeClient := clients.NewNodeClient(config.Timeout) - v, err := verifier.NewVerifier(&config.EncoderConfig, true) + + config.EncoderConfig.LoadG2Points = true + v, err := verifier.NewVerifier(&config.EncoderConfig, nil) if err != nil { log.Fatalln("could not start tcp listener", err) } diff --git a/retriever/server_test.go b/retriever/server_test.go index 161e6f0677..79e2b22a8e 100644 --- a/retriever/server_test.go +++ b/retriever/server_test.go @@ -41,14 +41,15 @@ func makeTestComponents() (encoding.Prover, encoding.Verifier, error) { SRSOrder: 3000, SRSNumberToLoad: 3000, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - p, err := prover.NewProver(config, true) + p, err := prover.NewProver(config, nil) if err != nil { return nil, nil, err } - v, err := verifier.NewVerifier(config, true) + v, err := verifier.NewVerifier(config, nil) if err != nil { return nil, nil, err } diff --git a/test/integration_test.go b/test/integration_test.go index 3df7de45ed..b3efdc25f1 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -98,7 +98,6 @@ func init() { // makeTestEncoder makes an encoder currently using the only supported backend. func mustMakeTestComponents() (encoding.Prover, encoding.Verifier) { - config := &kzg.KzgConfig{ G1Path: "../inabox/resources/kzg/g1.point", G2Path: "../inabox/resources/kzg/g2.point", @@ -106,14 +105,15 @@ func mustMakeTestComponents() (encoding.Prover, encoding.Verifier) { SRSOrder: 3000, SRSNumberToLoad: 3000, NumWorker: uint64(runtime.GOMAXPROCS(0)), + LoadG2Points: true, } - p, err := prover.NewProver(config, true) + p, err := prover.NewProver(config, nil) if err != nil { log.Fatal(err) } - v, err := verifier.NewVerifier(config, true) + v, err := verifier.NewVerifier(config, nil) if err != nil { log.Fatal(err) } diff --git a/test/synthetic-test/synthetic_client_test.go b/test/synthetic-test/synthetic_client_test.go index ab868c557a..3854a214e9 100644 --- a/test/synthetic-test/synthetic_client_test.go +++ b/test/synthetic-test/synthetic_client_test.go @@ -241,6 +241,7 @@ func setupRetrievalClient(ethClient common.EthClient, retrievalClientConfig *Ret SRSNumberToLoad: uint64(srsOrder), Verbose: true, PreloadEncoder: false, + LoadG2Points: true, }, false) if err != nil { return err diff --git a/tools/traffic/generator_v2.go b/tools/traffic/generator_v2.go index dc39b4af85..beec5e393b 100644 --- a/tools/traffic/generator_v2.go +++ b/tools/traffic/generator_v2.go @@ -179,7 +179,8 @@ func buildRetriever(config *config.Config) (clients.RetrievalClient, retrivereth nodeClient := clients.NewNodeClient(config.NodeClientTimeout) - v, err := verifier.NewVerifier(&config.RetrievalClientConfig.EncoderConfig, true) + config.RetrievalClientConfig.EncoderConfig.LoadG2Points = true + v, err := verifier.NewVerifier(&config.RetrievalClientConfig.EncoderConfig, nil) if err != nil { panic(fmt.Sprintf("Unable to build statusTracker: %s", err)) }