Skip to content

Commit

Permalink
pkg/auth/spannerauth: add method for retrieving invalidated records
Browse files Browse the repository at this point in the history
This change adds a spannerauth method for retrieving records that have
been invalidated. This will be used when adding spannerauth support to
the authservice admin client.

References storj-private#507

Change-Id: If9aa2780d91824f7052af016bc5e4a7ef3f9b82e
  • Loading branch information
jewharton authored and Storj Robot committed Dec 20, 2023
1 parent d45a9c6 commit 14f7705
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 40 deletions.
16 changes: 16 additions & 0 deletions pkg/auth/authdb/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ type Record struct {
Public bool // if true, knowledge of secret key is not required
}

// FullRecord extends Record and includes invalidation information.
type FullRecord struct {
Record
InvalidatedAt time.Time
InvalidationReason string
}

// IsInvalid returns whether the record was invalidated.
func (f FullRecord) IsInvalid() bool {
return f.InvalidationReason != "" || !f.InvalidatedAt.IsZero()
}

// KeyHashSizeEncoded is the length of a hex encoded KeyHash.
const KeyHashSizeEncoded = 64

Expand Down Expand Up @@ -94,6 +106,10 @@ type Storage interface {
type StorageAdmin interface {
Storage

// GetFullRecord retrieves a record with invalidation information.
// It returns (nil, nil) if the key does not exist.
GetFullRecord(ctx context.Context, keyHash KeyHash) (record *FullRecord, err error)

// Invalidate invalidates the record.
Invalidate(ctx context.Context, keyHash KeyHash, reason string) error

Expand Down
39 changes: 30 additions & 9 deletions pkg/auth/spannerauth/spannerauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,22 @@ func (d *CloudDatabase) PutWithCreatedAt(ctx context.Context, keyHash authdb.Key
// It returns (nil, nil) if the key does not exist.
// If the record is invalid, the error contains why.
func (d *CloudDatabase) Get(ctx context.Context, keyHash authdb.KeyHash) (_ *authdb.Record, err error) {
full, err := d.GetFullRecord(ctx, keyHash)
if err != nil {
return nil, err
}
if full == nil {
return nil, nil
}
if full.IsInvalid() {
return nil, Error.Wrap(authdb.Invalid.New("%s", full.InvalidationReason))
}
return &full.Record, nil
}

// GetFullRecord retrieves the record from the remote Cloud Spanner database.
// It returns (nil, nil) if the key does not exist.
func (d *CloudDatabase) GetFullRecord(ctx context.Context, keyHash authdb.KeyHash) (_ *authdb.FullRecord, err error) {
defer mon.Task()(&ctx)(&err)

key := spanner.Key{keyHash.Bytes()}
Expand All @@ -143,6 +159,7 @@ func (d *CloudDatabase) Get(ctx context.Context, keyHash authdb.KeyHash) (_ *aut
"encrypted_secret_key",
"encrypted_access_grant",
"invalidation_reason",
"invalidated_at",
}

boundedTx := d.client.Single().WithTimestampBound(spanner.ExactStaleness(defaultExactStaleness))
Expand All @@ -166,15 +183,7 @@ func (d *CloudDatabase) Get(ctx context.Context, keyHash authdb.KeyHash) (_ *aut
}
}

var invalidationReason spanner.NullString
if err := row.ColumnByName("invalidation_reason", &invalidationReason); err != nil {
return nil, Error.Wrap(err)
}
if invalidationReason.StringVal != "" {
return nil, Error.Wrap(authdb.Invalid.New("%s", invalidationReason.StringVal))
}

record := new(authdb.Record)
record := new(authdb.FullRecord)
if err := row.ColumnByName("public", &record.Public); err != nil {
return nil, Error.Wrap(err)
}
Expand All @@ -194,6 +203,18 @@ func (d *CloudDatabase) Get(ctx context.Context, keyHash authdb.KeyHash) (_ *aut
return nil, Error.Wrap(err)
}

var invalidationReason spanner.NullString
if err := row.ColumnByName("invalidation_reason", &invalidationReason); err != nil {
return nil, Error.Wrap(err)
}
record.InvalidationReason = invalidationReason.StringVal

var invalidatedAt spanner.NullTime
if err := row.ColumnByName("invalidated_at", &invalidatedAt); err != nil {
return nil, Error.Wrap(err)
}
record.InvalidatedAt = invalidatedAt.Time

// From https://cloud.google.com/spanner/docs/ttl:
//
// TTL garbage collection deletes eligible rows continuously and in the
Expand Down
64 changes: 33 additions & 31 deletions pkg/auth/spannerauth/spannerauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,44 +101,46 @@ func TestCloudDatabaseAdmin(t *testing.T) {

require.NoError(t, db.HealthCheck(ctx))

reference := make(map[authdb.KeyHash]*authdb.Record)
for i := 0; i < 5; i++ {
withRecord := func(name string, fn func(t *testing.T, k authdb.KeyHash)) {
var k authdb.KeyHash
require.NoError(t, k.SetBytes([]byte(strconv.Itoa(i))))
r := createRandomRecord(t, time.Time{}, true)
reference[k] = r
require.NoError(t, db.Put(ctx, k, r))
testrand.Read(k[:])
require.NoError(t, db.Put(ctx, k, createRandomRecord(t, time.Time{}, true)))
t.Run(name, func(t *testing.T) {
fn(t, k)
})
}

// invalidate
{
var k authdb.KeyHash
require.NoError(t, k.SetBytes([]byte(strconv.Itoa(1))))
withRecord("Invalidate", func(t *testing.T, k authdb.KeyHash) {
require.NoError(t, db.Invalidate(ctx, k, "test"))
}
// unpublish
{
var k authdb.KeyHash
require.NoError(t, k.SetBytes([]byte(strconv.Itoa(2))))

_, err := db.Get(ctx, k)
require.True(t, authdb.Invalid.Has(err))

r, err := db.GetFullRecord(ctx, k)
require.NoError(t, err)
require.Equal(t, "test", r.InvalidationReason)
require.WithinDuration(t, time.Now(), r.InvalidatedAt, time.Minute)
})

withRecord("Unpublish", func(t *testing.T, k authdb.KeyHash) {
require.NoError(t, db.Unpublish(ctx, k))
reference[k].Public = false
}
// delete
{
var k authdb.KeyHash
require.NoError(t, k.SetBytes([]byte(strconv.Itoa(3))))

r, err := db.Get(ctx, k)
require.NoError(t, err)
require.False(t, r.Public)
})

withRecord("Delete", func(t *testing.T, k authdb.KeyHash) {
require.NoError(t, db.Delete(ctx, k))
reference[k] = nil
}

for k, r := range reference {
actual, err := db.Get(ctx, k)
if err != nil {
require.EqualError(t, err, spannerauth.Error.Wrap(authdb.Invalid.New("test")).Error())
} else {
require.Equal(t, r, actual)
}
}
record, err := db.Get(ctx, k)
require.NoError(t, err)
require.Nil(t, record)

fullRecord, err := db.GetFullRecord(ctx, k)
require.NoError(t, err)
require.Nil(t, fullRecord)
})
}

func TestRecordExpiry(t *testing.T) {
Expand Down

0 comments on commit 14f7705

Please sign in to comment.