Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cpu prover #629

Merged
merged 5 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 128 additions & 224 deletions encoding/kzg/prover/parametrized_prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@ package prover
import (
"fmt"
"log"
"math"
"time"

"github.com/Layr-Labs/eigenda/encoding"
"github.com/hashicorp/go-multierror"

"github.com/Layr-Labs/eigenda/encoding/fft"
"github.com/Layr-Labs/eigenda/encoding/kzg"
"github.com/Layr-Labs/eigenda/encoding/rs"
"github.com/Layr-Labs/eigenda/encoding/utils/toeplitz"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/ecc/bn254"
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
)
Expand All @@ -21,18 +18,36 @@ type ParametrizedProver struct {
*rs.Encoder

*kzg.KzgConfig
Srs *kzg.SRS
G2Trailing []bn254.G2Affine
Ks *kzg.KZGSettings

Fs *fft.FFTSettings
Ks *kzg.KZGSettings
SFs *fft.FFTSettings // fft used for submatrix product helper
FFTPointsT [][]bn254.G1Affine // transpose of FFTPoints
Computer ProofComputer
}

type WorkerResult struct {
points []bn254.G1Affine
err error
type RsEncodeResult struct {
Frames []rs.Frame
Indices []uint32
Duration time.Duration
Err error
}
type LengthCommitmentResult struct {
LengthCommitment bn254.G2Affine
Duration time.Duration
Err error
}
type LengthProofResult struct {
LengthProof bn254.G2Affine
Duration time.Duration
Err error
}
type CommitmentResult struct {
Commitment bn254.G1Affine
Duration time.Duration
Err error
}
type ProofsResult struct {
Proofs []bn254.G1Affine
Duration time.Duration
Err error
}

// just a wrapper to take bytes not Fr Element
Expand All @@ -46,228 +61,117 @@ func (g *ParametrizedProver) EncodeBytes(inputBytes []byte) (*bn254.G1Affine, *b

func (g *ParametrizedProver) Encode(inputFr []fr.Element) (*bn254.G1Affine, *bn254.G2Affine, *bn254.G2Affine, []encoding.Frame, []uint32, error) {

startTime := time.Now()
poly, frames, indices, err := g.Encoder.Encode(inputFr)
if err != nil {
return nil, nil, nil, nil, nil, err
}

if len(poly.Coeffs) > int(g.KzgConfig.SRSNumberToLoad) {
return nil, nil, nil, nil, nil, fmt.Errorf("poly Coeff length %v is greater than Loaded SRS points %v", len(poly.Coeffs), int(g.KzgConfig.SRSNumberToLoad))
}

// compute commit for the full poly
commit, err := g.Commit(poly.Coeffs)
if err != nil {
return nil, nil, nil, nil, nil, err
if len(inputFr) > int(g.KzgConfig.SRSNumberToLoad) {
return nil, nil, nil, nil, nil, fmt.Errorf("poly Coeff length %v is greater than Loaded SRS points %v", len(inputFr), int(g.KzgConfig.SRSNumberToLoad))
}

config := ecc.MultiExpConfig{}
encodeStart := time.Now()

var lengthCommitment bn254.G2Affine
_, err = lengthCommitment.MultiExp(g.Srs.G2[:len(poly.Coeffs)], poly.Coeffs, config)
if err != nil {
return nil, nil, nil, nil, nil, err
}

intermediate := time.Now()

chunkLength := uint64(len(inputFr))

if g.Verbose {
log.Printf(" Commiting takes %v\n", time.Since(intermediate))
intermediate = time.Now()

log.Printf("shift %v\n", g.SRSOrder-chunkLength)
log.Printf("order %v\n", len(g.Srs.G2))
log.Println("low degree verification info")
}

shiftedSecret := g.G2Trailing[g.KzgConfig.SRSNumberToLoad-chunkLength:]

//The proof of low degree is commitment of the polynomial shifted to the largest srs degree
var lengthProof bn254.G2Affine
_, err = lengthProof.MultiExp(shiftedSecret, poly.Coeffs, config)
if err != nil {
return nil, nil, nil, nil, nil, err
}
rsChan := make(chan RsEncodeResult, 1)
lengthCommitmentChan := make(chan LengthCommitmentResult, 1)
lengthProofChan := make(chan LengthProofResult, 1)
commitmentChan := make(chan CommitmentResult, 1)
proofChan := make(chan ProofsResult, 1)

if g.Verbose {
log.Printf(" Generating Length Proof takes %v\n", time.Since(intermediate))
intermediate = time.Now()
}

// compute proofs
paddedCoeffs := make([]fr.Element, g.NumEvaluations())
copy(paddedCoeffs, poly.Coeffs)

proofs, err := g.ProveAllCosetThreads(paddedCoeffs, g.NumChunks, g.ChunkLength, g.NumWorker)
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("could not generate proofs: %v", err)
}

if g.Verbose {
log.Printf(" Proving takes %v\n", time.Since(intermediate))
}

kzgFrames := make([]encoding.Frame, len(frames))
for i, index := range indices {
kzgFrames[i] = encoding.Frame{
Proof: proofs[index],
Coeffs: frames[i].Coeffs,
// inputFr is untouched
// compute chunks
go func() {
start := time.Now()
frames, indices, err := g.Encoder.Encode(inputFr)
rsChan <- RsEncodeResult{
Frames: frames,
Indices: indices,
Err: err,
Duration: time.Since(start),
}
}

if g.Verbose {
log.Printf("Total encoding took %v\n", time.Since(startTime))
}
return &commit, &lengthCommitment, &lengthProof, kzgFrames, indices, nil
}
}()

func (g *ParametrizedProver) Commit(polyFr []fr.Element) (bn254.G1Affine, error) {
commit, err := g.Ks.CommitToPoly(polyFr)
return *commit, err
}

func (p *ParametrizedProver) ProveAllCosetThreads(polyFr []fr.Element, numChunks, chunkLen, numWorker uint64) ([]bn254.G1Affine, error) {
begin := time.Now()
// Robert: Standardizing this to use the same math used in precomputeSRS
dimE := numChunks
l := chunkLen

sumVec := make([]bn254.G1Affine, dimE*2)

jobChan := make(chan uint64, numWorker)
results := make(chan WorkerResult, numWorker)

// create storage for intermediate fft outputs
coeffStore := make([][]fr.Element, dimE*2)
for i := range coeffStore {
coeffStore[i] = make([]fr.Element, l)
}

for w := uint64(0); w < numWorker; w++ {
go p.proofWorker(polyFr, jobChan, l, dimE, coeffStore, results)
}

for j := uint64(0); j < l; j++ {
jobChan <- j
}
close(jobChan)

// return last error
var err error
for w := uint64(0); w < numWorker; w++ {
wr := <-results
if wr.err != nil {
err = wr.err
// compute commit for the full poly
go func() {
start := time.Now()
commit, err := g.Computer.ComputeCommitment(inputFr)
commitmentChan <- CommitmentResult{
Commitment: *commit,
Err: err,
Duration: time.Since(start),
}
}

if err != nil {
return nil, fmt.Errorf("proof worker error: %v", err)
}

t0 := time.Now()

// compute proof by multi scaler multiplication
msmErrors := make(chan error, dimE*2)
for i := uint64(0); i < dimE*2; i++ {

go func(k uint64) {
_, err := sumVec[k].MultiExp(p.FFTPointsT[k], coeffStore[k], ecc.MultiExpConfig{})
// handle error
msmErrors <- err
}(i)
}

for i := uint64(0); i < dimE*2; i++ {
err := <-msmErrors
if err != nil {
fmt.Println("Error. MSM while adding points", err)
return nil, err
}()

go func() {
start := time.Now()
lengthCommitment, err := g.Computer.ComputeLengthCommitment(inputFr)
lengthCommitmentChan <- LengthCommitmentResult{
LengthCommitment: *lengthCommitment,
Err: err,
Duration: time.Since(start),
}
}

t1 := time.Now()

// only 1 ifft is needed
sumVecInv, err := p.Fs.FFTG1(sumVec, true)
if err != nil {
return nil, fmt.Errorf("fft error: %v", err)
}

t2 := time.Now()

// outputs is out of order - buttefly
proofs, err := p.Fs.FFTG1(sumVecInv[:dimE], false)
if err != nil {
return nil, err
}

t3 := time.Now()

fmt.Printf("mult-th %v, msm %v,fft1 %v, fft2 %v,\n", t0.Sub(begin), t1.Sub(t0), t2.Sub(t1), t3.Sub(t2))

return proofs, nil
}

func (p *ParametrizedProver) proofWorker(
polyFr []fr.Element,
jobChan <-chan uint64,
l uint64,
dimE uint64,
coeffStore [][]fr.Element,
results chan<- WorkerResult,
) {

for j := range jobChan {
coeffs, err := p.GetSlicesCoeff(polyFr, dimE, j, l)
if err != nil {
results <- WorkerResult{
points: nil,
err: err,
}
} else {
for i := 0; i < len(coeffs); i++ {
coeffStore[i][j] = coeffs[i]
}
}()

go func() {
start := time.Now()
lengthProof, err := g.Computer.ComputeLengthProof(inputFr)
lengthProofChan <- LengthProofResult{
LengthProof: *lengthProof,
Err: err,
Duration: time.Since(start),
}
}()

go func() {
start := time.Now()
// compute proofs
paddedCoeffs := make([]fr.Element, g.NumEvaluations())
// polyCoeffs has less points than paddedCoeffs in general due to erasure redundancy
copy(paddedCoeffs, inputFr)

numBlob := 1
flatpaddedCoeffs := make([]fr.Element, 0, numBlob*len(paddedCoeffs))
for i := 0; i < numBlob; i++ {
flatpaddedCoeffs = append(flatpaddedCoeffs, paddedCoeffs...)
}
}

results <- WorkerResult{
err: nil,
}
}

// output is in the form see primeField toeplitz
//
// phi ^ (coset size ) = 1
//
// implicitly pad slices to power of 2
func (p *ParametrizedProver) GetSlicesCoeff(polyFr []fr.Element, dimE, j, l uint64) ([]fr.Element, error) {
// there is a constant term
m := uint64(len(polyFr)) - 1
dim := (m - j) / l

toeV := make([]fr.Element, 2*dimE-1)
for i := uint64(0); i < dim; i++ {

toeV[i].Set(&polyFr[m-(j+i*l)])
proofs, err := g.Computer.ComputeMultiFrameProof(flatpaddedCoeffs, g.NumChunks, g.ChunkLength, g.NumWorker)
proofChan <- ProofsResult{
Proofs: proofs,
Err: err,
Duration: time.Since(start),
}
}()

lengthProofResult := <-lengthProofChan
lengthCommitmentResult := <-lengthCommitmentChan
commitmentResult := <-commitmentChan
rsResult := <-rsChan
proofsResult := <-proofChan

if lengthProofResult.Err != nil || lengthCommitmentResult.Err != nil ||
commitmentResult.Err != nil || rsResult.Err != nil ||
proofsResult.Err != nil {
return nil, nil, nil, nil, nil, multierror.Append(lengthProofResult.Err, lengthCommitmentResult.Err, commitmentResult.Err, rsResult.Err, proofsResult.Err)
}
totalProcessingTime := time.Since(encodeStart)

log.Printf("\n\t\tRS encode %-v\n\t\tCommiting %-v\n\t\tLengthCommit %-v\n\t\tlengthProof %-v\n\t\tmultiProof %-v\n\t\tMetaInfo. order %-v shift %v\n",
rsResult.Duration,
commitmentResult.Duration,
lengthCommitmentResult.Duration,
lengthProofResult.Duration,
proofsResult.Duration,
g.SRSOrder,
g.SRSOrder-uint64(len(inputFr)),
)

// assemble frames
kzgFrames := make([]encoding.Frame, len(rsResult.Frames))
for i, index := range rsResult.Indices {
kzgFrames[i] = encoding.Frame{
Proof: proofsResult.Proofs[index],
Coeffs: rsResult.Frames[i].Coeffs,
}
}

// use precompute table
tm, err := toeplitz.NewToeplitz(toeV, p.SFs)
if err != nil {
return nil, err
if g.Verbose {
log.Printf("Total encoding took %v\n", totalProcessingTime)
}
return tm.GetFFTCoeff()
}

/*
returns the power of 2 which is immediately bigger than the input
*/
func CeilIntPowerOf2Num(d uint64) uint64 {
nextPower := math.Ceil(math.Log2(float64(d)))
return uint64(math.Pow(2.0, nextPower))
return &commitmentResult.Commitment, &lengthCommitmentResult.LengthCommitment, &lengthProofResult.LengthProof, kzgFrames, rsResult.Indices, nil
}
Loading
Loading