From 96817b0e30727c2e6d44bd43b76ce06f50e966b3 Mon Sep 17 00:00:00 2001 From: Tom Bentley Date: Thu, 1 Feb 2024 03:42:56 +0000 Subject: [PATCH] Add RecordEncryptor and RecordDecryptor (#899) * Add RecordEncryptor * Add the RecordDecryptor Signed-off-by: Tom Bentley --- .../test/assertj/HeaderAssert.java | 26 +- .../test/assertj/RecordAssert.java | 4 +- .../test/record/RecordTestUtils.java | 20 +- .../test/assertj/HeaderAssertTest.java | 4 +- .../encryption/inband/DecryptState.java | 31 ++ .../encryption/inband/InBandKeyManager.java | 4 +- .../filter/encryption/inband/Parcel.java | 6 +- .../encryption/inband/RecordDecryptor.java | 115 +++++++ .../encryption/inband/RecordEncryptor.java | 172 ++++++++++ .../encryption/records/RecordTransform.java | 4 +- .../filter/encryption/inband/ParcelTest.java | 4 +- .../inband/RecordEncryptorTest.java | 298 ++++++++++++++++++ 12 files changed, 667 insertions(+), 21 deletions(-) create mode 100644 kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/DecryptState.java create mode 100644 kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/RecordDecryptor.java create mode 100644 kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/RecordEncryptor.java create mode 100644 kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/RecordEncryptorTest.java diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java index d952a2f3ea..e2d85faa8e 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java @@ -6,8 +6,11 @@ package io.kroxylicious.test.assertj; +import java.nio.charset.StandardCharsets; + import org.apache.kafka.common.header.Header; import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.AbstractByteArrayAssert; import org.assertj.core.api.Assertions; public class HeaderAssert extends AbstractAssert { @@ -29,20 +32,25 @@ public HeaderAssert hasKeyEqualTo(String expected) { } public HeaderAssert hasValueEqualTo(String expected) { - isNotNull(); - String valueString = actual.value() == null ? null : new String(actual.value()); - Assertions.assertThat(valueString) - .describedAs("header value") - .isEqualTo(expected); + valueAssert().isEqualTo(expected == null ? null : expected.getBytes(StandardCharsets.UTF_8)); + return this; + } + + public HeaderAssert hasValueEqualTo(byte[] expected) { + valueAssert().isEqualTo(expected); return this; } public HeaderAssert hasNullValue() { - isNotNull(); - Assertions.assertThat(actual.value()) - .describedAs("header value") - .isNull(); + valueAssert().isNull(); return this; } + private AbstractByteArrayAssert valueAssert() { + isNotNull(); + AbstractByteArrayAssert headerValue = Assertions.assertThat(actual.value()) + .describedAs("header value"); + return headerValue; + } + } diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java index 7c08af6b00..3b45a54df0 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java @@ -32,7 +32,7 @@ public static RecordAssert assertThat(Record actual) { return new RecordAssert(actual); } - public RecordAssert hasOffsetEqualTo(int expect) { + public RecordAssert hasOffsetEqualTo(long expect) { isNotNull(); AbstractLongAssert offset = offsetAssert(); offset.isEqualTo(expect); @@ -45,7 +45,7 @@ private AbstractLongAssert offsetAssert() { .describedAs("record offset"); } - public RecordAssert hasTimestampEqualTo(int expect) { + public RecordAssert hasTimestampEqualTo(long expect) { isNotNull(); AbstractLongAssert timestamp = timestampAssert(); timestamp.isEqualTo(expect); diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java index dca9fc2808..7d3e264eff 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java @@ -59,6 +59,10 @@ private static byte[] bytesOf(ByteBuffer buffer) { return result; } + public static String asString(ByteBuffer buffer) { + return buffer == null ? null : new String(bytesOf(buffer), StandardCharsets.UTF_8); + } + /** * Get a copy of the bytes contained in the given {@code record}'s value, without changing the * {@code record}'s {@link ByteBuffer#position() position}, {@link ByteBuffer#limit() limit} or {@link ByteBuffer#mark() mark}. @@ -70,7 +74,7 @@ public static byte[] recordValueAsBytes(Record record) { } public static String recordValueAsString(Record record) { - return record.value() == null ? null : new String(bytesOf(record.value()), StandardCharsets.UTF_8); + return asString(record.value()); } public static byte[] recordKeyAsBytes(Record record) { @@ -78,7 +82,7 @@ public static byte[] recordKeyAsBytes(Record record) { } public static String recordKeyAsString(Record record) { - return record.key() == null ? null : new String(bytesOf(record.key()), StandardCharsets.UTF_8); + return asString(record.key()); } /** @@ -247,6 +251,18 @@ public static Record record(byte magic, return MemoryRecords.readableRecords(mr.buffer()).records().iterator().next(); } + public static Record record(long offset, long timestamp, ByteBuffer key, ByteBuffer value, Header[] headers) { + return record(RecordBatch.CURRENT_MAGIC_VALUE, offset, timestamp, key, value, headers); + } + + public static Record record(long offset, long timestamp, String key, String value, Header[] headers) { + return record(RecordBatch.CURRENT_MAGIC_VALUE, offset, timestamp, key, value, headers); + } + + public static Record record(long offset, long timestamp, byte[] key, byte[] value, Header[] headers) { + return record(RecordBatch.CURRENT_MAGIC_VALUE, offset, timestamp, key, value, headers); + } + /** * Return a singleton RecordBatch containing a single Record with the given key, value and headers. * The batch will use the current magic. The baseOffset and offset of the record will be 0 diff --git a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java index ee57afdc05..cdc6ba7a68 100644 --- a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java +++ b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java @@ -46,10 +46,10 @@ void testHeaderHasValueEqualTo() { RecordHeader nonNullValue = new RecordHeader("foo", "abc".getBytes(StandardCharsets.UTF_8)); HeaderAssert nonNullValueAssert = KafkaAssertions.assertThat(nonNullValue); - nullValueAssert.hasValueEqualTo(null); + nullValueAssert.hasValueEqualTo((String) null); nonNullValueAssert.hasValueEqualTo("abc"); throwsAssertionErrorContaining(() -> nonNullValueAssert.hasValueEqualTo("other"), "[header value]"); - throwsAssertionErrorContaining(() -> nonNullValueAssert.hasValueEqualTo(null), "[header value]"); + throwsAssertionErrorContaining(() -> nonNullValueAssert.hasValueEqualTo((String) null), "[header value]"); throwsAssertionErrorContaining(() -> nullValueAssert.hasValueEqualTo("other"), "[header value]"); assertThrowsIfHeaderNull(nullAssert -> nullAssert.hasValueEqualTo("any")); } diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/DecryptState.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/DecryptState.java new file mode 100644 index 0000000000..d6c9ecc13e --- /dev/null +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/DecryptState.java @@ -0,0 +1,31 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.filter.encryption.inband; + +import org.apache.kafka.common.record.Record; + +import io.kroxylicious.filter.encryption.EncryptionVersion; + +import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; + +/** + * Helper class to group together some state for decryption. + * Either both, or neither, of the given {@code decryptionVersion} and {@code encryptor} should be null. + * @param kafkaRecord The record + * @param decryptionVersion The version + * @param encryptor The encryptor + */ +record DecryptState(@NonNull Record kafkaRecord, @Nullable EncryptionVersion decryptionVersion, + @Nullable AesGcmEncryptor encryptor) { + + DecryptState { + if (decryptionVersion == null ^ encryptor == null) { + throw new IllegalArgumentException(); + } + } +} diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index fb552141e8..d93dac81ef 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -463,7 +463,9 @@ private void decryptRecord(EncryptionVersion decryptionVersion, synchronized (encryptor) { plaintextParcel = decryptParcel(wrapper.slice(), encryptor); } - Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, kafkaRecord, builder); + Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, kafkaRecord, (v, h) -> { + builder.appendWithOffset(kafkaRecord.offset(), kafkaRecord.timestamp(), kafkaRecord.key(), v, h); + }); } private ByteBuffer decryptParcel(ByteBuffer ciphertextParcel, AesGcmEncryptor encryptor) { diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java index debc3b066e..9f96afbcee 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java @@ -9,6 +9,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Set; +import java.util.function.BiConsumer; import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.internals.RecordHeader; @@ -19,7 +20,6 @@ import io.kroxylicious.filter.encryption.EncryptionException; import io.kroxylicious.filter.encryption.ParcelVersion; import io.kroxylicious.filter.encryption.RecordField; -import io.kroxylicious.filter.encryption.records.BatchAwareMemoryRecordsBuilder; import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; @@ -58,7 +58,7 @@ static void writeParcel(ParcelVersion parcelVersion, Set recordFiel static void readParcel(ParcelVersion parcelVersion, ByteBuffer parcel, Record encryptedRecord, - @NonNull BatchAwareMemoryRecordsBuilder builder) { + @NonNull BiConsumer consumer) { switch (parcelVersion) { case V1: var parcelledValue = readRecordValue(parcel); @@ -80,7 +80,7 @@ static void readParcel(ParcelVersion parcelVersion, usedHeaders = parcelledHeaders; } ByteBuffer parcelledBuffer = parcelledValue == ABSENT_VALUE ? encryptedRecord.value() : parcelledValue; - builder.appendWithOffset(encryptedRecord.offset(), encryptedRecord.timestamp(), encryptedRecord.key(), parcelledBuffer, usedHeaders); + consumer.accept(parcelledBuffer, usedHeaders); break; default: throw new EncryptionException("Unknown parcel version " + parcelVersion); diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/RecordDecryptor.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/RecordDecryptor.java new file mode 100644 index 0000000000..a909d986ca --- /dev/null +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/RecordDecryptor.java @@ -0,0 +1,115 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.filter.encryption.inband; + +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.function.IntFunction; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.utils.ByteUtils; + +import io.kroxylicious.filter.encryption.AadSpec; +import io.kroxylicious.filter.encryption.CipherCode; +import io.kroxylicious.filter.encryption.records.RecordTransform; + +import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; + +/** + * A {@link RecordTransform} that decrypts records that were previously encrypted by {@link RecordEncryptor}. + */ +public class RecordDecryptor implements RecordTransform { + + private final IntFunction keys; + int index = 0; + private ByteBuffer transformedValue; + private Header[] transformedHeaders; + + public RecordDecryptor(@NonNull IntFunction keys) { + Objects.requireNonNull(keys); + this.keys = keys; + } + + @Override + public void init(@NonNull Record record) { + DecryptState decryptState = keys.apply(index); + if (decryptState == null || decryptState.encryptor() == null) { + transformedValue = record.value(); + transformedHeaders = record.headers(); + return; + } + var encryptor = decryptState.encryptor(); + var decryptionVersion = decryptState.decryptionVersion(); + var wrapper = record.value(); + // Skip the edek + var edekLength = ByteUtils.readUnsignedVarint(wrapper); + wrapper.position(wrapper.position() + edekLength); + + var aadSpec = AadSpec.fromCode(wrapper.get()); + ByteBuffer aad = switch (aadSpec) { + case NONE -> ByteUtils.EMPTY_BUF; + }; + + var cipherCode = CipherCode.fromCode(wrapper.get()); + + ByteBuffer plaintextParcel; + synchronized (encryptor) { + plaintextParcel = decryptParcel(wrapper.slice(), encryptor); + } + Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, record, (v, h) -> { + transformedValue = v; + transformedHeaders = h; + }); + + } + + private ByteBuffer decryptParcel(ByteBuffer ciphertextParcel, AesGcmEncryptor encryptor) { + ByteBuffer plaintext = ciphertextParcel.duplicate(); + encryptor.decrypt(ciphertextParcel, plaintext); + plaintext.flip(); + return plaintext; + } + + @Override + public long transformOffset(@NonNull Record record) { + return record.offset(); + } + + @Override + public long transformTimestamp(@NonNull Record record) { + return record.timestamp(); + } + + @Nullable + @Override + public ByteBuffer transformKey(@NonNull Record record) { + return record.key(); + } + + @Nullable + @Override + public ByteBuffer transformValue(@NonNull Record record) { + return transformedValue == null ? null : transformedValue.duplicate(); + } + + @Nullable + @Override + public Header[] transformHeaders(@NonNull Record record) { + return transformedHeaders; + } + + @Override + public void resetAfterTransform(@NonNull Record record) { + if (transformedValue != null) { + transformedValue.clear(); + } + transformedHeaders = null; + index++; + } +} diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/RecordEncryptor.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/RecordEncryptor.java new file mode 100644 index 0000000000..2be393d5b9 --- /dev/null +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/RecordEncryptor.java @@ -0,0 +1,172 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.filter.encryption.inband; + +import java.nio.ByteBuffer; +import java.util.Objects; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.utils.ByteUtils; + +import io.kroxylicious.filter.encryption.AadSpec; +import io.kroxylicious.filter.encryption.CipherCode; +import io.kroxylicious.filter.encryption.EncryptionScheme; +import io.kroxylicious.filter.encryption.EncryptionVersion; +import io.kroxylicious.filter.encryption.RecordField; +import io.kroxylicious.filter.encryption.records.RecordTransform; + +import edu.umd.cs.findbugs.annotations.NonNull; +import edu.umd.cs.findbugs.annotations.Nullable; + +/** + * A {@link RecordTransform} that encrypts records so that they can be later decrypted by {@link RecordDecryptor}. + * @param The type of KEK id + */ +class RecordEncryptor implements RecordTransform { + + /** + * The encryption header. The value is the encryption version that was used to serialize the parcel and the wrapper. + */ + static final String ENCRYPTION_HEADER_NAME = "kroxylicious.io/encryption"; + private final EncryptionVersion encryptionVersion; + private final EncryptionScheme encryptionScheme; + private final KeyContext keyContext; + private final ByteBuffer parcelBuffer; + private final ByteBuffer wrapperBuffer; + /** + * The encryption version used on the produce path. + * Note that the encryption version used on the fetch path is read from the + * {@link #ENCRYPTION_HEADER_NAME} header. + */ + private final Header[] encryptionHeader; + private @Nullable ByteBuffer transformedValue; + private @Nullable Header[] transformedHeaders; + + /** + * Constructor (obviously). + * @param encryptionVersion The encryption version + * @param encryptionScheme The encryption scheme for this key + * @param keyContext The key context + * @param parcelBuffer A buffer big enough to write the parcel + * @param wrapperBuffer A buffer big enough to write the wrapper + */ + RecordEncryptor(@NonNull EncryptionVersion encryptionVersion, + @NonNull EncryptionScheme encryptionScheme, + @NonNull KeyContext keyContext, + @NonNull ByteBuffer parcelBuffer, + @NonNull ByteBuffer wrapperBuffer) { + Objects.requireNonNull(encryptionVersion); + Objects.requireNonNull(encryptionScheme); + Objects.requireNonNull(keyContext); + Objects.requireNonNull(parcelBuffer); + Objects.requireNonNull(wrapperBuffer); + this.encryptionVersion = encryptionVersion; + this.encryptionScheme = encryptionScheme; + this.keyContext = keyContext; + this.parcelBuffer = parcelBuffer; + this.wrapperBuffer = wrapperBuffer; + this.encryptionHeader = new Header[]{ new RecordHeader(ENCRYPTION_HEADER_NAME, new byte[]{ encryptionVersion.code() }) }; + } + + @Override + public void init(@NonNull Record kafkaRecord) { + if (encryptionScheme.recordFields().contains(RecordField.RECORD_HEADER_VALUES) + && kafkaRecord.headers().length > 0 + && !kafkaRecord.hasValue()) { + // todo implement header encryption preserving null record-values + throw new IllegalStateException("encrypting headers prohibited when original record value null, we must preserve the null for tombstoning"); + } + + this.transformedValue = doTransformValue(kafkaRecord); + this.transformedHeaders = doTransformHeaders(kafkaRecord); + } + + @Nullable + private ByteBuffer doTransformValue(@NonNull Record kafkaRecord) { + final ByteBuffer transformedValue; + if (kafkaRecord.hasValue()) { + Parcel.writeParcel(encryptionVersion.parcelVersion(), encryptionScheme.recordFields(), kafkaRecord, parcelBuffer); + parcelBuffer.flip(); + transformedValue = writeWrapper(parcelBuffer); + parcelBuffer.rewind(); + } + else { + transformedValue = null; + } + return transformedValue; + } + + private Header[] doTransformHeaders(@NonNull Record kafkaRecord) { + final Header[] transformedHeaders; + if (kafkaRecord.hasValue()) { + Header[] oldHeaders = kafkaRecord.headers(); + if (encryptionScheme.recordFields().contains(RecordField.RECORD_HEADER_VALUES) || oldHeaders.length == 0) { + transformedHeaders = encryptionHeader; + } + else { + transformedHeaders = new Header[1 + oldHeaders.length]; + transformedHeaders[0] = encryptionHeader[0]; + System.arraycopy(oldHeaders, 0, transformedHeaders, 1, oldHeaders.length); + } + } + else { + transformedHeaders = kafkaRecord.headers(); + } + return transformedHeaders; + } + + @Nullable + private ByteBuffer writeWrapper(ByteBuffer parcelBuffer) { + switch (encryptionVersion.wrapperVersion()) { + case V1 -> { + var edek = keyContext.serializedEdek(); + ByteUtils.writeUnsignedVarint(edek.length, wrapperBuffer); + wrapperBuffer.put(edek); + wrapperBuffer.put(AadSpec.NONE.code()); // aadCode + wrapperBuffer.put(CipherCode.AES_GCM_96_128.code()); + keyContext.encodedSize(parcelBuffer.limit()); + ByteBuffer aad = ByteUtils.EMPTY_BUF; // TODO pass the AAD to encode + keyContext.encode(parcelBuffer, wrapperBuffer); // iv and ciphertext + } + } + wrapperBuffer.flip(); + return wrapperBuffer; + } + + @Override + public void resetAfterTransform(Record record) { + wrapperBuffer.rewind(); + } + + @Override + public long transformOffset(Record record) { + return record.offset(); + } + + @Override + public long transformTimestamp(Record record) { + return record.timestamp(); + } + + @Override + public @Nullable ByteBuffer transformKey(Record record) { + return record.key(); + } + + @Override + public @Nullable ByteBuffer transformValue(Record kafkaRecord) { + return transformedValue == null ? null : transformedValue.duplicate(); + } + + @Override + public @Nullable Header[] transformHeaders(Record kafkaRecord) { + return transformedHeaders; + } + +} diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/RecordTransform.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/RecordTransform.java index 8dda5a1f92..2b52631f58 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/RecordTransform.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/RecordTransform.java @@ -25,7 +25,9 @@ *
    *
  1. Transform one record at a time
  2. *
  3. Invoke {@link #init(Record)} for that record before any other methods
  4. - *
  5. Invoke each of the {@code transform*()} methods for that record, as required
  6. + *
  7. Invoke the {@code transform*()} methods for that record, as required. + * They may be invoked zero, one or many times, and should be idempotent. + * They don't have to be invoked in any particular order.
  8. *
  9. Invoke {@link #resetAfterTransform(Record)} for that record
  10. *
*/ diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java index 23bbeffd6d..c7390c83ae 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java @@ -58,7 +58,9 @@ void shouldRoundTrip(Set fields, Record record) { buffer.flip(); BatchAwareMemoryRecordsBuilder mockBuilder = Mockito.mock(BatchAwareMemoryRecordsBuilder.class); - Parcel.readParcel(ParcelVersion.V1, buffer, record, mockBuilder); + Parcel.readParcel(ParcelVersion.V1, buffer, record, (v, h) -> { + mockBuilder.appendWithOffset(record.offset(), record.timestamp(), record.key(), v, h); + }); verify(mockBuilder).appendWithOffset(record.offset(), record.timestamp(), record.key(), expectedValue, record.headers()); assertThat(buffer.remaining()).isEqualTo(0); } diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/RecordEncryptorTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/RecordEncryptorTest.java new file mode 100644 index 0000000000..eab3c4ac90 --- /dev/null +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/RecordEncryptorTest.java @@ -0,0 +1,298 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.filter.encryption.inband; + +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.Set; + +import javax.crypto.KeyGenerator; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import io.kroxylicious.filter.encryption.EncryptionScheme; +import io.kroxylicious.filter.encryption.EncryptionVersion; +import io.kroxylicious.filter.encryption.RecordField; +import io.kroxylicious.filter.encryption.records.RecordTransform; +import io.kroxylicious.test.assertj.KafkaAssertions; +import io.kroxylicious.test.record.RecordTestUtils; + +class RecordEncryptorTest { + + private static AesGcmEncryptor DECRYPTOR; + private static KeyContext KEY_CONTEXT; + private static ByteBuffer EDEK; + + @BeforeAll + public static void initKeyContext() throws NoSuchAlgorithmException { + var generator = KeyGenerator.getInstance("AES"); + var key = generator.generateKey(); + EDEK = ByteBuffer.wrap(key.getEncoded()); // it doesn't matter for this test that it's not encrypted + AesGcmEncryptor encryptor = AesGcmEncryptor.forEncrypt(new AesGcmIvGenerator(new SecureRandom()), key); + DECRYPTOR = AesGcmEncryptor.forDecrypt(key); + KEY_CONTEXT = new KeyContext(EDEK, Long.MAX_VALUE, Integer.MAX_VALUE, encryptor); + } + + private Record encryptSingleRecord(Set fields, long offset, long timestamp, String key, String value, Header... headers) { + var kc = KEY_CONTEXT; + var re = new RecordEncryptor(EncryptionVersion.V1, + new EncryptionScheme<>("key", fields), + kc, + ByteBuffer.allocate(100), + ByteBuffer.allocate(100)); + + Record record = RecordTestUtils.record(RecordBatch.MAGIC_VALUE_V2, offset, timestamp, key, value, headers); + + return transformRecord(re, record); + } + + private Record transformRecord(RecordTransform recordTransform, Record record) { + recordTransform.init(record); + var tOffset = recordTransform.transformOffset(record); + var tTimestamp = recordTransform.transformTimestamp(record); + var tKey = recordTransform.transformKey(record); + var tValue = recordTransform.transformValue(record); + var tHeaders = recordTransform.transformHeaders(record); + recordTransform.resetAfterTransform(record); + return RecordTestUtils.record(tOffset, tTimestamp, tKey, tValue, tHeaders); + } + + @Test + void shouldEncryptValueOnlyWithNoExistingHeaders() { + // Given + Set fields = Set.of(RecordField.RECORD_VALUE); + long offset = 55L; + long timestamp = System.currentTimeMillis(); + var key = "hello"; + var value = "world"; + + // When + var t = encryptSingleRecord(fields, offset, timestamp, key, value); + + // Then + KafkaAssertions.assertThat(t) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasValueNotEqualTo(value) + .singleHeader() + .hasKeyEqualTo("kroxylicious.io/encryption") + .hasValueEqualTo(new byte[]{ 1 }); + + // And when + var rd = new RecordDecryptor(index -> new DecryptState(t, EncryptionVersion.V1, DECRYPTOR)); + var rt = transformRecord(rd, t); + + // Then + KafkaAssertions.assertThat(rt) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasValueEqualTo(value) + .hasEmptyHeaders(); + } + + // TODO with legacy magic + + @Test + void shouldEncryptValueOnlyWithExistingHeaders() { + // Given + Set fields = Set.of(RecordField.RECORD_VALUE); + long offset = 55L; + long timestamp = System.currentTimeMillis(); + var key = "hello"; + var value = "world"; + var header = new RecordHeader("bob", null); + + // When + var t = encryptSingleRecord(fields, offset, timestamp, key, value, header); + + // Then + KafkaAssertions.assertThat(t) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasValueNotEqualTo(value) + .hasHeadersSize(2) + .containsHeaderWithKey("kroxylicious.io/encryption") + .containsHeaderWithKey("bob"); + + // And when + var rd = new RecordDecryptor(index -> new DecryptState(t, EncryptionVersion.V1, DECRYPTOR)); + var rt = transformRecord(rd, t); + + // Then + KafkaAssertions.assertThat(rt) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasValueEqualTo(value) + .singleHeader() + .hasKeyEqualTo("bob") + .hasNullValue(); + } + + @Test + void shouldEncryptValueOnlyPreservesNullValue() { + // Given + Set fields = Set.of(RecordField.RECORD_VALUE); + long offset = 55L; + long timestamp = System.currentTimeMillis(); + var key = "hello"; + var value = (String) null; + + // When + var t = encryptSingleRecord(fields, offset, timestamp, key, value); + + // Then + KafkaAssertions.assertThat(t) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasNullValue() + .hasEmptyHeaders(); + + // And when + var rd = new RecordDecryptor(index -> null); // note the null return + var rt = transformRecord(rd, t); + + // Then + KafkaAssertions.assertThat(rt) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasNullValue() + .hasEmptyHeaders(); + } + + @Test + void shouldEncryptValueAndHeaders() { + // Given + Set fields = Set.of(RecordField.RECORD_VALUE, RecordField.RECORD_HEADER_VALUES); + long offset = 55L; + long timestamp = System.currentTimeMillis(); + var key = "hello"; + var value = "world"; + var header = new RecordHeader("bob", null); + + // When + var t = encryptSingleRecord(fields, offset, timestamp, key, value, header); + + // Then + KafkaAssertions.assertThat(t) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasValueNotEqualTo(value) + .singleHeader() + .hasKeyEqualTo("kroxylicious.io/encryption") + .hasValueEqualTo(new byte[]{ 1 }); + + // And when + var rd = new RecordDecryptor(index -> new DecryptState(t, EncryptionVersion.V1, DECRYPTOR)); + var rt = transformRecord(rd, t); + + // Then + KafkaAssertions.assertThat(rt) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasValueEqualTo(value) + .singleHeader() + .hasKeyEqualTo("bob") + .hasNullValue(); + } + + @Test + @Disabled("Not implemented yet") + void shouldEncryptValueAndHeadersPreservesNullValue() { + // Given + Set fields = Set.of(RecordField.RECORD_VALUE, RecordField.RECORD_HEADER_VALUES); + long offset = 55L; + long timestamp = System.currentTimeMillis(); + var key = "hello"; + var value = (String) null; + var header = new RecordHeader("bob", null); + + // When + var t = encryptSingleRecord(fields, offset, timestamp, key, value, header); + + // Then + KafkaAssertions.assertThat(t) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasNullValue() + .singleHeader() + .hasKeyEqualTo("kroxylicious.io/encryption") + .hasValueEqualTo(new byte[]{ 1 }); + + // TODO decryption + } + + @Test + @Disabled("Not supported yet") + void shouldEncryptHeadersOnly() { + // Given + Set fields = Set.of(RecordField.RECORD_HEADER_VALUES); + long offset = 55L; + long timestamp = System.currentTimeMillis(); + var key = "hello"; + var value = "world"; + var header = new RecordHeader("bob", null); + + // When + var t = encryptSingleRecord(fields, offset, timestamp, key, value, header); + + // Then + KafkaAssertions.assertThat(t) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasValueEqualTo(value) + .singleHeader() + .hasKeyEqualTo("kroxylicious.io/encryption") + .hasValueEqualTo(new byte[]{ 1 }); + + // TODO decryption + } + + @Test + @Disabled("Not supported yet") + void shouldEncryptHeadersOnlyPreservesNullValue() { + // Given + Set fields = Set.of(RecordField.RECORD_HEADER_VALUES); + long offset = 55L; + long timestamp = System.currentTimeMillis(); + var key = "hello"; + var value = (String) null; + var header = new RecordHeader("bob", null); + + // When + var t = encryptSingleRecord(fields, offset, timestamp, key, value, header); + + // Then + KafkaAssertions.assertThat(t) + .hasOffsetEqualTo(offset) + .hasTimestampEqualTo(timestamp) + .hasKeyEqualTo(key) + .hasValueEqualTo(value) + .singleHeader() + .hasKeyEqualTo("kroxylicious.io/encryption") + .hasValueEqualTo(new byte[]{ 1 }); + + // TODO decryption + } + +}