Skip to content

Commit

Permalink
Add support for full primitives in the verifier factory.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696498129
Change-Id: I1dae2e123214a4d8e3ecf878de2655aef855e13d
  • Loading branch information
morambro authored and copybara-github committed Nov 14, 2024
1 parent 1bbf2cc commit 560d705
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 83 deletions.
193 changes: 160 additions & 33 deletions signature/signature_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/tink-crypto/tink-go/v2/testing/fakemonitoring"
"github.com/tink-crypto/tink-go/v2/testkeyset"
"github.com/tink-crypto/tink-go/v2/testutil"
"github.com/tink-crypto/tink-go/v2/tink"
commonpb "github.com/tink-crypto/tink-go/v2/proto/common_go_proto"
tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
)
Expand Down Expand Up @@ -486,42 +487,90 @@ func TestVerifyWithLegacyKeyDoesNotHaveSideEffectOnMessage(t *testing.T) {
}
}

const stubKeyURL = "type.googleapis.com/google.crypto.tink.SomeKey"
const stubPrivateKeyURL = "type.googleapis.com/google.crypto.tink.StubPrivateKey"
const stubPublicKeyURL = "type.googleapis.com/google.crypto.tink.StubPublicKey"

type stubFullSigner struct{}

func (s *stubFullSigner) Sign(data []byte) ([]byte, error) {
return slices.Concat([]byte("full_signer_prefix"), data), nil
return slices.Concat([]byte("full_primitive_prefix"), data), nil
}

type stubFullVerifier struct{}

func (s *stubFullVerifier) Verify(sig, data []byte) error {
if !bytes.Equal(sig, slices.Concat([]byte("full_primitive_prefix"), data)) {
return fmt.Errorf("invalid signature %s", sig)
}
return nil
}

var _ tink.Verifier = (*stubFullVerifier)(nil)

type stubParams struct{}

var _ key.Parameters = (*stubParams)(nil)

func (p *stubParams) Equals(_ key.Parameters) bool { return true }
func (p *stubParams) HasIDRequirement() bool { return true }

type stubPublicKey struct{}
type stubPublicKey struct {
prefixType tinkpb.OutputPrefixType
idRequirement uint32
}

var _ key.Key = (*stubPublicKey)(nil)

func (p *stubPublicKey) Equals(_ key.Key) bool { return true }
func (p *stubPublicKey) Parameters() key.Parameters { return &stubParams{} }
func (p *stubPublicKey) IDRequirement() (uint32, bool) { return 0, false }
func (p *stubPublicKey) HasIDRequirement() bool { return true }
func (p *stubPublicKey) IDRequirement() (uint32, bool) { return p.idRequirement, p.HasIDRequirement() }
func (p *stubPublicKey) HasIDRequirement() bool { return p.prefixType == tinkpb.OutputPrefixType_RAW }

type stubPublicKeySerialization struct{}

var _ protoserialization.KeySerializer = (*stubPublicKeySerialization)(nil)

func (s *stubPublicKeySerialization) SerializeKey(key key.Key) (*protoserialization.KeySerialization, error) {
return protoserialization.NewKeySerialization(
&tinkpb.KeyData{
TypeUrl: stubPublicKeyURL,
Value: []byte("serialized_public_key"),
KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
},
key.(*stubPublicKey).prefixType,
key.(*stubPublicKey).idRequirement,
)
}

type stubPublicKeyParser struct{}

var _ protoserialization.KeyParser = (*stubPublicKeyParser)(nil)

func (s *stubPublicKeyParser) ParseKey(serialization *protoserialization.KeySerialization) (key.Key, error) {
idRequirement, _ := serialization.IDRequirement()
return &stubPublicKey{
prefixType: serialization.OutputPrefixType(),
idRequirement: idRequirement,
}, nil
}

type stubPrivateKey struct {
prefixType tinkpb.OutputPrefixType
idRequrement uint32
idRequirement uint32
}

