From 7e40e1f00a7b0b187cc944670e28b29cdf2c8a05 Mon Sep 17 00:00:00 2001 From: Moreno Ambrosin Date: Wed, 13 Nov 2024 01:15:21 -0800 Subject: [PATCH] Populate "full" primitives in the `PrimitiveSet` A full primitive is constructed using `Config`'s `PrimitiveFromKey`. The keyset handle tries this first, and falls back to the legacy `PrimitiveFromKeyData` method otherwise. PiperOrigin-RevId: 696027102 Change-Id: Icd28645081aec158cae7397a5e27490195e8110e --- aead/aead_factory_test.go | 2 +- keyset/handle.go | 42 +++++++++++--------- keyset/handle_test.go | 80 ++++++++++++++++++++++++++++++++------- 3 files changed, 93 insertions(+), 31 deletions(-) diff --git a/aead/aead_factory_test.go b/aead/aead_factory_test.go index c897437..532feb0 100644 --- a/aead/aead_factory_test.go +++ b/aead/aead_factory_test.go @@ -123,7 +123,7 @@ func (sc *stubConfig) PrimitiveFromKeyData(keyData *tinkpb.KeyData, _ internalap } func (sc *stubConfig) PrimitiveFromKey(_ key.Key, _ internalapi.Token) (any, error) { - return new(stubAEAD), nil + return nil, fmt.Errorf("unsupported") } func TestNewWithConfig(t *testing.T) { diff --git a/keyset/handle.go b/keyset/handle.go index 3101c49..c1d2a33 100644 --- a/keyset/handle.go +++ b/keyset/handle.go @@ -432,7 +432,7 @@ func (h *Handle) WriteWithNoSecrets(w Writer) error { // The config.Config concrete type is not used directly due to circular dependencies. type Config interface { PrimitiveFromKeyData(keyData *tinkpb.KeyData, _ internalapi.Token) (any, error) - // PrimitiveFromKey creates a primitive from a key.Key. + // PrimitiveFromKey creates a primitive from a [key.Key]. PrimitiveFromKey(key key.Key, _ internalapi.Token) (any, error) } @@ -499,6 +499,29 @@ func (h *Handle) PrimitivesWithKeyManager(km registry.KeyManager, _ internalapi. return p, nil } +func addToPrimitiveSet(primitiveSet *primitiveset.PrimitiveSet, entry *Entry, km registry.KeyManager, config Config) (*primitiveset.Entry, error) { + protoKey, err := entryToProtoKey(entry) + if err != nil { + return nil, err + } + if km != nil && km.DoesSupport(protoKey.GetKeyData().GetTypeUrl()) { + primitive, err := km.Primitive(protoKey.GetKeyData().GetValue()) + if err != nil { + return nil, fmt.Errorf("cannot get primitive from key: %v", err) + } + return primitiveSet.Add(primitive, protoKey) + } + primitive, err := config.PrimitiveFromKey(entry.Key(), internalapi.Token{}) + if err == nil { + return primitiveSet.AddFullPrimitive(primitive, protoKey) + } + primitive, err = config.PrimitiveFromKeyData(protoKey.GetKeyData(), internalapi.Token{}) + if err != nil { + return nil, fmt.Errorf("cannot get primitive from key data: %v", err) + } + return primitiveSet.Add(primitive, protoKey) +} + func (h *Handle) primitives(km registry.KeyManager, opts ...PrimitivesOption) (*primitiveset.PrimitiveSet, error) { if h == nil { return nil, fmt.Errorf("nil handle") @@ -522,22 +545,7 @@ func (h *Handle) primitives(km registry.KeyManager, opts ...PrimitivesOption) (* if entry.KeyStatus() != Enabled { continue } - protoKey, err := entryToProtoKey(entry) - if err != nil { - return nil, err - } - var primitive any - if km != nil && km.DoesSupport(protoKey.GetKeyData().GetTypeUrl()) { - primitive, err = km.Primitive(protoKey.GetKeyData().GetValue()) - } else { - // TODO: b/369551049 - Use the new PrimitiveFromKey method once we have - // added tooling to distinguish between "full" and "partial" primitives. - primitive, err = config.PrimitiveFromKeyData(protoKey.GetKeyData(), internalapi.Token{}) - } - if err != nil { - return nil, fmt.Errorf("cannot get primitive from key: %v", err) - } - primitiveSetEntry, err := primitiveSet.Add(primitive, protoKey) + primitiveSetEntry, err := addToPrimitiveSet(primitiveSet, entry, km, config) if err != nil { return nil, fmt.Errorf("cannot add primitive: %v", err) } diff --git a/keyset/handle_test.go b/keyset/handle_test.go index 348a631..3b64cd9 100644 --- a/keyset/handle_test.go +++ b/keyset/handle_test.go @@ -23,6 +23,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/tink-crypto/tink-go/v2/aead" + "github.com/tink-crypto/tink-go/v2/aead/aesgcm" "github.com/tink-crypto/tink-go/v2/core/registry" "github.com/tink-crypto/tink-go/v2/internal/internalapi" "github.com/tink-crypto/tink-go/v2/internal/protoserialization" @@ -869,30 +870,83 @@ func TestPrimitivesWithRegistry(t *testing.T) { type testConfig struct{} +type stubPrimitive struct { + isFull bool +} + func (c *testConfig) PrimitiveFromKeyData(_ *tinkpb.KeyData, _ internalapi.Token) (any, error) { - return testPrimitive{}, nil + return &stubPrimitive{false}, nil } -func (c *testConfig) PrimitiveFromKey(_ key.Key, _ internalapi.Token) (any, error) { - return testPrimitive{}, nil +func (c *testConfig) PrimitiveFromKey(k key.Key, _ internalapi.Token) (any, error) { + if _, ok := k.(*aesgcm.Key); !ok { + return nil, fmt.Errorf("Unable to create primitive from key") + } + return &stubPrimitive{true}, nil } -func TestPrimitivesWithConfig(t *testing.T) { - template := mac.HMACSHA256Tag128KeyTemplate() - template.OutputPrefixType = tinkpb.OutputPrefixType_RAW - handle, err := keyset.NewHandle(template) +func TestPrimitives(t *testing.T) { + handle, err := keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate()) if err != nil { - t.Fatalf("keyset.NewHandle(%v) = %v, want nil", template, err) + t.Fatalf("keyset.NewHandle(%v) = %v, want nil", mac.HMACSHA256Tag128KeyTemplate(), err) } - primitives, err := handle.Primitives(internalapi.Token{}, keyset.WithConfig(&testConfig{})) + primitives, err := handle.Primitives(internalapi.Token{}) if err != nil { - t.Fatalf("handle.Primitives(internalapi.Token{}, keyset.WithConfig(&testConfig{})) err = %v, want nil", err) + t.Fatalf("handle.Primitives(internalapi.Token{}) err = %v, want nil", err) } if len(primitives.EntriesInKeysetOrder) != 1 { - t.Fatalf("len(handle.Primitives(internalapi.Token{}, )) = %d, want 1", len(primitives.EntriesInKeysetOrder)) + t.Fatalf("len(handle.Primitives(internalapi.Token{})) = %d, want 1", len(primitives.EntriesInKeysetOrder)) } - if _, ok := (primitives.Primary.Primitive).(testPrimitive); !ok { - t.Errorf("handle.Primitives(internalapi.Token{}, ).Primary = %v, want instance of `testPrimitive`", primitives.Primary.Primitive) + if primitives.Primary.Primitive == nil { + t.Fatalf("handle.Primitives(internalapi.Token{}).Primary.Primitive = nil, want instance of `stubPrimitive`") + } + if _, ok := primitives.Primary.Primitive.(tink.MAC); !ok { + t.Fatalf("handle.Primitives(internalapi.Token{}).Primary.Primitive = %T, want instance of `tink.MAC`", primitives.Primary.FullPrimitive) + } +} + +func TestPrimitivesWithConfig(t *testing.T) { + for _, tc := range []struct { + name string + keyTemplate *tinkpb.KeyTemplate + wantFull bool + }{ + { + name: "legacy primitive", + keyTemplate: mac.HMACSHA256Tag128KeyTemplate(), + wantFull: false, + }, + { + name: "full primitive", + keyTemplate: aead.AES256GCMKeyTemplate(), + wantFull: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + handle, err := keyset.NewHandle(tc.keyTemplate) + if err != nil { + t.Fatalf("keyset.NewHandle(%v) = %v, want nil", tc.keyTemplate, err) + } + primitives, err := handle.Primitives(internalapi.Token{}, keyset.WithConfig(&testConfig{})) + if err != nil { + t.Fatalf("handle.Primitives(internalapi.Token{}, keyset.WithConfig(&testConfig{})) err = %v, want nil", err) + } + if len(primitives.EntriesInKeysetOrder) != 1 { + t.Fatalf("len(handle.Primitives(internalapi.Token{})) = %d, want 1", len(primitives.EntriesInKeysetOrder)) + } + var p any + if tc.wantFull { + p = primitives.Primary.FullPrimitive + } else { + p = primitives.Primary.Primitive + } + if _, ok := p.(*stubPrimitive); !ok { + t.Fatalf("handle.Primitives(internalapi.Token{}).Primary.FullPrimitive = %v, want instance of `stubPrimitive`", p) + } + if p.(*stubPrimitive).isFull != tc.wantFull { + t.Errorf("handle.Primitives(internalapi.Token{}).Primary.FullPrimitive.isFull = %v, want %v", p.(*stubPrimitive).isFull, tc.wantFull) + } + }) } }