Skip to content

Commit

Permalink
Add lookup function to multiwrapper and simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
jefferai committed Aug 17, 2020
1 parent 6d40896 commit 6714981
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 35 deletions.
91 changes: 56 additions & 35 deletions wrappers/multiwrapper/multiwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ var ErrKeyNotFound = errors.New("given key ID not found")
// Functions on this type will likely panic if the wrapper is not created via
// NewMultiWrapper.
type MultiWrapper struct {
wrappers *sync.Map
m sync.RWMutex
wrappers map[string]wrapping.Wrapper
}

// NewMultiWrapper creates a MultiWrapper and sets its encrypting wrapper to
Expand All @@ -32,32 +33,51 @@ func NewMultiWrapper(base wrapping.Wrapper) *MultiWrapper {
}

ret := &MultiWrapper{
wrappers: new(sync.Map),
wrappers: make(map[string]wrapping.Wrapper, 3),
}
ret.wrappers.Store(baseEncryptor, base)
ret.wrappers[baseEncryptor] = base
ret.wrappers[base.KeyID()] = base
return ret
}

// AddWrapper adds a wrapper to the MultiWrapper. For safety, it will refuse to
// overwrite an existing wrapper; use RemoveWrapper to remove that one first.
// The return parameter indicates if the wrapper was successfully added, that
// is, it will be false if an existing wrapper would have been overridden. If
// you want to change the encrypting wrapper, create a new MultiWrapper. This
// function will panic if w is nil.
// you want to change the encrypting wrapper, create a new MultiWrapper or call
// SetEncryptingWrapper. This function will panic if w is nil.
func (m *MultiWrapper) AddWrapper(w wrapping.Wrapper) (added bool) {
_, loaded := m.wrappers.LoadOrStore(w.KeyID(), w)
return !loaded
m.m.Lock()
defer m.m.Unlock()

wrapper := m.wrappers[w.KeyID()]
if wrapper != nil {
return false
}
m.wrappers[w.KeyID()] = w
return true
}

// RemoveWrapper removes a wrapper from the MultiWrapper, identified by key ID.
// It will not remove the encrypting wrapper; use SetEncryptingWrapper for
// that.
func (m *MultiWrapper) RemoveWrapper(keyID string) {
// Don't allow removing our base encryptor
// that. Returns whether or not a wrapper was removed, which will always be
// true unless it was the base encryptor.
func (m *MultiWrapper) RemoveWrapper(keyID string) (removed bool) {
// For safety, no real reason this should happen
if keyID == baseEncryptor {
return
panic("invalid key ID")
}

m.m.Lock()
defer m.m.Unlock()

base := m.wrappers[baseEncryptor]
if base.KeyID() == keyID {
// Don't allow removing the base encryptor
return false
}
m.wrappers.Delete(keyID)
delete(m.wrappers, keyID)
return true
}

// SetEncryptingWrapper resets the encrypting wrapper to the one passed in. It
Expand All @@ -71,30 +91,32 @@ func (m *MultiWrapper) SetEncryptingWrapper(w wrapping.Wrapper) (success bool) {
panic("invalid key ID")
}

// Note: we keep this simple and don't return errors because there are no
// reasonable ways this should fail, other than trying to give a new
// encryptor with an existing key ID.
val, ok := m.wrappers.Load(baseEncryptor)
if !ok {
m.wrappers.Store(baseEncryptor, w)
return true
}
oldW := val.(wrapping.Wrapper)
_, loaded := m.wrappers.LoadOrStore(oldW.KeyID(), oldW)
if loaded {
return false
}
m.m.Lock()
defer m.m.Unlock()

m.wrappers.Store(baseEncryptor, w)
m.wrappers[baseEncryptor] = w
m.wrappers[w.KeyID()] = w
return true
}

// WrapperForKeyID returns the wrapper for the given keyID. Returns nil if no
// wrapper was found for the given key ID.
func (m *MultiWrapper) WrapperForKeyID(keyID string) wrapping.Wrapper {
m.m.RLock()
defer m.m.RUnlock()

return m.wrappers[keyID]
}

func (m *MultiWrapper) encryptor() wrapping.Wrapper {
val, ok := m.wrappers.Load(baseEncryptor)
if !ok {
m.m.RLock()
defer m.m.RUnlock()

wrapper := m.wrappers[baseEncryptor]
if wrapper == nil {
panic("no base encryptor found")
}
return val.(wrapping.Wrapper)
return wrapper
}

func (m *MultiWrapper) Type() string {
Expand Down Expand Up @@ -133,15 +155,14 @@ func (m *MultiWrapper) Encrypt(ctx context.Context, pt []byte, aad []byte) (*wra
// decryption with the current encryptor. It will return an ErrKeyNotFound if
// it cannot find a suitable key.
func (m *MultiWrapper) Decrypt(ctx context.Context, ct *wrapping.EncryptedBlobInfo, aad []byte) ([]byte, error) {
// First check the encryptor
enc := m.encryptor()
if ct.KeyInfo == nil || ct.KeyInfo.KeyID == enc.KeyID() {
if ct.KeyInfo == nil {
enc := m.encryptor()
return enc.Decrypt(ctx, ct, aad)
}

val, ok := m.wrappers.Load(ct.KeyInfo.KeyID)
if !ok {
wrapper := m.WrapperForKeyID(ct.KeyInfo.KeyID)
if wrapper == nil {
return nil, ErrKeyNotFound
}
return val.(wrapping.Wrapper).Decrypt(ctx, ct, aad)
return wrapper.Decrypt(ctx, ct, aad)
}
20 changes: 20 additions & 0 deletions wrappers/multiwrapper/multiwrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,26 @@ func TestMultiWrapper(t *testing.T) {
}
}

// Check retriving the wrappers
checkW1 := multi.WrapperForKeyID("w1")
if checkW1 == nil {
t.Fatal("nil w1")
}
if checkW1.KeyID() != "w1" {
t.Fatal("mismatch")
}
checkW2 := multi.WrapperForKeyID("w2")
if checkW2 == nil {
t.Fatal("nil w2")
}
if checkW2.KeyID() != "w2" {
t.Fatal("mismatch")
}
checkW3 := multi.WrapperForKeyID("w3")
if checkW3 != nil {
t.Fatal("expected key not found")
}

// Check removing a wrapper, and not removing the base wrapper
multi.RemoveWrapper("w1")
multi.RemoveWrapper("w2")
Expand Down

0 comments on commit 6714981

Please sign in to comment.