diff --git a/extendeddatacrossword_test.go b/extendeddatacrossword_test.go index bd1982e..cb4ed67 100644 --- a/extendeddatacrossword_test.go +++ b/extendeddatacrossword_test.go @@ -2,6 +2,7 @@ package rsmt2d import ( "bytes" + crand "crypto/rand" "errors" "fmt" "math/rand" @@ -423,6 +424,186 @@ func TestCorruptedEdsReturnsErrByzantineData_UnorderedShares(t *testing.T) { } } +func TestFuzzRandByzantine(t *testing.T) { + // This test is slow and should be skipped during normal testing + t.Skip() + for i := 0; i < 10000; i++ { + TestErrRandByzantine(t) + } +} + +func TestErrRandByzantine(t *testing.T) { + codec := NewLeoRSCodec() + original, corrupted, idx := randCorruptedEDS(t, codec, 8) + require.False(t, original.Equals(corrupted), "corrupted eds is equal to original eds") + + newEds, err := repairNewFromCorrupted(codec, corrupted, idx) + if err != nil && newEds != nil { + // visual check of the new eds + prettyPrintEds(newEds) + fmt.Println("new eds is original", original.Equals(newEds)) + fmt.Println("new eds is corrupted", corrupted.Equals(newEds)) + } + require.NoError(t, err, "failure to reconstruct the extended data square") +} + +func randCorruptedEDS(t require.TestingT, codec Codec, size int) (original, corrupted *ExtendedDataSquare, idx int) { + ds := genRandDS(size, shareSize) + original, err := ComputeExtendedDataSquare(ds, codec, NewDefaultTree) + require.NoError(t, err) + + // create random share + randShare := make([]byte, shareSize) + _, _ = crand.Read(randShare) + + // choose a random share to corrupt + shares := original.Flattened() + idx = rand.Intn(len(shares)) + + // copy namespace to avoid namespace ordering issues + copy(randShare, shares[idx][:nmt.DefaultNamespaceIDLen]) + + // corrupt the share + shares[idx] = randShare + + corrupted, err = ImportExtendedDataSquare( + shares, + codec, + NewDefaultTree) + require.NoError(t, err) + return original, corrupted, idx +} + +func repairNewFromCorrupted(codec Codec, corrupted *ExtendedDataSquare, corruptedIdx int) (*ExtendedDataSquare, error) { + samples := make([][]bool, corrupted.Width()) + for i := range samples { + samples[i] = make([]bool, corrupted.Width()) + } + + square, err := NewExtendedDataSquare( + codec, + NewDefaultTree, + corrupted.Width(), + shareSize, + ) + if err != nil { + return nil, fmt.Errorf("failure to create extended data square: %w", err) + } + + // set corrupted share first + corruptedX, corruptedY := corruptedIdx/int(corrupted.Width()), corruptedIdx%int(corrupted.Width()) + share := corrupted.GetCell(uint(corruptedX), uint(corruptedY)) + err = square.SetCell(uint(corruptedX), uint(corruptedY), share) + if err != nil { + return nil, fmt.Errorf("failure to set corrupted share: %w", err) + } + + rowRoots, err := corrupted.RowRoots() + if err != nil { + return nil, fmt.Errorf("failure to get row roots: %w", err) + } + colRoots, err := corrupted.ColRoots() + if err != nil { + return nil, fmt.Errorf("failure to get column roots: %w", err) + } + + // loop until repaired or byzantine error + for { + repaired, err := fillRandomCellAndRepair(corrupted, square, rowRoots, colRoots, samples) + if repaired { + prettyPrintSamples(samples, corruptedIdx) + return square, errors.New("no byzantine error") + } + var errByz *ErrByzantineData + if errors.As(err, &errByz) { + err = checkErrByzantine(errByz, corruptedX, corruptedY) + if err != nil { + prettyPrintSamples(samples, corruptedIdx) + } + return square, err + } + } +} + +func fillRandomCellAndRepair( + eds, square *ExtendedDataSquare, + rowRoots, colRoots [][]byte, + samples [][]bool, +) (repaired bool, err error) { + // select random share + x, y := rand.Intn(int(eds.Width())), rand.Intn(int(eds.Width())) + + // skip if share is already set + if square.GetCell(uint(x), uint(y)) != nil { + return false, nil + } + + share := eds.GetCell(uint(x), uint(y)) + err = square.SetCell(uint(x), uint(y), share) + if err != nil { + return false, fmt.Errorf("failure to set cell: %w", err) + } + samples[x][y] = true + + err = square.Repair(rowRoots, colRoots) + if err != nil { + return false, err + } + return true, nil +} + +func checkErrByzantine(errByz *ErrByzantineData, x, y int) error { + var axisIdx int + if errByz.Axis == Row { + axisIdx = x + } else { + axisIdx = y + } + + if errByz.Index != uint(axisIdx) { + return fmt.Errorf("byzantine error index mismatch: got %s, want %d", errByz, axisIdx) + } + return nil +} + +// prettyPrintSamples prints coordinates of shares in the 2D array +func prettyPrintSamples(samples [][]bool, corruptedIdx int) { + fmt.Println("SAMPLES", corruptedIdx) + for i, row := range samples { + for j, sampled := range row { + if corruptedIdx == i*len(samples)+j { + if !sampled { + fmt.Print("x ") + continue + } + fmt.Print("X ") + continue + } + if !sampled { + fmt.Print(". ") + continue + } + fmt.Print("O ") + } + fmt.Println() + } +} + +func prettyPrintEds(eds *ExtendedDataSquare) { + fmt.Println("EDS") + for r := 0; r < int(eds.Width()); r++ { + for _, sh := range eds.Row(uint(r)) { + if sh == nil { + fmt.Print(". ") + continue + } + fmt.Print("O ") + } + fmt.Println() + } + fmt.Println() +} + // createTestEdsWithNMT creates an extended data square with the given shares and namespace size. // Shares are placed in row-major order. // The first namespaceSize bytes of each share are treated as its namespace.