Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize encoder interpolation latency #309

Merged
merged 13 commits into from
Mar 22, 2024
1 change: 1 addition & 0 deletions disperser/cmd/encoder/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"

"github.com/Layr-Labs/eigenda/common"

"github.com/Layr-Labs/eigenda/disperser/cmd/encoder/flags"
"github.com/urfave/cli"
)
Expand Down
8 changes: 4 additions & 4 deletions encoding/kzg/prover/parametrized_prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ func (p *ParametrizedProver) proofWorker(
points: nil,
err: err,
}
}

for i := 0; i < len(coeffs); i++ {
coeffStore[i][j] = coeffs[i]
} else {
for i := 0; i < len(coeffs); i++ {
coeffStore[i][j] = coeffs[i]
}
}
}

Expand Down
93 changes: 73 additions & 20 deletions encoding/rs/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rs

import (
"errors"
"fmt"
"log"
"time"

Expand Down Expand Up @@ -56,8 +57,8 @@ func (g *Encoder) Encode(inputFr []fr.Element) (*GlobalPoly, []Frame, []uint32,
return nil, nil, nil, err
}

log.Printf(" SUMMARY: Encode %v byte among %v numNode takes %v\n",
len(inputFr)*encoding.BYTES_PER_COEFFICIENT, g.NumChunks, time.Since(start))
log.Printf(" SUMMARY: RSEncode %v byte among %v numChunks with chunkLength %v takes %v\n",
len(inputFr)*encoding.BYTES_PER_COEFFICIENT, g.NumChunks, g.ChunkLength, time.Since(start))

return poly, frames, indices, nil
}
Expand All @@ -72,34 +73,51 @@ func (g *Encoder) MakeFrames(
if err != nil {
return nil, nil, err
}
k := uint64(0)


indices := make([]uint32, 0)
frames := make([]Frame, g.NumChunks)

for i := uint64(0); i < uint64(g.NumChunks); i++ {
numWorker := uint64(g.NumRSWorker)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to convert to uint64, since L83 makes the conversion already.
Also it probably has no need to use that many workers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is assigned to number chunks at L84


// finds out which coset leader i-th node is having
j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i))
if numWorker > g.NumChunks {
numWorker = g.NumChunks
}

// mutltiprover return proof in butterfly order
frame := Frame{}
indices = append(indices, j)
jobChan := make(chan JobRequest, numWorker)
results := make(chan error, numWorker)

ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)]
err := rb.ReverseBitOrderFr(ys)
if err != nil {
return nil, nil, err
}
coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j))
if err != nil {
return nil, nil, err
for w := uint64(0); w < numWorker; w++ {
go g.interpolyWorker(
polyEvals,
jobChan,
results,
frames,
)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move the k defined on L76 down here could make it more readable

k := uint64(0)
for i := uint64(0); i < g.NumChunks; i++ {
j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i))
jr := JobRequest{
Index: uint64(i),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i is already uint64

FrameIndex: k,
}
jobChan <- jr
k++
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is k needed? it's the same as i

indices = append(indices, j)
}
close(jobChan)

frame.Coeffs = coeffs
for w := uint64(0); w < numWorker; w++ {
interPolyErr := <-results
if interPolyErr != nil {
err = interPolyErr
}
}

frames[k] = frame
k++
if err != nil {
return nil, nil, fmt.Errorf("proof worker error: %v", err)
}

return frames, indices, nil
Expand Down Expand Up @@ -127,3 +145,38 @@ func (g *Encoder) ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, []fr.Elemen

return evals, pdCoeffs, nil
}

type JobRequest struct {
Index uint64
FrameIndex uint64
}

