Skip to content

Commit

Permalink
curve: Fix the RistrettoPoint.UnmarshalBinary method
Browse files Browse the repository at this point in the history
Teaches me a lesson about copy-paste adding these as an afterthought.
  • Loading branch information
Yawning committed May 20, 2021
1 parent 294cf0f commit 96af29d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
11 changes: 9 additions & 2 deletions curve/ristretto.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ func (p *CompressedRistretto) MarshalBinary() ([]byte, error) {
func (p *CompressedRistretto) UnmarshalBinary(data []byte) error {
p.Identity() // Foot + gun avoidance.

var ep EdwardsPoint
if err := ep.UnmarshalBinary(data); err != nil {
var rp RistrettoPoint
if err := rp.UnmarshalBinary(data); err != nil {
return err
}

Expand Down Expand Up @@ -411,6 +411,13 @@ func (p *RistrettoPoint) MultiscalarMulVartime(scalars []*scalar.Scalar, points
return p
}

// IsIdentity returns true iff the point is equivalent to the identity element
// of the curve.
func (p *RistrettoPoint) IsIdentity() bool {
var id RistrettoPoint
return p.Equal(id.Identity()) == 1
}

func (p *RistrettoPoint) elligatorRistrettoFlavor(r_0 *field.FieldElement) {
c := constMINUS_ONE

Expand Down
48 changes: 48 additions & 0 deletions curve/ristretto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
package curve

import (
"bytes"
"testing"

"github.com/oasisprotocol/curve25519-voi/curve/scalar"
Expand Down Expand Up @@ -60,6 +61,7 @@ func TestRistretto(t *testing.T) {
t.Run("Ristretto/FourTorsion/Random", testRistrettoFourTorsionRandom)
t.Run("Ristretto/Elligator", testRistrettoElligator)
t.Run("Ristretto/TestVectors", testRistrettoVectors)
t.Run("Ristretto/Serialization", testRistrettoSerialization)
}

func testRistrettoSum(t *testing.T) {
Expand Down Expand Up @@ -270,3 +272,49 @@ func testRistrettoRandomRoundtrip(t *testing.T) {
}
}
}

func testRistrettoSerialization(t *testing.T) {
var p RistrettoPoint

if _, err := p.SetRandom(nil); err != nil {
t.Fatalf("p.SetRandom: %v", err)
}
if p.IsIdentity() {
t.Fatalf("random point is identity???")
}

b, err := p.MarshalBinary()
if err != nil {
t.Fatalf("RistrettoPoint.MarshalBinary: %v", err)
}

// Check that RistrettoPoints round-trip.
var pp RistrettoPoint
if err = pp.UnmarshalBinary(b); err != nil {
t.Fatalf("RistrettoPoint.UnmarshalBinary: %v", err)
}
if p.Equal(&pp) != 1 {
t.Fatalf("p != pp (Got %v, %v)", p, pp)
}

// Check that CompressedRistrettos round-trip.
var pc CompressedRistretto
pp.Identity()
if err = pc.UnmarshalBinary(b); err != nil {
t.Fatalf("CompressedRistretto.UnmarshalBinary: %v", err)
}
if _, err = pp.SetCompressed(&pc); err != nil {
t.Fatalf("RistrettoPoint.SetCompressed: %v", err)
}
if p.Equal(&pp) != 1 {
t.Fatalf("compressed p != pp (Got %v, %v)", p, pp)
}

bb, err := pc.MarshalBinary()
if err != nil {
t.Fatalf("CompressedRistretto.MarshalBinary: %v", err)
}
if !bytes.Equal(bb, b) {
t.Fatalf("b != bb (Got %v, %v)", b, bb)
}
}

0 comments on commit 96af29d

Please sign in to comment.