diff --git a/internal/config/config.go b/internal/config/config.go index 6bb818a..2c6fcc0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,8 +19,10 @@ import ( "fmt" "reflect" + "github.com/tink-crypto/tink-go/v2/core/registry" "github.com/tink-crypto/tink-go/v2/internal/internalapi" "github.com/tink-crypto/tink-go/v2/key" + tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto" ) // Config keeps a collection of functions that create a primitive from @@ -29,10 +31,24 @@ import ( // This is an internal API. type Config struct { primitiveConstructors map[reflect.Type]primitiveConstructor + keysetManagers map[string]registry.KeyManager } type primitiveConstructor func(key key.Key) (any, error) +// PrimitiveFromKeyData creates a primitive from the given [tinkpb.KeyData]. +// Returns an error if there is no key manager registered for the given key +// type URL. +// +// This is an internal API. +func (c *Config) PrimitiveFromKeyData(kd *tinkpb.KeyData, _ internalapi.Token) (any, error) { + km, ok := c.keysetManagers[kd.GetTypeUrl()] + if !ok { + return nil, fmt.Errorf("PrimitiveFromKeyData: no key manager for key URL %v", kd.GetTypeUrl()) + } + return km.Primitive(kd.GetValue()) +} + // PrimitiveFromKey creates a primitive from the given [key.Key]. Returns an // error if there is no primitiveConstructor registered for the given key. // @@ -63,7 +79,23 @@ func (c *Config) RegisterPrimitiveConstructor(keyType reflect.Type, constructor return nil } +// RegisterKeyManger registers a key manager for a key type URL. +// +// Not thread-safe. +// +// This is an internal API. +func (c *Config) RegisterKeyManger(keyTypeURL string, km registry.KeyManager, _ internalapi.Token) error { + if _, ok := c.keysetManagers[keyTypeURL]; ok { + return fmt.Errorf("RegisterKeyManger: attempt to register a different key manager for %v", keyTypeURL) + } + c.keysetManagers[keyTypeURL] = km + return nil +} + // New creates an empty Config. func New() (*Config, error) { - return &Config{map[reflect.Type]primitiveConstructor{}}, nil + return &Config{ + primitiveConstructors: map[reflect.Type]primitiveConstructor{}, + keysetManagers: map[string]registry.KeyManager{}, + }, nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 5566591..c5c12ad 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -18,9 +18,11 @@ import ( "reflect" "testing" + "google.golang.org/protobuf/proto" "github.com/tink-crypto/tink-go/v2/internal/config" "github.com/tink-crypto/tink-go/v2/internal/internalapi" "github.com/tink-crypto/tink-go/v2/key" + tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto" ) type testParameters0 struct{} @@ -59,7 +61,7 @@ func (tk testKeyUnregistered) Parameters() key.Parameters { retur func (tk testKeyUnregistered) IDRequirement() (id uint32, required bool) { return 0, false } func (tk testKeyUnregistered) Equals(other key.Key) bool { return false } -func TestConfigWorks(t *testing.T) { +func TestConfigPrimitiveFromKeyWorks(t *testing.T) { testConfig, err := config.New() if err != nil { t.Fatalf("Config.New() err = %v, want nil", err) @@ -80,6 +82,54 @@ func TestConfigWorks(t *testing.T) { } } +const ( + typeURL0 = "type_url_0" + typeURL1 = "type_url_1" +) + +type stubKeyManager0 struct{} + +func (km *stubKeyManager0) Primitive(_ []byte) (any, error) { return &testPrimitive0{}, nil } +func (km *stubKeyManager0) NewKeyData(_ []byte) (*tinkpb.KeyData, error) { return nil, nil } +func (km *stubKeyManager0) DoesSupport(t string) bool { return t == typeURL0 } +func (km *stubKeyManager0) TypeURL() string { return typeURL0 } +func (km *stubKeyManager0) NewKey(serializedKeyFormat []byte) (proto.Message, error) { + return nil, nil +} + +type stubKeyManager1 struct{} + +func (km *stubKeyManager1) Primitive(_ []byte) (any, error) { return &testPrimitive1{}, nil } +func (km *stubKeyManager1) NewKeyData(_ []byte) (*tinkpb.KeyData, error) { return nil, nil } +func (km *stubKeyManager1) DoesSupport(t string) bool { return t == typeURL1 } +func (km *stubKeyManager1) TypeURL() string { return typeURL1 } +func (km *stubKeyManager1) NewKey(serializedKeyFormat []byte) (proto.Message, error) { return nil, nil } + +func TestConfigPrimitiveFromKeDataWorks(t *testing.T) { + testConfig, err := config.New() + if err != nil { + t.Fatalf("Config.New() err = %v, want nil", err) + } + token := internalapi.Token{} + + err = testConfig.RegisterKeyManger(typeURL0, &stubKeyManager0{}, token) + if err != nil { + t.Fatalf("testConfig.RegisterKeyManger() err = %v, want nil", err) + } + + keyData := &tinkpb.KeyData{ + TypeUrl: typeURL0, + Value: []byte("key"), + } + p0, err := testConfig.PrimitiveFromKeyData(keyData, token) + if err != nil { + t.Fatalf("testConfig.PrimitiveFromKeyData() err = %v, want nil", err) + } + if p0.(*testPrimitive0) == nil { + t.Errorf("Wrong primitive returned: got %T, want testPrimitive0", p0) + } +} + func TestMultiplePrimitiveConstructors(t *testing.T) { testConfig, err := config.New() if err != nil { @@ -112,6 +162,38 @@ func TestMultiplePrimitiveConstructors(t *testing.T) { } } +func TestMultipleKeyManagers(t *testing.T) { + testConfig, err := config.New() + if err != nil { + t.Fatalf("config.New() err = %v, want nil", err) + } + token := internalapi.Token{} + + err = testConfig.RegisterKeyManger(typeURL0, &stubKeyManager0{}, token) + if err != nil { + t.Fatalf("testConfig.RegisterKeyManger() err = %v, want nil", err) + } + err = testConfig.RegisterKeyManger(typeURL1, &stubKeyManager1{}, token) + if err != nil { + t.Fatalf("testConfig.RegisterKeyManger() err = %v, want nil", err) + } + + p0, err := testConfig.PrimitiveFromKeyData(&tinkpb.KeyData{TypeUrl: typeURL0, Value: []byte("key")}, token) + if err != nil { + t.Fatalf("testConfig.RegisterPrimitiveConstructor() err = %v, want nil", err) + } + if p0.(*testPrimitive0) == nil { + t.Errorf("Wrong primitive returned: got %T, want testPrimitive0", p0) + } + p1, err := testConfig.PrimitiveFromKeyData(&tinkpb.KeyData{TypeUrl: typeURL1, Value: []byte("key")}, token) + if err != nil { + t.Fatalf("testConfig.RegisterPrimitiveConstructor() err = %v, want nil", err) + } + if p1.(*testPrimitive1) == nil { + t.Errorf("Wrong primitive returned: got %T, want testPrimitive0", p1) + } +} + func TestRegisterDifferentPrimitiveConstructor(t *testing.T) { testConfig, err := config.New() if err != nil { @@ -131,6 +213,24 @@ func TestRegisterDifferentPrimitiveConstructor(t *testing.T) { } } +func TestRegisterDifferentKeyManagers(t *testing.T) { + testConfig, err := config.New() + if err != nil { + t.Fatalf("config.New() err = %v, want nil", err) + } + token := internalapi.Token{} + + err = testConfig.RegisterKeyManger(typeURL0, &stubKeyManager0{}, token) + if err != nil { + t.Fatalf("testConfig.RegisterKeyManger() err = %v, want nil", err) + } + + // Register another primitiveCreator for the same key type fails. + if err = testConfig.RegisterKeyManger(typeURL0, &stubKeyManager1{}, token); err == nil { + t.Errorf("testConfig.RegisterKeyManger() err = nil, want error") + } +} + func TestUnregisteredPrimitive(t *testing.T) { testConfig, err := config.New() if err != nil { @@ -151,3 +251,19 @@ func TestUnregisteredPrimitive(t *testing.T) { t.Errorf("testConfig.PrimitiveFromKey() err = nil, want error") } } + +func TestUnregisteredKeyManager(t *testing.T) { + testConfig, err := config.New() + if err != nil { + t.Fatalf("config.New() err = %v, want nil", err) + } + token := internalapi.Token{} + + if err = testConfig.RegisterKeyManger(typeURL0, &stubKeyManager0{}, token); err != nil { + t.Fatalf("testConfig.RegisterKeyManger() err = %v, want nil", err) + } + + if _, err := testConfig.PrimitiveFromKeyData(&tinkpb.KeyData{TypeUrl: typeURL1, Value: []byte("key")}, token); err == nil { + t.Errorf("testConfig.PrimitiveFromKey() err = nil, want error") + } +} diff --git a/internal/registryconfig/registry_config.go b/internal/registryconfig/registry_config.go index aa2152b..cad1749 100644 --- a/internal/registryconfig/registry_config.go +++ b/internal/registryconfig/registry_config.go @@ -24,8 +24,8 @@ import ( "github.com/tink-crypto/tink-go/v2/core/registry" "github.com/tink-crypto/tink-go/v2/internal/internalapi" - "github.com/tink-crypto/tink-go/v2/internal/protoserialization" "github.com/tink-crypto/tink-go/v2/key" + tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto" ) var ( @@ -39,6 +39,11 @@ type primitiveConstructor func(key key.Key) (any, error) // old global Registry through the new Configuration interface. type RegistryConfig struct{} +// PrimitiveFromKeyData constructs a primitive from a [key.Key] using the registry. +func (c *RegistryConfig) PrimitiveFromKeyData(keyData *tinkpb.KeyData, _ internalapi.Token) (any, error) { + return registry.PrimitiveFromKeyData(keyData) +} + // PrimitiveFromKey constructs a primitive from a [key.Key] using the registry. func (c *RegistryConfig) PrimitiveFromKey(key key.Key, _ internalapi.Token) (any, error) { if key == nil { @@ -46,12 +51,7 @@ func (c *RegistryConfig) PrimitiveFromKey(key key.Key, _ internalapi.Token) (any } constructor, found := primitiveConstructors[reflect.TypeOf(key)] if !found { - // Fallback to using the key manager. - keySerialization, err := protoserialization.SerializeKey(key) - if err != nil { - return nil, err - } - return registry.PrimitiveFromKeyData(keySerialization.KeyData()) + return nil, fmt.Errorf("no constructor found for key %T", key) } return constructor(key) } diff --git a/internal/registryconfig/registry_config_test.go b/internal/registryconfig/registry_config_test.go index 156287c..ddc36ba 100644 --- a/internal/registryconfig/registry_config_test.go +++ b/internal/registryconfig/registry_config_test.go @@ -19,21 +19,21 @@ import ( "testing" "google.golang.org/protobuf/proto" + "github.com/tink-crypto/tink-go/v2/aead" + aeadsubtle "github.com/tink-crypto/tink-go/v2/aead/subtle" "github.com/tink-crypto/tink-go/v2/core/registry" "github.com/tink-crypto/tink-go/v2/internal/internalapi" "github.com/tink-crypto/tink-go/v2/internal/protoserialization" "github.com/tink-crypto/tink-go/v2/internal/registryconfig" "github.com/tink-crypto/tink-go/v2/key" "github.com/tink-crypto/tink-go/v2/keyset" - "github.com/tink-crypto/tink-go/v2/mac" - "github.com/tink-crypto/tink-go/v2/mac/subtle" "github.com/tink-crypto/tink-go/v2/testutil" commonpb "github.com/tink-crypto/tink-go/v2/proto/common_go_proto" tinkpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto" ) func TestPrimitiveFromKey(t *testing.T) { - keyset, err := keyset.NewHandle(mac.HMACSHA256Tag256KeyTemplate()) + keyset, err := keyset.NewHandle(aead.AES256GCMKeyTemplate()) if err != nil { t.Fatalf("keyset.NewHandle() err = %v, want nil", err) } @@ -47,51 +47,90 @@ func TestPrimitiveFromKey(t *testing.T) { if err != nil { t.Errorf("registryConfig.PrimitiveFromKey() err = %v, want nil", err) } - if _, ok := p.(*subtle.HMAC); !ok { - t.Error("p is not of type *subtle.HMAC") + if _, ok := p.(*aeadsubtle.AESGCM); !ok { + t.Errorf("p is not of type *aeadsubtle.AESGCM; got %T", p) } } -func TestPrimitiveFromKeyErrors(t *testing.T) { +func TestPrimitiveFromKeyData(t *testing.T) { + keyset, err := keyset.NewHandle(aead.AES256GCMKeyTemplate()) + if err != nil { + t.Fatalf("keyset.NewHandle() err = %v, want nil", err) + } + entry, err := keyset.Entry(0) + if err != nil { + t.Fatalf("keyset.Entry() err = %v, want nil", err) + } + protoKey, err := protoserialization.SerializeKey(entry.Key()) + if err != nil { + t.Fatalf("protoserialization.SerializeKey() err = %v, want nil", err) + } registryConfig := ®istryconfig.RegistryConfig{} + p, err := registryConfig.PrimitiveFromKeyData(protoKey.KeyData(), internalapi.Token{}) + if err != nil { + t.Errorf("registryConfig.PrimitiveFromKey() err = %v, want nil", err) + } + if _, ok := p.(*aeadsubtle.AESGCM); !ok { + t.Error("p is not of type *aeadsubtle.AESGCM") + } +} +func TestPrimitiveFromKeyErrors(t *testing.T) { + registryConfig := ®istryconfig.RegistryConfig{} testCases := []struct { name string key key.Key + }{ + { + name: "unregistered key type", + key: &stubKey{}, + }, + { + name: "nil key", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if _, err := registryConfig.PrimitiveFromKey(tc.key, internalapi.Token{}); err == nil { + t.Errorf("registryConfig.PrimitiveFromKey() err = nil, want error") + } + }) + } +} + +func TestPrimitiveFromKeyDataErrors(t *testing.T) { + registryConfig := ®istryconfig.RegistryConfig{} + + testCases := []struct { + name string + keyData *tinkpb.KeyData }{ { name: "unregistered url", - key: func() key.Key { - key := testutil.NewHMACKeyData(commonpb.HashType_SHA256, 16) - key.TypeUrl = "some-unregistered-url" - keySerialization, err := protoserialization.NewKeySerialization(key, tinkpb.OutputPrefixType_TINK, 1) - if err != nil { - t.Fatalf("protoserialization.NewKeySerialization() err = %v, want nil", err) - } - return protoserialization.NewFallbackProtoKey(keySerialization) + keyData: func() *tinkpb.KeyData { + kd := testutil.NewHMACKeyData(commonpb.HashType_SHA256, 16) + kd.TypeUrl = "some url" + return kd }(), }, { name: "mismatching url", - key: func() key.Key { - key := testutil.NewHMACKeyData(commonpb.HashType_SHA256, 16) - key.TypeUrl = testutil.AESGCMTypeURL - keySerialization, err := protoserialization.NewKeySerialization(key, tinkpb.OutputPrefixType_TINK, 1) - if err != nil { - t.Fatalf("protoserialization.NewKeySerialization() err = %v, want nil", err) - } - return protoserialization.NewFallbackProtoKey(keySerialization) + keyData: func() *tinkpb.KeyData { + kd := testutil.NewHMACKeyData(commonpb.HashType_SHA256, 16) + kd.TypeUrl = testutil.AESGCMTypeURL + return kd }(), }, { - name: "nil key", + name: "nil KeyData", + keyData: nil, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if _, err := registryConfig.PrimitiveFromKey(tc.key, internalapi.Token{}); err == nil { - t.Errorf("registryConfig.PrimitiveFromKey() err = nil, want error") + if _, err := registryConfig.PrimitiveFromKeyData(tc.keyData, internalapi.Token{}); err == nil { + t.Errorf("registryConfig.Primitive() err = nil, want not-nil") } }) } @@ -170,41 +209,6 @@ func stubPrimitiveConstructorFromFallbackProtoKey(k key.Key) (any, error) { return &stubPrimitive{}, nil } -func TestRegisterPrimitiveConstructorUsesCreatorFirst(t *testing.T) { - defer registryconfig.ClearPrimitiveConstructors() - keyset, err := keyset.NewHandle(mac.HMACSHA256Tag256KeyTemplate()) - if err != nil { - t.Fatalf("keyset.NewHandle() err = %v, want nil", err) - } - entry, err := keyset.Entry(0) - if err != nil { - t.Fatalf("keyset.Entry() err = %v, want nil", err) - } - - registryConfig := ®istryconfig.RegistryConfig{} - p, err := registryConfig.PrimitiveFromKey(entry.Key(), internalapi.Token{}) - if err != nil { - t.Errorf("registryConfig.PrimitiveFromKey() err = %v, want nil", err) - } - if _, ok := p.(*subtle.HMAC); !ok { - t.Error("p is not of type *subtle.HMAC") - } - - rc := ®istryconfig.RegistryConfig{} - // We now register a constructor for protoserialization.FallbackProtoKey that - // returns a stubPrimitive instead of a HMAC. - if err := registryconfig.RegisterPrimitiveConstructor[*protoserialization.FallbackProtoKey](stubPrimitiveConstructorFromFallbackProtoKey); err != nil { - t.Errorf("registryconfig.RegisterPrimitiveConstructor[*protoserialization.FallbackProtoKey](stubPrimitiveConstructorFromFallbackProtoKey) err = %v, want nil", err) - } - p, err = rc.PrimitiveFromKey(entry.Key(), internalapi.Token{}) - if err != nil { - t.Errorf("registryConfig.PrimitiveFromKey() err = %v, want nil", err) - } - if _, ok := p.(*stubPrimitive); !ok { - t.Error("p is not of type *stubPrimitive") - } -} - func TestPrimitiveFromKeyFailsIfCreatorFails(t *testing.T) { defer registryconfig.ClearPrimitiveConstructors() if err := registryconfig.RegisterPrimitiveConstructor[*stubKey](alwaysFailingStubPrimitiveConstructor); err != nil { diff --git a/keyset/handle.go b/keyset/handle.go index 0ede22a..2406eb3 100644 --- a/keyset/handle.go +++ b/keyset/handle.go @@ -400,6 +400,7 @@ func (h *Handle) WriteWithNoSecrets(w Writer) error { // Config defines methods in the config.Config concrete type that are used by keyset.Handle. // The config.Config concrete type is not used directly due to circular dependencies. type Config interface { + PrimitiveFromKeyData(keyData *tinkpb.KeyData, _ internalapi.Token) (any, error) // PrimitiveFromKey creates a primitive from a key.Key. PrimitiveFromKey(key key.Key, _ internalapi.Token) (any, error) } @@ -494,7 +495,9 @@ func (h *Handle) primitives(km registry.KeyManager, opts ...PrimitivesOption) (* if km != nil && km.DoesSupport(protoKey.GetKeyData().GetTypeUrl()) { primitive, err = km.Primitive(protoKey.GetKeyData().GetValue()) } else { - primitive, err = config.PrimitiveFromKey(entry.Key(), internalapi.Token{}) + // TODO: b/369551049 - Use the new PrimitiveFromKey method once we have + // added tooling to distinguish between "full" and "partial" primitives. + primitive, err = config.PrimitiveFromKeyData(protoKey.GetKeyData(), internalapi.Token{}) } if err != nil { return nil, fmt.Errorf("cannot get primitive from key: %v", err)