Skip to content

Commit

Permalink
working parallel kzg proof compute
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jul 7, 2024
1 parent 4102710 commit 0e81887
Show file tree
Hide file tree
Showing 14 changed files with 203 additions and 116 deletions.
2 changes: 1 addition & 1 deletion encoding/kzg/prover/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
)

type ProofComputer interface {
type ProofComputeDevice interface {
// blobFr are coefficients
ComputeCommitment(blobFr []fr.Element) (*bn254.G1Affine, error)
ComputeMultiFrameProof(blobFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error)
Expand Down
2 changes: 1 addition & 1 deletion encoding/kzg/prover/gpu/ecntt.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
)

func (c *GpuComputer) ECNtt(batchPoints []bn254.G1Affine, isInverse bool) ([]bn254.G1Affine, error) {
func (c *GpuComputeDevice) ECNtt(batchPoints []bn254.G1Affine, isInverse bool) ([]bn254.G1Affine, error) {
totalNumSym := len(batchPoints)

// convert gnark affine to icicle projective on slice
Expand Down
2 changes: 1 addition & 1 deletion encoding/kzg/prover/gpu/msm.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

// MsmBatch function supports batch across blobs
func (c *GpuComputer) MsmBatch(rowsFr [][]fr.Element, rowsG1 [][]bn254.G1Affine) ([]bn254.G1Affine, error) {
func (c *GpuComputeDevice) MsmBatch(rowsFr [][]fr.Element, rowsG1 [][]bn254.G1Affine) ([]bn254.G1Affine, error) {
msmCfg := icicle_bn254_msm.GetDefaultMSMConfig()
rowsSfIcicle := make([]icicle_bn254.ScalarField, 0)
rowsAffineIcicle := make([]icicle_bn254.Affine, 0)
Expand Down
22 changes: 14 additions & 8 deletions encoding/kzg/prover/gpu/multiframe_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gpu

import (
"fmt"
"sync"
"time"

"github.com/Layr-Labs/eigenda/encoding/fft"
Expand All @@ -19,18 +20,19 @@ type WorkerResult struct {
err error
}

type GpuComputer struct {
type GpuComputeDevice struct {
*kzg.KzgConfig
Fs *fft.FFTSettings
FFTPointsT [][]bn254.G1Affine // transpose of FFTPoints
SFs *fft.FFTSettings
Srs *kzg.SRS
G2Trailing []bn254.G2Affine
NttCfg core.NTTConfig[[bn254_icicle.SCALAR_LIMBS]uint32]
NttCfg core.NTTConfig[[bn254_icicle.SCALAR_LIMBS]uint32]
GpuLock *sync.Mutex // lock whenever gpu is needed,
}

// benchmarks shows cpu commit on 2MB blob only takes 24.165562ms. For now, use cpu
func (p *GpuComputer) ComputeLengthProof(coeffs []fr.Element) (*bn254.G2Affine, error) {
func (p *GpuComputeDevice) ComputeLengthProof(coeffs []fr.Element) (*bn254.G2Affine, error) {
inputLength := uint64(len(coeffs))
shiftedSecret := p.G2Trailing[p.KzgConfig.SRSNumberToLoad-inputLength:]
config := ecc.MultiExpConfig{}
Expand All @@ -44,7 +46,7 @@ func (p *GpuComputer) ComputeLengthProof(coeffs []fr.Element) (*bn254.G2Affine,
}

// benchmarks shows cpu commit on 2MB blob only takes 11.673738ms. For now, use cpu
func (p *GpuComputer) ComputeCommitment(coeffs []fr.Element) (*bn254.G1Affine, error) {
func (p *GpuComputeDevice) ComputeCommitment(coeffs []fr.Element) (*bn254.G1Affine, error) {
// compute commit for the full poly
config := ecc.MultiExpConfig{}
var commitment bn254.G1Affine
Expand All @@ -56,7 +58,7 @@ func (p *GpuComputer) ComputeCommitment(coeffs []fr.Element) (*bn254.G1Affine, e
}

// benchmarks shows cpu commit on 2MB blob only takes 31.318661ms. For now, use cpu
func (p *GpuComputer) ComputeLengthCommitment(coeffs []fr.Element) (*bn254.G2Affine, error) {
func (p *GpuComputeDevice) ComputeLengthCommitment(coeffs []fr.Element) (*bn254.G2Affine, error) {
config := ecc.MultiExpConfig{}

var lengthCommitment bn254.G2Affine
Expand All @@ -69,7 +71,7 @@ func (p *GpuComputer) ComputeLengthCommitment(coeffs []fr.Element) (*bn254.G2Aff

// This function supports batching over multiple blobs.
// All blobs must have same size and concatenated passed as polyFr
func (p *GpuComputer) ComputeMultiFrameProof(polyFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) {
func (p *GpuComputeDevice) ComputeMultiFrameProof(polyFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) {
// Robert: Standardizing this to use the same math used in precomputeSRS
dimE := numChunks
l := chunkLen
Expand Down Expand Up @@ -110,6 +112,10 @@ func (p *GpuComputer) ComputeMultiFrameProof(polyFr []fr.Element, numChunks, chu
}
preprocessDone := time.Now()

// Start using GPU
p.GpuLock.Lock()
defer p.GpuLock.Unlock()

// Compute NTT on the coeff matrix
p.NttCfg.BatchSize = int32(l)
coeffStoreFFT, e := p.NTT(coeffStore)
Expand Down Expand Up @@ -180,7 +186,7 @@ func (p *GpuComputer) ComputeMultiFrameProof(polyFr []fr.Element, numChunks, chu
return flatProofsBatch, nil
}

func (p *GpuComputer) proofWorkerGPU(
func (p *GpuComputeDevice) proofWorkerGPU(
polyFr []fr.Element,
jobChan <-chan uint64,
l uint64,
Expand Down Expand Up @@ -208,7 +214,7 @@ func (p *GpuComputer) proofWorkerGPU(
}

// capable of batching blobs
func (p *GpuComputer) GetSlicesCoeffWithoutFFT(polyFr []fr.Element, dimE, j, l uint64) ([]fr.Element, error) {
func (p *GpuComputeDevice) GetSlicesCoeffWithoutFFT(polyFr []fr.Element, dimE, j, l uint64) ([]fr.Element, error) {
// there is a constant term
m := uint64(dimE*l) - 1
dim := (m - j%l) / l
Expand Down
8 changes: 7 additions & 1 deletion encoding/kzg/prover/gpu/ntt.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package gpu

import (
"fmt"

"github.com/Layr-Labs/eigenda/encoding/utils/gpu_utils"
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
bn254_icicle "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254"
bn254_icicle_ntt "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/ntt"
)

func (c *GpuComputer) NTT(batchFr [][]fr.Element) ([][]fr.Element, error) {
func (c *GpuComputeDevice) NTT(batchFr [][]fr.Element) ([][]fr.Element, error) {
if len(batchFr) == 0 {
return nil, fmt.Errorf("input to NTT contains no blob")
}

numSymbol := len(batchFr[0])
batchSize := len(batchFr)

Expand Down
159 changes: 114 additions & 45 deletions encoding/kzg/prover/parametrized_prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/Layr-Labs/eigenda/encoding"
"github.com/hashicorp/go-multierror"

"github.com/Layr-Labs/eigenda/encoding/fft"
"github.com/Layr-Labs/eigenda/encoding/kzg"
Expand All @@ -27,7 +28,34 @@ type ParametrizedProver struct {
FFTPointsT [][]bn254.G1Affine // transpose of FFTPoints

UseGpu bool
Computer ProofComputer
Computer ProofComputeDevice
}

type RsEncodeResult struct {
Frames []rs.Frame
Indices []uint32
Err error
Duration time.Duration
}
type LengthCommitmentResult struct {
LengthCommitment bn254.G2Affine
Err error
Duration time.Duration
}
type LengthProofResult struct {
LengthProof bn254.G2Affine
Err error
Duration time.Duration
}
type CommitmentResult struct {
Commitment bn254.G1Affine
Err error
Duration time.Duration
}
type ProofsResult struct {
Proofs []bn254.G1Affine
Err error
Duration time.Duration
}

// just a wrapper to take bytes not Fr Element
Expand All @@ -45,66 +73,107 @@ func (g *ParametrizedProver) Encode(inputFr []fr.Element) (*bn254.G1Affine, *bn2
return nil, nil, nil, nil, nil, fmt.Errorf("poly Coeff length %v is greater than Loaded SRS points %v", len(inputFr), int(g.KzgConfig.SRSNumberToLoad))
}

startTime := time.Now()
// compute chunks
poly, frames, indices, err := g.Encoder.Encode(inputFr)
if err != nil {
return nil, nil, nil, nil, nil, err
}
rsEncodeDone := time.Now()

// compute commit for the full poly
commit, err := g.Computer.ComputeCommitment(poly.Coeffs)
if err != nil {
return nil, nil, nil, nil, nil, err
}
commitDone := time.Now()
encodeStart := time.Now()

lengthCommitment, err := g.Computer.ComputeLengthCommitment(poly.Coeffs)
if err != nil {
return nil, nil, nil, nil, nil, err
}
lengthCommitDone := time.Now()
rsChan := make(chan RsEncodeResult, 1)
lengthCommitmentChan := make(chan LengthCommitmentResult, 1)
lengthProofChan := make(chan LengthProofResult, 1)
commitmentChan := make(chan CommitmentResult, 1)
proofChan := make(chan ProofsResult, 1)

lengthProof, err := g.Computer.ComputeLengthProof(poly.Coeffs)
if err != nil {
return nil, nil, nil, nil, nil, err
}
lengthProofDone := time.Now()
// inputFr is untouched
// compute chunks
go func() {
start := time.Now()
frames, indices, err := g.Encoder.Encode(inputFr)
rsChan <- RsEncodeResult{
Frames: frames,
Indices: indices,
Err: err,
Duration: time.Since(start),
}
}()

// compute proofs
paddedCoeffs := make([]fr.Element, g.NumEvaluations())
// polyCoeffs has less points than paddedCoeffs in general due to erasure redundancy
copy(paddedCoeffs, poly.Coeffs)
proofs, err := g.Computer.ComputeMultiFrameProof(paddedCoeffs, g.NumChunks, g.ChunkLength, g.NumWorker)
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("could not generate proofs: %v", err)
// compute commit for the full poly
go func() {
start := time.Now()
commit, err := g.Computer.ComputeCommitment(inputFr)
commitmentChan <- CommitmentResult{
Commitment: *commit,
Err: err,
Duration: time.Since(start),
}
}()

go func() {
start := time.Now()
lengthCommitment, err := g.Computer.ComputeLengthCommitment(inputFr)
lengthCommitmentChan <- LengthCommitmentResult{
LengthCommitment: *lengthCommitment,
Err: err,
Duration: time.Since(start),
}
}()

go func() {
start := time.Now()
lengthProof, err := g.Computer.ComputeLengthProof(inputFr)
lengthProofChan <- LengthProofResult{
LengthProof: *lengthProof,
Err: err,
Duration: time.Since(start),
}
}()

go func() {
start := time.Now()
// compute proofs
paddedCoeffs := make([]fr.Element, g.NumEvaluations())
// polyCoeffs has less points than paddedCoeffs in general due to erasure redundancy
copy(paddedCoeffs, inputFr)
proofs, err := g.Computer.ComputeMultiFrameProof(paddedCoeffs, g.NumChunks, g.ChunkLength, g.NumWorker)
proofChan <- ProofsResult{
Proofs: proofs,
Err: err,
Duration: time.Since(start),
}
}()

lengthProofResult := <-lengthProofChan
lengthCommitmentResult := <-lengthCommitmentChan
commitmentResult := <-commitmentChan
rsResult := <-rsChan
proofsResult := <-proofChan

if lengthProofResult.Err != nil || lengthCommitmentResult.Err != nil ||
commitmentResult.Err != nil || rsResult.Err != nil ||
proofsResult.Err != nil {
return nil, nil, nil, nil, nil, multierror.Append(lengthProofResult.Err, lengthCommitmentResult.Err, commitmentResult.Err, rsResult.Err, proofsResult.Err)
}
multiProofDone := time.Now()

totalProcessingTime := time.Since(encodeStart)
if g.Verbose {
log.Printf("\n\t\tRS encode %-v\n\t\tCommiting %-v\n\t\tLengthCommit %-v\n\t\tlengthProof %-v\n\t\tmultiProof %-v\n\t\tMetaInfo. order %-v shift %v\n",
rsEncodeDone.Sub(startTime),
commitDone.Sub(rsEncodeDone),
lengthCommitDone.Sub(commitDone),
lengthProofDone.Sub(lengthCommitDone),
multiProofDone.Sub(lengthProofDone),
rsResult.Duration,
commitmentResult.Duration,
lengthCommitmentResult.Duration,
lengthProofResult.Duration,
proofsResult.Duration,
len(g.Srs.G2),
g.SRSOrder-uint64(len(inputFr)),
)
}

// assemble frames
kzgFrames := make([]encoding.Frame, len(frames))
for i, index := range indices {
kzgFrames := make([]encoding.Frame, len(rsResult.Frames))
for i, index := range rsResult.Indices {
kzgFrames[i] = encoding.Frame{
Proof: proofs[index],
Coeffs: frames[i].Coeffs,
Proof: proofsResult.Proofs[index],
Coeffs: rsResult.Frames[i].Coeffs,
}
}

if g.Verbose {
log.Printf("Total encoding took %v\n", time.Since(startTime))
log.Printf("Total encoding took %v\n", totalProcessingTime)
}
return commit, lengthCommitment, lengthProof, kzgFrames, indices, nil
return &commitmentResult.Commitment, &lengthCommitmentResult.LengthCommitment, &lengthProofResult.LengthProof, kzgFrames, rsResult.Indices, nil
}
16 changes: 9 additions & 7 deletions encoding/kzg/prover/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ func (g *Prover) newProver(params encoding.EncodingParams) (*ParametrizedProver,
sfs := fft.NewFFTSettings(t)

// Set RS computer
var rsComputer rs.RSComputer
var RsComputeDevice rs.RsComputeDevice

// Set KZG Prover computer
var computer ProofComputer
var computer ProofComputeDevice
if !g.UseGpu {
computer = &cpu.CpuComputer{
Fs: fs,
Expand All @@ -256,30 +256,32 @@ func (g *Prover) newProver(params encoding.EncodingParams) (*ParametrizedProver,
G2Trailing: g.G2Trailing,
KzgConfig: g.KzgConfig,
}
rsComputer = &rs_cpu.CpuComputer{
RsComputeDevice = &rs_cpu.CpuComputer{
Fs: fs,
EncodingParams: params,
}
} else {
nttCfg := gpu_utils.SetupNTT()

computer = &gpu.GpuComputer{
GpuLock := sync.Mutex{}
computer = &gpu.GpuComputeDevice{
Fs: fs,
FFTPointsT: fftPointsT,
SFs: sfs,
Srs: g.Srs,
G2Trailing: g.G2Trailing,
KzgConfig: g.KzgConfig,
NttCfg: nttCfg,
GpuLock: &GpuLock,
}

rsComputer = &rs_gpu.GpuComputer{
RsComputeDevice = &rs_gpu.GpuComputeDevice{
Fs: fs,
EncodingParams: params,
NttCfg: nttCfg,
GpuLock: &GpuLock,
}
}
encoder.Computer = rsComputer
encoder.Computer = RsComputeDevice

return &ParametrizedProver{
Encoder: encoder,
Expand Down
Loading

0 comments on commit 0e81887

Please sign in to comment.