From 9e79171d39dc7d463b7a5f5f7650375801e7a450 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 3 Mar 2024 02:25:33 +0000 Subject: [PATCH] optimize encoder interpolation latency --- encoding/rs/encode.go | 104 +++++++++++++++++++++++++++++------ encoding/rs/encoder.go | 5 ++ encoding/rs/interpolation.go | 41 +++++++++++--- 3 files changed, 126 insertions(+), 24 deletions(-) diff --git a/encoding/rs/encode.go b/encoding/rs/encode.go index 68bf2438bd..be568b7ef0 100644 --- a/encoding/rs/encode.go +++ b/encoding/rs/encode.go @@ -77,31 +77,68 @@ func (g *Encoder) MakeFrames( indices := make([]uint32, 0) frames := make([]Frame, g.NumChunks) - for i := uint64(0); i < uint64(g.NumChunks); i++ { + jobChan := make(chan JobRequest, g.NumRSWorker) + results := make(chan error, g.NumRSWorker) + + for w := uint64(0); w < uint64(g.NumRSWorker); w++ { + go g.interpolyWorker( + polyEvals, + jobChan, + results, + frames, + ) + } - // finds out which coset leader i-th node is having + for i := uint64(0); i < g.NumChunks; i++ { j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i)) - - // mutltiprover return proof in butterfly order - frame := Frame{} + jr := JobRequest{ + Index: uint64(i), + FrameIndex: k, + } + jobChan <- jr + k++ indices = append(indices, j) + } + close(jobChan) - ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)] - err := rb.ReverseBitOrderFr(ys) - if err != nil { - return nil, nil, err - } - coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j)) - if err != nil { - return nil, nil, err + for w := uint64(0); w < uint64(g.NumRSWorker); w++ { + interPolyErr := <-results + if interPolyErr != nil { + err = interPolyErr } + } - frame.Coeffs = coeffs - - frames[k] = frame - k++ + if err != nil { + return nil, nil, fmt.Errorf("proof worker error: %v", err) } + /* + for i := uint64(0); i < uint64(g.NumChunks); i++ { + + // finds out which coset leader i-th node is having + j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i)) + + // mutltiprover return proof in butterfly order + frame := Frame{} + indices = append(indices, j) + + ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)] + err := rb.ReverseBitOrderFr(ys) + if err != nil { + return nil, nil, err + } + coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j)) + if err != nil { + return nil, nil, err + } + + frame.Coeffs = coeffs + + frames[k] = frame + k++ + } + */ + return frames, indices, nil } @@ -127,3 +164,36 @@ func (g *Encoder) ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, []fr.Elemen return evals, pdCoeffs, nil } + +type JobRequest struct { + Index uint64 + FrameIndex uint64 +} + +func (g *Encoder) interpolyWorker( + polyEvals []fr.Element, + jobChan <-chan JobRequest, + results chan<- error, + frames []Frame, +) { + + for jr := range jobChan { + i := jr.Index + k := jr.FrameIndex + j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i)) + ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)] + err := rb.ReverseBitOrderFr(ys) + if err != nil { + results <- err + } + coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j)) + if err != nil { + results <- err + } + + frames[k].Coeffs = coeffs + } + + results <- nil + +} diff --git a/encoding/rs/encoder.go b/encoding/rs/encoder.go index 3d6a4346e9..7eeb1d7b6c 100644 --- a/encoding/rs/encoder.go +++ b/encoding/rs/encoder.go @@ -2,9 +2,11 @@ package rs import ( "math" + "runtime" "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/encoding/fft" + //"github.com/consensys/gnark-crypto/ecc/bn254/fr" ) type Encoder struct { @@ -13,6 +15,8 @@ type Encoder struct { Fs *fft.FFTSettings verbose bool + + NumRSWorker int } // The function creates a high level struct that determines the encoding the a data of a @@ -37,6 +41,7 @@ func NewEncoder(params encoding.EncodingParams, verbose bool) (*Encoder, error) EncodingParams: params, Fs: fs, verbose: verbose, + NumRSWorker: runtime.GOMAXPROCS(0), }, nil } diff --git a/encoding/rs/interpolation.go b/encoding/rs/interpolation.go index 5aa062edb7..e71b41fe5e 100644 --- a/encoding/rs/interpolation.go +++ b/encoding/rs/interpolation.go @@ -54,9 +54,9 @@ func (g *Encoder) GetInterpolationPolyEval( //var tmp, tmp2 fr.Element for i := 0; i < len(interpolationPoly); i++ { shiftedInterpolationPoly[i].Mul(&interpolationPoly[i], &wPow) - + wPow.Mul(&wPow, &w) - + } err := g.Fs.InplaceFFT(shiftedInterpolationPoly, evals, false) @@ -74,18 +74,45 @@ func (g *Encoder) GetInterpolationPolyCoeff(chunk []fr.Element, k uint32) ([]fr. } var wPow fr.Element wPow.SetOne() - + var tmp, tmp2 fr.Element for i := 0; i < len(chunk); i++ { tmp.Inverse(&wPow) - + tmp2.Mul(&shiftedInterpolationPoly[i], &tmp) - + coeffs[i].Set(&tmp2) - + tmp.Mul(&wPow, &w) - + wPow.Set(&tmp) } return coeffs, nil } + +/* +// exp is the exponent for the entire fft inverse root of unity array +func (g *Encoder) GetInvRootOfUnityArray(exp uint8) []fr.Element { + rous, ok := g.InvRootOfUnityTable[exp] + if !ok { + rous = g.CreateRootsOfUnityArray(exp) + g.InvRootOfUnityTable[exp] = rous + } + return rous +} + +func (g *Encoder) CreateRootsOfUnityArray(exp uint8) []fr.Element { + w := g.Fs.ExpandedRootsOfUnity[uint64(k)] + for i := 0; i < len(chunk); i++ { + tmp.Inverse(&wPow) + + tmp2.Mul(&shiftedInterpolationPoly[i], &tmp) + + coeffs[i].Set(&tmp2) + + tmp.Mul(&wPow, &w) + + wPow.Set(&tmp) + } +} +*/