Skip to content

Commit

Permalink
Merge pull request kroxylicious#763 from robobario/remove-kek-from-km…
Browse files Browse the repository at this point in the history
…s-decrypt

Make KMS responsible for encoding immutable kek id into edek
  • Loading branch information
robobario authored Dec 5, 2023
2 parents 9fcfc5d + 57fd82d commit 83df861
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ public class InBandKeyManager<K, E> implements KeyManager<K> {

private final Kms<K, E> kms;
private final BufferPool bufferPool;
private final Serde<K> kekIdSerde;
private final Serde<E> edekSerde;
// TODO cache expiry, with key descruction
private final ConcurrentHashMap<K, CompletionStage<KeyContext>> keyContextCache;
Expand All @@ -66,7 +65,6 @@ public InBandKeyManager(Kms<K, E> kms,
this.kms = kms;
this.bufferPool = bufferPool;
this.edekSerde = kms.edekSerde();
this.kekIdSerde = kms.keyIdSerde();
this.dekTtlNanos = 5_000_000_000L;
this.maxEncryptionsPerDek = 500_000;
// TODO This ^^ must be > the maximum size of a batch to avoid an infinite loop
Expand Down Expand Up @@ -125,15 +123,10 @@ private Supplier<CompletionStage<KeyContext>> makeKeyContext(@NonNull K kekId) {
return () -> kms.generateDekPair(kekId)
.thenApply(dekPair -> {
E edek = dekPair.edek();
short kekIdSize = (short) kekIdSerde.sizeOf(kekId);
short edekSize = (short) edekSerde.sizeOf(edek);
ByteBuffer prefix = bufferPool.acquire(
Short.BYTES + // kekId size
kekIdSize + // the kekId
Short.BYTES + // DEK size
Short.BYTES + // DEK size
edekSize); // the DEK
prefix.putShort(kekIdSize);
kekIdSerde.serialize(kekId, prefix);
prefix.putShort(edekSize);
edekSerde.serialize(edek, prefix);
prefix.flip();
Expand Down Expand Up @@ -285,11 +278,10 @@ static RecordHeader dek(Record kafkaRecord) {
}

private CompletionStage<AesGcmEncryptor> getOrCacheDecryptor(RecordHeader dekHeader,
K kekId,
E edek) {
return decryptorCache.compute(dekHeader, (k, v) -> {
if (v == null) {
return kms.decryptEdek(kekId, edek)
return kms.decryptEdek(edek)
.thenApply(AesGcmEncryptor::forDecrypt).toCompletableFuture();
// TODO what happens if the CS complete exceptionally
// TODO what happens if the CS doesn't complete at all in a reasonably time frame?
Expand Down Expand Up @@ -349,16 +341,11 @@ private void decryptRecord(@NonNull Receiver receiver, Record kafkaRecord, AesGc
private CompletionStage<AesGcmEncryptor> resolveEncryptor(Record kafkaRecord) {
var dekHeader = dek(kafkaRecord);
var buffer = ByteBuffer.wrap(dekHeader.value());
var kekLength = buffer.getShort();
int origLimit = buffer.limit();
buffer.limit(buffer.position() + kekLength);
var kekId = kekIdSerde.deserialize(buffer);
buffer.limit(origLimit);
var edekLength = buffer.getShort();
buffer.limit(buffer.position() + edekLength);
var edek = edekSerde.deserialize(buffer);
buffer.rewind();
return getOrCacheDecryptor(dekHeader, kekId, edek);
return getOrCacheDecryptor(dekHeader, edek);
}

private Header decryptRecordHeader(Header header, AesGcmEncryptor encryptor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

import java.util.Arrays;
import java.util.Objects;
import java.util.UUID;

record InMemoryEdek(
int numAuthBits,
byte[] iv,
UUID kekRef,
byte[] edek) {

InMemoryEdek {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,42 @@
package io.kroxylicious.kms.provider.kroxylicious.inmemory;

import java.nio.ByteBuffer;
import java.util.UUID;

import io.kroxylicious.kms.service.Serde;

import edu.umd.cs.findbugs.annotations.NonNull;

class InMemoryEdekSerde implements Serde<InMemoryEdek> {

private static final InMemoryEdekSerde INSTANCE = new InMemoryEdekSerde();
private final Serde<UUID> uuidSerde = UUIDSerde.instance();

private InMemoryEdekSerde() {
}

public static Serde<InMemoryEdek> instance() {
return INSTANCE;
}

@Override
public InMemoryEdek deserialize(@NonNull ByteBuffer buffer) {
short numAuthBits = Serde.getUnsignedByte(buffer);
var ivLength = Serde.getUnsignedByte(buffer);
var iv = new byte[ivLength];
buffer.get(iv);
UUID keyRef = uuidSerde.deserialize(buffer);
int edekLength = buffer.limit() - buffer.position();
var edek = new byte[edekLength];
buffer.get(edek);
return new InMemoryEdek(numAuthBits, iv, edek);
return new InMemoryEdek(numAuthBits, iv, keyRef, edek);
}

@Override
public int sizeOf(InMemoryEdek inMemoryEdek) {
return Byte.BYTES // Auth tag: NIST.SP.800-38D §5.2.1.2 suggests max tag length is 128
+ Byte.BYTES // IV length: NIST.SP.800-38D §8.2 certainly doesn't _limit_ IV to 96 bits
+ uuidSerde.sizeOf(inMemoryEdek.kekRef())
+ inMemoryEdek.iv().length
+ inMemoryEdek.edek().length;
}
Expand All @@ -39,6 +52,7 @@ public void serialize(InMemoryEdek inMemoryEdek, @NonNull ByteBuffer buffer) {
Serde.putUnsignedByte(buffer, inMemoryEdek.numAuthBits());
Serde.putUnsignedByte(buffer, inMemoryEdek.iv().length);
buffer.put(inMemoryEdek.iv());
uuidSerde.serialize(inMemoryEdek.kekRef(), buffer);
buffer.put(inMemoryEdek.edek());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ private InMemoryEdek wrap(UUID kekRef, Supplier<SecretKey> generator) {
catch (IllegalBlockSizeException | InvalidKeyException e) {
throw new KmsException(e);
}
return new InMemoryEdek(spec.getTLen(), spec.getIV(), edek);
return new InMemoryEdek(spec.getTLen(), spec.getIV(), kekRef, edek);
}

@NonNull
Expand Down Expand Up @@ -152,9 +152,9 @@ private SecretKey lookupKey(UUID kekRef) {

@NonNull
@Override
public CompletableFuture<SecretKey> decryptEdek(@NonNull UUID kekRef, @NonNull InMemoryEdek edek) {
public CompletableFuture<SecretKey> decryptEdek(@NonNull InMemoryEdek edek) {
try {
var kek = lookupKey(kekRef);
var kek = lookupKey(edek.kekRef());
Cipher aesCipher = aesGcm();
initializeforUnwrap(aesCipher, edek, kek);
SecretKey key = unwrap(edek, aesCipher);
Expand Down Expand Up @@ -193,12 +193,6 @@ private static Cipher aesGcm() {
}
}

@NonNull
@Override
public Serde<UUID> keyIdSerde() {
return new UUIDSerde();
}

@NonNull
@Override
public CompletableFuture<UUID> resolveAlias(@NonNull String alias) {
Expand All @@ -212,6 +206,6 @@ public CompletableFuture<UUID> resolveAlias(@NonNull String alias) {
@NonNull
@Override
public Serde<InMemoryEdek> edekSerde() {
return new InMemoryEdekSerde();
return InMemoryEdekSerde.instance();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@

class UUIDSerde implements Serde<UUID> {

private static final UUIDSerde UUID_SERDE = new UUIDSerde();

private UUIDSerde() {
}

@Override
public UUID deserialize(@NonNull ByteBuffer buffer) {
var msb = buffer.getLong();
Expand All @@ -32,4 +37,8 @@ public void serialize(UUID uuid, @NonNull ByteBuffer buffer) {
buffer.putLong(uuid.getMostSignificantBits());
buffer.putLong(uuid.getLeastSignificantBits());
}

public static Serde<UUID> instance() {
return UUID_SERDE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
package io.kroxylicious.kms.provider.kroxylicious.inmemory;

import java.nio.ByteBuffer;
import java.util.UUID;

import org.junit.jupiter.api.Test;

import io.kroxylicious.kms.service.Serde;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

Expand All @@ -18,8 +21,8 @@ class InMemoryEdekSerdeTest {
@Test
void shouldRoundTripAllAllowedAuthBits() {
for (int bits : new int[]{ 128, 120, 112, 104, 96 }) {
var edek = new InMemoryEdek(bits, new byte[0], new byte[0]);
InMemoryEdekSerde serde = new InMemoryEdekSerde();
var edek = new InMemoryEdek(bits, new byte[0], UUID.randomUUID(), new byte[0]);
Serde<InMemoryEdek> serde = InMemoryEdekSerde.instance();
int size = serde.sizeOf(edek);
var buffer = ByteBuffer.allocate(size);
serde.serialize(edek, buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,27 @@

package io.kroxylicious.kms.provider.kroxylicious.inmemory;

import java.util.UUID;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;

class InMemoryEdekTest {

private static final UUID KEK_REF = UUID.randomUUID();

@Test
void testEqualsAndHashCode() {
var edek1 = new InMemoryEdek(96, new byte[]{ (byte) 1, (byte) 2, (byte) 3 },
new byte[]{ (byte) 4, (byte) 5, (byte) 6 });
KEK_REF, new byte[]{ (byte) 4, (byte) 5, (byte) 6 });

var edek2 = new InMemoryEdek(96, new byte[]{ (byte) 1, (byte) 2, (byte) 3 },
new byte[]{ (byte) 4, (byte) 5, (byte) 6 });
KEK_REF, new byte[]{ (byte) 4, (byte) 5, (byte) 6 });

var edek3 = new InMemoryEdek(96, new byte[]{ (byte) 4, (byte) 5, (byte) 6 },
new byte[]{ (byte) 1, (byte) 2, (byte) 3 });
KEK_REF, new byte[]{ (byte) 1, (byte) 2, (byte) 3 });

assertEquals(edek1, edek1);
assertEquals(edek1, edek2);
Expand All @@ -42,9 +46,9 @@ void testEqualsAndHashCode() {
@Test
void testToString() {
var edek1 = new InMemoryEdek(96, new byte[]{ (byte) 1, (byte) 2, (byte) 3 },
new byte[]{ (byte) 4, (byte) 5, (byte) 6 });
KEK_REF, new byte[]{ (byte) 4, (byte) 5, (byte) 6 });
assertEquals("InMemoryEdek{numAuthBits=96, iv=[1, 2, 3], edek=[4, 5, 6]}", edek1.toString());

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.junit.jupiter.api.Test;

import io.kroxylicious.kms.service.DekPair;
import io.kroxylicious.kms.service.Serde;
import io.kroxylicious.kms.service.UnknownAliasException;
import io.kroxylicious.kms.service.UnknownKeyException;

Expand Down Expand Up @@ -64,27 +63,6 @@ void shouldWorkAcrossServiceInstances() {
IntegrationTestingKmsService.delete(kmsId);
}

@Test
void shouldSerializeAndDeserialiseKeks() {
// given
var kmsId = UUID.randomUUID().toString();
var kms = service.buildKms(new IntegrationTestingKmsService.Config(kmsId));
var kek = kms.generateKey();
assertNotNull(kek);

// when
Serde<UUID> keyIdSerde = kms.keyIdSerde();
var buffer = ByteBuffer.allocate(keyIdSerde.sizeOf(kek));
keyIdSerde.serialize(kek, buffer);
assertFalse(buffer.hasRemaining());
buffer.flip();
var loadedKek = keyIdSerde.deserialize(buffer);

// then
assertEquals(kek, loadedKek, "Expect the deserialized kek to be equal to the original kek");
IntegrationTestingKmsService.delete(kmsId);
}

@Test
void shouldGenerateDeks() {
// given
Expand Down Expand Up @@ -153,7 +131,7 @@ void shouldDecryptDeks() {
assertNotNull(pair.dek());

// when
var decryptedDek = kms.decryptEdek(kek, pair.edek()).join();
var decryptedDek = kms.decryptEdek(pair.edek()).join();

// then
assertEquals(pair.dek(), decryptedDek, "Expect the decrypted DEK to equal the originally generated DEK");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@

import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import io.kroxylicious.kms.service.DekPair;
import io.kroxylicious.kms.service.Serde;
import io.kroxylicious.kms.service.UnknownAliasException;
import io.kroxylicious.kms.service.UnknownKeyException;

Expand Down Expand Up @@ -46,25 +44,6 @@ void shouldRejectOutOfBoundAuthBits() {
assertThrows(IllegalArgumentException.class, () -> new UnitTestingKmsService.Config(12, 0));
}

@Test
void shouldSerializeAndDeserialiseKeks() {
// given
var kms = service.buildKms(new UnitTestingKmsService.Config());
var kek = kms.generateKey();
assertNotNull(kek);

// when
Serde<UUID> keyIdSerde = kms.keyIdSerde();
var buffer = ByteBuffer.allocate(keyIdSerde.sizeOf(kek));
keyIdSerde.serialize(kek, buffer);
assertFalse(buffer.hasRemaining());
buffer.flip();
var loadedKek = keyIdSerde.deserialize(buffer);

// then
assertEquals(kek, loadedKek, "Expect the deserialized kek to be equal to the original kek");
}

@Test
void shouldGenerateDeks() {
// given
Expand Down Expand Up @@ -122,7 +101,7 @@ void shouldDecryptDeks() {
assertNotNull(pair.dek());

// when
var decryptedDek = kms.decryptEdek(kek, pair.edek()).join();
var decryptedDek = kms.decryptEdek(pair.edek()).join();

// then
assertEquals(pair.dek(), decryptedDek, "Expect the decrypted DEK to equal the originally generated DEK");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ public interface Kms<K, E> {
/**
* Asynchronously generates a new Data Encryption Key (DEK) and returns it together with the same DEK wrapped by the Key Encryption Key (KEK) given
* by the {@code kekRef},
* The returned encrypted DEK can later be decrypted with {@link Kms#decryptEdek(Object, Object)}.
* The returned encrypted DEK can later be decrypted with {@link #decryptEdek(Object)}. It is expected that
* the returned EDEK contains everything required for decryption including an immutable reference to the KEK
* @param kekRef The key encryption key used to encrypt the generated data encryption key.
* @return A completion stage for the wrapped data encryption key.
* @throws UnknownKeyException If the kek was not known to this KMS.
Expand All @@ -34,24 +35,14 @@ public interface Kms<K, E> {

/**
* Asynchronously decrypts a data encryption key that was {@linkplain #generateDekPair(Object) previously encrypted}.
* @param kek The key encryption key.
* @param edek The encrypted data encryption key.
* @return A completion stage for the data encryption key
* @throws UnknownKeyException If the kek was not known to this KMS.
* @throws InvalidKeyUsageException If the given kek was not intended for key wrapping.
* @throws UnknownKeyException If the edek was not encrypted by a KEK known to this KMS.
* @throws InvalidKeyUsageException If the edek refers to a kek that was not intended for key wrapping.
* @throws KmsException For other exceptions
*/
@NonNull
CompletionStage<SecretKey> decryptEdek(@NonNull K kek, @NonNull E edek);

/**
* Get a serializer for KEK ids.
* It is required that {@code deserialize(serialize(kekId)).equals(kekId)}.
*
* @return A serializer for KEK ids.
*/
@NonNull
Serde<K> keyIdSerde();
CompletionStage<SecretKey> decryptEdek(@NonNull E edek);

/**
* Get a serializer for encrypted DEKs.
Expand Down

0 comments on commit 83df861

Please sign in to comment.