func (g *Encoder) interpolyWorker(
polyEvals []fr.Element,
jobChan <-chan JobRequest,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it just use a workerpool, instead of a dedicated channel to create job request

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually a good idea, I think we should create a global persistent object under prover? what do you think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why global/persistent? Is the concern in workpool creation/destroy cost? It's generally better not have global/persistent object if can be avoided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say each requests is 256KB, 10MBps is equivalent to 40 parallel requests, and they need to share threads. If multiple objects are specified, it is unclear how to allocate them. I will mark it as todo, since I don't know the performance of the threadpool, compared to now.

results chan<- error,
frames []Frame,
) {

for jr := range jobChan {
i := jr.Index
k := jr.FrameIndex
j := rb.ReverseBitsLimited(uint32(g.NumChunks), uint32(i))
ys := polyEvals[g.ChunkLength*i : g.ChunkLength*(i+1)]
err := rb.ReverseBitOrderFr(ys)
if err != nil {
results <- err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should have continue here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I used "return", because if anything that has error, the entire MakeFrames becomes false.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to use continue. If you had enough errors for all of the workers to return, the program could hang on L104.

Continuing ensures that the workers continue to consume the requests.

Obviously, you could optimize this further if an error was likely here, but this would be a programming error so I don't think it needs to be optimized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. good call.

continue
}
coeffs, err := g.GetInterpolationPolyCoeff(ys, uint32(j))
if err != nil {
results <- err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

continue
}

frames[k].Coeffs = coeffs
}

results <- nil
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be inside the for loop? It'll be then one result per JobRequest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are "numWorker" of threads, and "NumChunks" of jobs. If it is moved inside the for loop, the worker is terminated


}
4 changes: 4 additions & 0 deletions encoding/rs/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rs

import (
"math"
"runtime"

"github.com/Layr-Labs/eigenda/encoding"
"github.com/Layr-Labs/eigenda/encoding/fft"
Expand All @@ -13,6 +14,8 @@ type Encoder struct {
Fs *fft.FFTSettings

verbose bool

NumRSWorker int
}

// The function creates a high level struct that determines the encoding the a data of a
Expand All @@ -37,6 +40,7 @@ func NewEncoder(params encoding.EncodingParams, verbose bool) (*Encoder, error)
EncodingParams: params,
Fs: fs,
verbose: verbose,
NumRSWorker: runtime.GOMAXPROCS(0),
}, nil

}
23 changes: 7 additions & 16 deletions encoding/rs/interpolation.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ func (g *Encoder) GetInterpolationPolyEval(
//var tmp, tmp2 fr.Element
for i := 0; i < len(interpolationPoly); i++ {
shiftedInterpolationPoly[i].Mul(&interpolationPoly[i], &wPow)

wPow.Mul(&wPow, &w)

}

err := g.Fs.InplaceFFT(shiftedInterpolationPoly, evals, false)
Expand All @@ -66,26 +64,19 @@ func (g *Encoder) GetInterpolationPolyEval(
// Since both F W are invertible, c = W^-1 F^-1 d, convert it back. F W W^-1 F^-1 d = c
func (g *Encoder) GetInterpolationPolyCoeff(chunk []fr.Element, k uint32) ([]fr.Element, error) {
coeffs := make([]fr.Element, g.ChunkLength)
w := g.Fs.ExpandedRootsOfUnity[uint64(k)]
shiftedInterpolationPoly := make([]fr.Element, len(chunk))
err := g.Fs.InplaceFFT(chunk, shiftedInterpolationPoly, true)
if err != nil {
return coeffs, err
}
var wPow fr.Element
wPow.SetOne()

var tmp, tmp2 fr.Element

mod := int32(len(g.Fs.ExpandedRootsOfUnity) - 1)

for i := 0; i < len(chunk); i++ {
tmp.Inverse(&wPow)

tmp2.Mul(&shiftedInterpolationPoly[i], &tmp)

coeffs[i].Set(&tmp2)

tmp.Mul(&wPow, &w)

wPow.Set(&tmp)
// We can lookup the inverse power by counting RootOfUnity backward
j := (-int32(k)*int32(i))%mod + mod
coeffs[i].Mul(&shiftedInterpolationPoly[i], &g.Fs.ExpandedRootsOfUnity[j])
}

return coeffs, nil
}
Loading