Skip to content

Commit

Permalink
Populate "full" primitives in the PrimitiveSet
Browse files Browse the repository at this point in the history
A full primitive is constructed using `Config`'s `PrimitiveFromKey`. The keyset handle tries this first, and falls back to the legacy `PrimitiveFromKeyData` method otherwise.

PiperOrigin-RevId: 696027102
Change-Id: Icd28645081aec158cae7397a5e27490195e8110e
  • Loading branch information
morambro authored and copybara-github committed Nov 13, 2024
1 parent fbc3214 commit 7e40e1f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 31 deletions.
2 changes: 1 addition & 1 deletion aead/aead_factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (sc *stubConfig) PrimitiveFromKeyData(keyData *tinkpb.KeyData, _ internalap
}

func (sc *stubConfig) PrimitiveFromKey(_ key.Key, _ internalapi.Token) (any, error) {
return new(stubAEAD), nil
return nil, fmt.Errorf("unsupported")
}

func TestNewWithConfig(t *testing.T) {
Expand Down
42 changes: 25 additions & 17 deletions keyset/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ func (h *Handle) WriteWithNoSecrets(w Writer) error {
// The config.Config concrete type is not used directly due to circular dependencies.
type Config interface {
PrimitiveFromKeyData(keyData *tinkpb.KeyData, _ internalapi.Token) (any, error)
// PrimitiveFromKey creates a primitive from a key.Key.
// PrimitiveFromKey creates a primitive from a [key.Key].
PrimitiveFromKey(key key.Key, _ internalapi.Token) (any, error)
}

Expand Down Expand Up @@ -499,6 +499,29 @@ func (h *Handle) PrimitivesWithKeyManager(km registry.KeyManager, _ internalapi.
return p, nil
}

func addToPrimitiveSet(primitiveSet *primitiveset.PrimitiveSet, entry *Entry, km registry.KeyManager, config Config) (*primitiveset.Entry, error) {
protoKey, err := entryToProtoKey(entry)
if err != nil {
return nil, err
}
if km != nil && km.DoesSupport(protoKey.GetKeyData().GetTypeUrl()) {
primitive, err := km.Primitive(protoKey.GetKeyData().GetValue())
if err != nil {
return nil, fmt.Errorf("cannot get primitive from key: %v", err)
}
return primitiveSet.Add(primitive, protoKey)
}
primitive, err := config.PrimitiveFromKey(entry.Key(), internalapi.Token{})
if err == nil {
return primitiveSet.AddFullPrimitive(primitive, protoKey)
}
primitive, err = config.PrimitiveFromKeyData(protoKey.GetKeyData(), internalapi.Token{})
if err != nil {
return nil, fmt.Errorf("cannot get primitive from key data: %v", err)
}
return primitiveSet.Add(primitive, protoKey)
}

func (h *Handle) primitives(km registry.KeyManager, opts ...PrimitivesOption) (*primitiveset.PrimitiveSet, error) {
if h == nil {
return nil, fmt.Errorf("nil handle")
Expand All @@ -522,22 +545,7 @@ func (h *Handle) primitives(km registry.KeyManager, opts ...PrimitivesOption) (*
if entry.KeyStatus() != Enabled {
continue
}
protoKey, err := entryToProtoKey(entry)
if err != nil {
return nil, err
}
var primitive any
if km != nil && km.DoesSupport(protoKey.GetKeyData().GetTypeUrl()) {
primitive, err = km.Primitive(protoKey.GetKeyData().GetValue())
} else {
// TODO: b/369551049 - Use the new PrimitiveFromKey method once we have
// added tooling to distinguish between "full" and "partial" primitives.
primitive, err = config.PrimitiveFromKeyData(protoKey.GetKeyData(), internalapi.Token{})
}
if err != nil {
return nil, fmt.Errorf("cannot get primitive from key: %v", err)
}
primitiveSetEntry, err := primitiveSet.Add(primitive, protoKey)
primitiveSetEntry, err := addToPrimitiveSet(primitiveSet, entry, km, config)
if err != nil {
return nil, fmt.Errorf("cannot add primitive: %v", err)
}
Expand Down
80 changes: 67 additions & 13 deletions keyset/handle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"google.golang.org/protobuf/proto"
"github.com/tink-crypto/tink-go/v2/aead"
"github.com/tink-crypto/tink-go/v2/aead/aesgcm"
"github.com/tink-crypto/tink-go/v2/core/registry"
"github.com/tink-crypto/tink-go/v2/internal/internalapi"
"github.com/tink-crypto/tink-go/v2/internal/protoserialization"
Expand Down Expand Up @@ -869,30 +870,83 @@ func TestPrimitivesWithRegistry(t *testing.T) {

type testConfig struct{}

type stubPrimitive struct {
isFull bool
}

func (c *testConfig) PrimitiveFromKeyData(_ *tinkpb.KeyData, _ internalapi.Token) (any, error) {
return testPrimitive{}, nil
return &stubPrimitive{false}, nil
}

func (c *testConfig) PrimitiveFromKey(_ key.Key, _ internalapi.Token) (any, error) {
return testPrimitive{}, nil
func (c *testConfig) PrimitiveFromKey(k key.Key, _ internalapi.Token) (any, error) {
if _, ok := k.(*aesgcm.Key); !ok {
return nil, fmt.Errorf("Unable to create primitive from key")
}
return &stubPrimitive{true}, nil
}

func TestPrimitivesWithConfig(t *testing.T) {
template := mac.HMACSHA256Tag128KeyTemplate()
template.OutputPrefixType = tinkpb.OutputPrefixType_RAW
handle, err := keyset.NewHandle(template)
func TestPrimitives(t *testing.T) {
handle, err := keyset.NewHandle(mac.HMACSHA256Tag128KeyTemplate())
if err != nil {
t.Fatalf("keyset.NewHandle(%v) = %v, want nil", template, err)
t.Fatalf("keyset.NewHandle(%v) = %v, want nil", mac.HMACSHA256Tag128KeyTemplate(), err)
}
primitives, err := handle.Primitives(internalapi.Token{}, keyset.WithConfig(&testConfig{}))
primitives, err := handle.Primitives(internalapi.Token{})
if err != nil {
t.Fatalf("handle.Primitives(internalapi.Token{}, keyset.WithConfig(&testConfig{})) err = %v, want nil", err)
t.Fatalf("handle.Primitives(internalapi.Token{}) err = %v, want nil", err)
}
if len(primitives.EntriesInKeysetOrder) != 1 {
t.Fatalf("len(handle.Primitives(internalapi.Token{}, )) = %d, want 1", len(primitives.EntriesInKeysetOrder))
t.Fatalf("len(handle.Primitives(internalapi.Token{})) = %d, want 1", len(primitives.EntriesInKeysetOrder))
}
if _, ok := (primitives.Primary.Primitive).(testPrimitive); !ok {
t.Errorf("handle.Primitives(internalapi.Token{}, ).Primary = %v, want instance of `testPrimitive`", primitives.Primary.Primitive)
if primitives.Primary.Primitive == nil {
t.Fatalf("handle.Primitives(internalapi.Token{}).Primary.Primitive = nil, want instance of `stubPrimitive`")
}
if _, ok := primitives.Primary.Primitive.(tink.MAC); !ok {
t.Fatalf("handle.Primitives(internalapi.Token{}).Primary.Primitive = %T, want instance of `tink.MAC`", primitives.Primary.FullPrimitive)
}
}

func TestPrimitivesWithConfig(t *testing.T) {
for _, tc := range []struct {
name string
keyTemplate *tinkpb.KeyTemplate
wantFull bool
}{
{
name: "legacy primitive",
keyTemplate: mac.HMACSHA256Tag128KeyTemplate(),
wantFull: false,
},
{
name: "full primitive",
keyTemplate: aead.AES256GCMKeyTemplate(),
wantFull: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
handle, err := keyset.NewHandle(tc.keyTemplate)
if err != nil {
t.Fatalf("keyset.NewHandle(%v) = %v, want nil", tc.keyTemplate, err)
}
primitives, err := handle.Primitives(internalapi.Token{}, keyset.WithConfig(&testConfig{}))
if err != nil {
t.Fatalf("handle.Primitives(internalapi.Token{}, keyset.WithConfig(&testConfig{})) err = %v, want nil", err)
}
if len(primitives.EntriesInKeysetOrder) != 1 {
t.Fatalf("len(handle.Primitives(internalapi.Token{})) = %d, want 1", len(primitives.EntriesInKeysetOrder))
}
var p any
if tc.wantFull {
p = primitives.Primary.FullPrimitive
} else {
p = primitives.Primary.Primitive
}
if _, ok := p.(*stubPrimitive); !ok {
t.Fatalf("handle.Primitives(internalapi.Token{}).Primary.FullPrimitive = %v, want instance of `stubPrimitive`", p)
}
if p.(*stubPrimitive).isFull != tc.wantFull {
t.Errorf("handle.Primitives(internalapi.Token{}).Primary.FullPrimitive.isFull = %v, want %v", p.(*stubPrimitive).isFull, tc.wantFull)
}
})
}
}

Expand Down

0 comments on commit 7e40e1f

Please sign in to comment.