var _ key.Key = (*stubPrivateKey)(nil)

func (p *stubPrivateKey) Equals(_ key.Key) bool { return true }
func (p *stubPrivateKey) Parameters() key.Parameters { return &stubParams{} }
func (p *stubPrivateKey) IDRequirement() (uint32, bool) { return p.idRequrement, p.HasIDRequirement() }
func (p *stubPrivateKey) IDRequirement() (uint32, bool) { return p.idRequirement, p.HasIDRequirement() }
func (p *stubPrivateKey) HasIDRequirement() bool { return p.prefixType != tinkpb.OutputPrefixType_RAW }
func (p *stubPrivateKey) PublicKey() (key.Key, error) { return &stubPublicKey{}, nil }
func (p *stubPrivateKey) PublicKey() (key.Key, error) {
return &stubPublicKey{
prefixType: p.prefixType,
idRequirement: p.idRequirement,
}, nil
}

type stubPrivateKeySerialization struct{}

Expand All @@ -530,12 +579,12 @@ var _ protoserialization.KeySerializer = (*stubPrivateKeySerialization)(nil)
func (s *stubPrivateKeySerialization) SerializeKey(key key.Key) (*protoserialization.KeySerialization, error) {
return protoserialization.NewKeySerialization(
&tinkpb.KeyData{
TypeUrl: stubKeyURL,
TypeUrl: stubPrivateKeyURL,
Value: []byte("serialized_key"),
KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PRIVATE,
},
key.(*stubPrivateKey).prefixType,
key.(*stubPrivateKey).idRequrement,
key.(*stubPrivateKey).idRequirement,
)
}

Expand All @@ -545,31 +594,45 @@ var _ protoserialization.KeyParser = (*stubPrivateKeyParser)(nil)

func (s *stubPrivateKeyParser) ParseKey(serialization *protoserialization.KeySerialization) (key.Key, error) {
idRequirement, _ := serialization.IDRequirement()
return &stubPrivateKey{serialization.OutputPrefixType(), idRequirement}, nil
return &stubPrivateKey{
prefixType: serialization.OutputPrefixType(),
idRequirement: idRequirement,
}, nil
}

