Skip to content

Commit

Permalink
feat: replace EmbedPublicKey by option
Browse files Browse the repository at this point in the history
  • Loading branch information
hacdias committed Jun 13, 2023
1 parent 9c7cc1d commit ca538d9
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 280 deletions.
84 changes: 53 additions & 31 deletions ipns/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,21 @@ const (
)

type options struct {
compatibleWithV1 bool
v1Compatibility bool
embedPublicKey *bool
}

type Option func(*options)

func CompatibleWithV1(compatible bool) Option {
return func(opts *options) {
opts.compatibleWithV1 = compatible
func WithV1Compatibility(compatible bool) Option {
return func(o *options) {
o.v1Compatibility = compatible
}
}

func WithPublicKey(embedded bool) Option {
return func(o *options) {
o.embedPublicKey = &embedded
}
}

Expand All @@ -214,7 +221,9 @@ func processOptions(opts ...Option) *options {
}

// NewRecord creates a new IPNS [Record] and signs it with the given private key.
// This function does not embed the public key. To do so, call [EmbedPublicKey].
// By default, we embed the public key for key types whose peer IDs do not encode
// the public key, such as RSA and ECDSA key types. This can be changed with the
// option [WithPublicKey].
func NewRecord(sk ic.PrivKey, value path.Path, seq uint64, eol time.Time, ttl time.Duration, opts ...Option) (*Record, error) {
options := processOptions(opts...)

Expand Down Expand Up @@ -243,7 +252,7 @@ func NewRecord(sk ic.PrivKey, value path.Path, seq uint64, eol time.Time, ttl ti
SignatureV2: sig2,
}

if options.compatibleWithV1 {
if options.v1Compatibility {
pb.Value = []byte(value)
typ := ipns_pb.IpnsEntry_EOL
pb.ValidityType = &typ
Expand All @@ -263,6 +272,24 @@ func NewRecord(sk ic.PrivKey, value path.Path, seq uint64, eol time.Time, ttl ti
pb.SignatureV1 = sig1
}

embedPublicKey := false
if options.embedPublicKey == nil {
embedPublicKey, err = needToEmbedPublicKey(sk.GetPublic())
if err != nil {
return nil, err
}
} else {
embedPublicKey = *options.embedPublicKey
}

if embedPublicKey {
pkBytes, err := ic.MarshalPublicKey(sk.GetPublic())
if err != nil {
return nil, err
}
pb.PubKey = pkBytes
}

return &Record{
pb: pb,
node: node,
Expand Down Expand Up @@ -396,31 +423,6 @@ func compare(a, b *Record) (int, error) {
return 0, nil
}

// EmbedPublicKey embeds the given public key in the given [Record]. While not
// strictly required, some nodes (e.g., DHT servers), may reject IPNS Records
// that do not embed their public keys as they may not be able to validate them
// efficiently.
func EmbedPublicKey(r *Record, pk ic.PubKey) error {
// Try extracting the public key from the ID. If we can, do not embed it.
pid, err := peer.IDFromPublicKey(pk)
if err != nil {
return err
}
if _, err := pid.ExtractPublicKey(); err != peer.ErrNoPublicKey {
// Either a *real* error or nil.
return err
}

// We failed to extract the public key from the peer ID, embed it.
pkBytes, err := ic.MarshalPublicKey(pk)
if err != nil {
return err
}

r.pb.PubKey = pkBytes
return nil
}

// ExtractPublicKey extracts a [crypto.PubKey] matching the given [peer.ID] from
// the IPNS Record, if possible.
func ExtractPublicKey(r *Record, pid peer.ID) (ic.PubKey, error) {
Expand Down Expand Up @@ -459,3 +461,23 @@ func Key(pid peer.ID) string {

return "/ipns/" + encoded
}

func needToEmbedPublicKey(pk ic.PubKey) (bool, error) {
// First try extracting the peer ID from the public key.
pid, err := peer.IDFromPublicKey(pk)
if err != nil {
return false, fmt.Errorf("cannot convert public key to peer ID: %w", err)
}

_, err = pid.ExtractPublicKey()
if err == nil {
// Can be extracted, therefore no need to embed the public key.
return false, nil
}

if errors.Is(err, peer.ErrNoPublicKey) {
return true, nil
}

return false, fmt.Errorf("cannot extract ID from public key: %w", err)
}
101 changes: 75 additions & 26 deletions ipns/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"
"time"

"github.com/gogo/protobuf/proto"
ipns_pb "github.com/ipfs/boxo/ipns/pb"
"github.com/ipfs/boxo/path"
"github.com/ipfs/boxo/util"
Expand Down Expand Up @@ -107,7 +108,7 @@ func TestNewRecord(t *testing.T) {
t.Run("V1+V2 with option", func(t *testing.T) {
t.Parallel()

rec := mustNewRecord(t, sk, testPath, seq, eol, ttl, CompatibleWithV1(true))
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl, WithV1Compatibility(true))
require.NotEmpty(t, rec.pb.SignatureV1)

_, err := rec.PubKey()
Expand All @@ -116,51 +117,50 @@ func TestNewRecord(t *testing.T) {
fieldsMatch(t, rec, testPath, seq, eol, ttl)
fieldsMatchV1(t, rec, testPath, seq, eol, ttl)
})
}

func TestEmbedPublicKey(t *testing.T) {
t.Parallel()

sk, pk, pid := mustKeyPair(t, ic.RSA)

seq := uint64(0)
eol := time.Now().Add(time.Hour)
ttl := time.Minute * 10

rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
t.Run("Public key embedded by default for RSA and ECDSA keys", func(t *testing.T) {
t.Parallel()

_, err := rec.PubKey()
require.ErrorIs(t, err, ErrPublicKeyNotFound)
for _, keyType := range []int{ic.RSA, ic.ECDSA} {
sk, _, _ := mustKeyPair(t, keyType)
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
fieldsMatch(t, rec, testPath, seq, eol, ttl)

err = EmbedPublicKey(rec, pk)
require.NoError(t, err)
pk, err := rec.PubKey()
require.NoError(t, err)
require.True(t, pk.Equals(sk.GetPublic()))
}
})

recPK, err := rec.PubKey()
require.NoError(t, err)
t.Run("Public key not embedded by default for Ed25519 keys", func(t *testing.T) {
t.Parallel()

recPID, err := peer.IDFromPublicKey(recPK)
require.NoError(t, err)
for _, keyType := range []int{ic.Ed25519, ic.Secp256k1} {
sk, _, _ := mustKeyPair(t, keyType)
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
fieldsMatch(t, rec, testPath, seq, eol, ttl)

require.Equal(t, pid, recPID)
_, err := rec.PubKey()
require.ErrorIs(t, err, ErrPublicKeyNotFound)
}
})
}

func TestExtractPublicKey(t *testing.T) {
t.Parallel()

t.Run("Returns expected public key when embedded in Peer ID", func(t *testing.T) {
sk, pk, pid := mustKeyPair(t, ic.Ed25519)
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10, WithPublicKey(false))

pk2, err := ExtractPublicKey(rec, pid)
require.Nil(t, err)
require.Equal(t, pk, pk2)
})

t.Run("Returns expected public key when embedded in Record", func(t *testing.T) {
t.Run("Returns expected public key when embedded in Record (by default)", func(t *testing.T) {
sk, pk, pid := mustKeyPair(t, ic.RSA)
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
err := EmbedPublicKey(rec, pk)
require.Nil(t, err)

pk2, err := ExtractPublicKey(rec, pid)
require.Nil(t, err)
Expand All @@ -169,7 +169,7 @@ func TestExtractPublicKey(t *testing.T) {

t.Run("Errors when not embedded in Record or Peer ID", func(t *testing.T) {
sk, _, pid := mustKeyPair(t, ic.RSA)
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10, WithPublicKey(false))

pk, err := ExtractPublicKey(rec, pid)
require.Error(t, err)
Expand Down Expand Up @@ -248,3 +248,52 @@ func TestCBORDataSerialization(t *testing.T) {
assert.Equal(t, expected, f)
}
}

func TestUnmarshal(t *testing.T) {
t.Parallel()

t.Run("Errors on invalid bytes", func(t *testing.T) {
_, err := UnmarshalRecord([]byte("blah blah blah"))
require.ErrorIs(t, err, ErrBadRecord)
})

t.Run("Errors if record is too long", func(t *testing.T) {
data := make([]byte, MaxRecordSize+1)
_, err := UnmarshalRecord(data)
require.ErrorIs(t, err, ErrRecordSize)
})

t.Run("Errors with V1-only records", func(t *testing.T) {
pb := ipns_pb.IpnsEntry{}
data, err := proto.Marshal(&pb)
require.NoError(t, err)
_, err = UnmarshalRecord(data)
require.ErrorIs(t, err, ErrDataMissing)
})

t.Run("Errors on bad data", func(t *testing.T) {
pb := ipns_pb.IpnsEntry{
Data: []byte("definitely not cbor"),
}
data, err := proto.Marshal(&pb)
require.NoError(t, err)
_, err = UnmarshalRecord(data)
require.ErrorIs(t, err, ErrBadRecord)
})
}

func TestKey(t *testing.T) {
for _, v := range [][]string{
{"RSA", "QmRp2LvtSQtCkUWCpi92ph5MdQyRtfb9jHbkNgZzGExGuG", "/ipns/k2k4r8kpauqq30hoj9oktej5btbgz1jeos16d3te36xd78trvak0jcor"},
{"Ed25519", "12D3KooWSzRuSFHgLsKr6jJboAPdP7xMga2YBgBspYuErxswcgvt", "/ipns/k51qzi5uqu5dmjjgoe7s21dncepi970722cn30qlhm9qridas1c9ktkjb6ejux"},
{"ECDSA", "QmSBUTocZ9LxE53Br9PDDcPWnR1FJQRv94U96Wkt8eypAw", "/ipns/k2k4r8ku8cnc1sl2h5xn7i07dma9abfnkqkxi4a6nd1xq0knoxe7b0y4"},
{"Secp256k1", "16Uiu2HAmUymv6JpFwNZppdKUMxGJuHsTeicXgHGKbBasu4Ruj3K1", "/ipns/kzwfwjn5ji4puw3jc1qw4b073j74xvq21iziuqw4rem21pr7f0l4dj8i9yb978s"},
} {
t.Run(v[0], func(t *testing.T) {
pid, err := peer.Decode(v[1])
require.NoError(t, err)
key := Key(pid)
require.Equal(t, v[2], key)
})
}
}
21 changes: 11 additions & 10 deletions ipns/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ func TestOrdering(t *testing.T) {
func TestValidator(t *testing.T) {
t.Parallel()

check := func(t *testing.T, sk ic.PrivKey, keybook peerstore.KeyBook, key string, val []byte, eol time.Time, exp error) {
check := func(t *testing.T, sk ic.PrivKey, keybook peerstore.KeyBook, key string, val []byte, eol time.Time, exp error, opts ...Option) {
validator := Validator{keybook}
data := val
if data == nil {
// do not call mustNewRecord because that validates the record!
rec, err := NewRecord(sk, testPath, 1, eol, 0)
rec, err := NewRecord(sk, testPath, 1, eol, 0, opts...)
require.NoError(t, err)
data = mustMarshal(t, rec)
}
Expand All @@ -99,9 +99,10 @@ func TestValidator(t *testing.T) {
check(t, sk, kb, RoutingKey(pid), nil, ts.Add(time.Hour*-1), ErrExpiredRecord)
check(t, sk, kb, RoutingKey(pid), []byte("bad data"), ts.Add(time.Hour), ErrBadRecord)
check(t, sk, kb, "/ipns/"+"bad key", nil, ts.Add(time.Hour), ErrKeyFormat)
check(t, sk, emptyKB, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyNotFound)
check(t, sk2, kb, RoutingKey(pid2), nil, ts.Add(time.Hour), ErrPublicKeyNotFound)
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrSignature)
check(t, sk, emptyKB, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyNotFound, WithPublicKey(false))
check(t, sk2, kb, RoutingKey(pid2), nil, ts.Add(time.Hour), ErrPublicKeyNotFound, WithPublicKey(false))
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyMismatch)
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrSignature, WithPublicKey(false))
check(t, sk, kb, "//"+string(pid), nil, ts.Add(time.Hour), ErrInvalidPath)
check(t, sk, kb, "/wrong/"+string(pid), nil, ts.Add(time.Hour), ErrInvalidPath)
})
Expand All @@ -128,14 +129,14 @@ func TestValidator(t *testing.T) {
kb, err := pstoremem.NewPeerstore()
require.NoError(t, err)

sk, pk, pid := mustKeyPair(t, ic.RSA)
rec := mustNewRecord(t, sk, testPath, 1, eol, 0)
sk, _, pid := mustKeyPair(t, ic.RSA)
rec := mustNewRecord(t, sk, testPath, 1, eol, 0, WithPublicKey(false))

// Fails with RSA key without embedded public key.
check(t, sk, kb, RoutingKey(pid), mustMarshal(t, rec), eol, ErrPublicKeyNotFound)

// Embeds public key, must work now.
require.NoError(t, EmbedPublicKey(rec, pk))
rec = mustNewRecord(t, sk, testPath, 1, eol, 0)
check(t, sk, kb, RoutingKey(pid), mustMarshal(t, rec), eol, nil)

// Force bad public key. Validation fails.
Expand Down Expand Up @@ -163,8 +164,8 @@ func TestValidate(t *testing.T) {

v := Validator{}

rec1 := mustNewRecord(t, sk, path.FromString("/path/1"), 1, eol, 0, CompatibleWithV1(true))
rec2 := mustNewRecord(t, sk, path.FromString("/path/2"), 2, eol, 0, CompatibleWithV1(true))
rec1 := mustNewRecord(t, sk, path.FromString("/path/1"), 1, eol, 0, WithV1Compatibility(true))
rec2 := mustNewRecord(t, sk, path.FromString("/path/2"), 2, eol, 0, WithV1Compatibility(true))

best, err := v.Select(ipnsRoutingKey, [][]byte{mustMarshal(t, rec1), mustMarshal(t, rec2)})
require.NoError(t, err)
Expand Down
Loading

0 comments on commit ca538d9

Please sign in to comment.