func TestPrimitiveFactoryUsesFullPrimitiveIfRegistered(t *testing.T) {
defer registryconfig.ClearPrimitiveConstructors()
defer protoserialization.ClearKeyParsers()
defer protoserialization.UnregisterKeySerializer[*stubPrivateKey]()
defer protoserialization.UnregisterKeySerializer[*stubPublicKey]()

if err := protoserialization.RegisterKeyParser(stubKeyURL, &stubPrivateKeyParser{}); err != nil {
if err := protoserialization.RegisterKeyParser(stubPublicKeyURL, &stubPublicKeyParser{}); err != nil {
t.Fatalf("protoserialization.RegisterKeyParser() err = %v, want nil", err)
}
if err := protoserialization.RegisterKeySerializer[*stubPublicKey](&stubPublicKeySerialization{}); err != nil {
t.Fatalf("protoserialization.RegisterKeySerializer() err = %v, want nil", err)
}
if err := protoserialization.RegisterKeyParser(stubPrivateKeyURL, &stubPrivateKeyParser{}); err != nil {
t.Fatalf("protoserialization.RegisterKeyParser() err = %v, want nil", err)
}
if err := protoserialization.RegisterKeySerializer[*stubPrivateKey](&stubPrivateKeySerialization{}); err != nil {
t.Fatalf("protoserialization.RegisterKeySerializer() err = %v, want nil", err)
}
// Register a primitive constructor to make sure that the factory uses the
// full primitive.
primitiveConstructor := func(key key.Key) (any, error) { return &stubFullSigner{}, nil }
if err := registryconfig.RegisterPrimitiveConstructor[*stubPrivateKey](primitiveConstructor); err != nil {
// Register primitive constructors to make sure that the factory uses full
// primitives.
signerConstructor := func(key key.Key) (any, error) { return &stubFullSigner{}, nil }
if err := registryconfig.RegisterPrimitiveConstructor[*stubPrivateKey](signerConstructor); err != nil {
t.Fatalf("registryconfig.RegisterPrimitiveConstructor() err = %v, want nil", err)
}
verifierConstructor := func(key key.Key) (any, error) { return &stubFullVerifier{}, nil }
if err := registryconfig.RegisterPrimitiveConstructor[*stubPublicKey](verifierConstructor); err != nil {
t.Fatalf("registryconfig.RegisterPrimitiveConstructor() err = %v, want nil", err)
}

km := keyset.NewManager()
keyID, err := km.AddKey(&stubPrivateKey{
tinkpb.OutputPrefixType_TINK,
0x1234,
prefixType: tinkpb.OutputPrefixType_RAW,
idRequirement: 0,
})
if err != nil {
t.Fatalf("km.AddKey() err = %v, want nil", err)
Expand All @@ -587,47 +650,98 @@ func TestPrimitiveFactoryUsesFullPrimitiveIfRegistered(t *testing.T) {
t.Fatalf("signature.NewSigner() err = %v, want nil", err)
}
data := []byte("data")
signature, err := signer.Sign(data)
sig, err := signer.Sign(data)
if err != nil {
t.Fatalf("signer.Sign() err = %v, want nil", err)
}
if !bytes.Equal(signature, slices.Concat([]byte("full_signer_prefix"), data)) {
t.Errorf("signature = %q, want: %q", signature, data)
if !bytes.Equal(sig, slices.Concat([]byte("full_primitive_prefix"), data)) {
t.Errorf("sig = %q, want: %q", sig, data)
}

// Try verifying the signature.
publicHandle, err := handle.Public()
if err != nil {
t.Fatalf("handle.Public() err = %v, want nil", err)
}
verifier, err := signature.NewVerifier(publicHandle)
if err != nil {
t.Fatalf("signature.NewVerifier() err = %v, want nil", err)
}

if err := verifier.Verify(sig, data); err != nil {
t.Errorf("verifier.Verify() err = %v, want nil", err)
}
}

type stubLegacySigner struct{}

var _ tink.Signer = (*stubLegacySigner)(nil)

func (s *stubLegacySigner) Sign(data []byte) ([]byte, error) {
return slices.Concat([]byte("legacy_signer_prefix"), data), nil
}

type stubKeyManager struct{}
type stubPrivateKeyManager struct{}

var _ registry.KeyManager = (*stubPrivateKeyManager)(nil)

func (km *stubPrivateKeyManager) NewKey(_ []byte) (proto.Message, error) {
return nil, fmt.Errorf("not implemented")
}
func (km *stubPrivateKeyManager) NewKeyData(_ []byte) (*tinkpb.KeyData, error) {
return nil, fmt.Errorf("not implemented")
}
func (km *stubPrivateKeyManager) DoesSupport(keyURL string) bool { return keyURL == stubPrivateKeyURL }
func (km *stubPrivateKeyManager) TypeURL() string { return stubPrivateKeyURL }
func (km *stubPrivateKeyManager) Primitive(_ []byte) (any, error) { return &stubLegacySigner{}, nil }

type stubLegacyVerifier struct{}

var _ tink.Verifier = (*stubLegacyVerifier)(nil)

func (s *stubLegacyVerifier) Verify(sig, data []byte) error {
if !bytes.Equal(sig, slices.Concat([]byte("legacy_signer_prefix"), data)) {
return fmt.Errorf("invalid data")
}
return nil
}

type stubPublicKeyManager struct{}

var _ registry.KeyManager = (*stubKeyManager)(nil)
var _ registry.KeyManager = (*stubPublicKeyManager)(nil)

func (km *stubKeyManager) NewKey(_ []byte) (proto.Message, error) {
func (km *stubPublicKeyManager) NewKey(_ []byte) (proto.Message, error) {
return nil, fmt.Errorf("not implemented")
}
func (km *stubKeyManager) NewKeyData(_ []byte) (*tinkpb.KeyData, error) {
func (km *stubPublicKeyManager) NewKeyData(_ []byte) (*tinkpb.KeyData, error) {
return nil, fmt.Errorf("not implemented")
}
func (km *stubKeyManager) DoesSupport(keyURL string) bool { return keyURL == stubKeyURL }
func (km *stubKeyManager) TypeURL() string { return stubKeyURL }
func (km *stubKeyManager) Primitive(_ []byte) (any, error) { return &stubLegacySigner{}, nil }
func (km *stubPublicKeyManager) DoesSupport(keyURL string) bool { return keyURL == stubPublicKeyURL }
func (km *stubPublicKeyManager) TypeURL() string { return stubPublicKeyURL }
func (km *stubPublicKeyManager) Primitive(_ []byte) (any, error) { return &stubLegacyVerifier{}, nil }

func TestPrimitiveFactoryUsesLegacyPrimitive(t *testing.T) {
defer protoserialization.ClearKeyParsers()
defer protoserialization.UnregisterKeySerializer[*stubPrivateKey]()
defer protoserialization.UnregisterKeySerializer[*stubPublicKey]()

if err := protoserialization.RegisterKeyParser(stubKeyURL, &stubPrivateKeyParser{}); err != nil {
if err := protoserialization.RegisterKeyParser(stubPublicKeyURL, &stubPublicKeyParser{}); err != nil {
t.Fatalf("protoserialization.RegisterKeyParser() err = %v, want nil", err)
}
if err := protoserialization.RegisterKeySerializer[*stubPublicKey](&stubPublicKeySerialization{}); err != nil {
t.Fatalf("protoserialization.RegisterKeySerializer() err = %v, want nil", err)
}
if err := protoserialization.RegisterKeyParser(stubPrivateKeyURL, &stubPrivateKeyParser{}); err != nil {
t.Fatalf("protoserialization.RegisterKeyParser() err = %v, want nil", err)
}
if err := protoserialization.RegisterKeySerializer[*stubPrivateKey](&stubPrivateKeySerialization{}); err != nil {
t.Fatalf("protoserialization.RegisterKeySerializer() err = %v, want nil", err)
}

if err := registry.RegisterKeyManager(&stubKeyManager{}); err != nil {
if err := registry.RegisterKeyManager(&stubPrivateKeyManager{}); err != nil {
t.Fatalf("registry.RegisterKeyManager() err = %v, want nil", err)
}
if err := registry.RegisterKeyManager(&stubPublicKeyManager{}); err != nil {
t.Fatalf("registry.RegisterKeyManager() err = %v, want nil", err)
}

Expand Down Expand Up @@ -678,12 +792,25 @@ func TestPrimitiveFactoryUsesLegacyPrimitive(t *testing.T) {
if err != nil {
t.Fatalf("signature.NewSigner() err = %v, want nil", err)
}
signature, err := signer.Sign(data)
sig, err := signer.Sign(data)
if err != nil {
t.Fatalf("signer.Sign() err = %v, want nil", err)
}
if !bytes.Equal(signature, tc.wantSigature) {
t.Errorf("signature = %q, want: %q", signature, data)
if !bytes.Equal(sig, tc.wantSigature) {
t.Errorf("sig = %q, want: %q", sig, data)
}

// Try verifying the signature.
publicHandle, err := handle.Public()
if err != nil {
t.Fatalf("handle.Public() err = %v, want nil", err)
}
verifier, err := signature.NewVerifier(publicHandle)
if err != nil {
t.Fatalf("signature.NewVerifier() err = %v, want nil", err)
}
if err := verifier.Verify(sig, data); err != nil {
t.Errorf("verifier.Verify() err = %v, want nil", err)
}
})
}
Expand Down
Loading

0 comments on commit 560d705

Please sign in to comment.