From 3ed2b3f1b399e0abd0737fde7efe71832858b116 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Mon, 22 Jan 2024 10:44:37 +1300 Subject: [PATCH 01/11] Remove Receiver from encrypt path Signed-off-by: Robert Young --- .../encryption/inband/InBandKeyManager.java | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index be68cb42c2..44faef5640 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -136,9 +136,7 @@ public CompletionStage encrypt(@NonNull String topicName, return CompletableFuture.completedFuture(records); } MemoryRecordsBuilder builder = recordsBuilder(allocateBufferForEncode(records, bufferAllocator), records); - return attemptEncrypt(topicName, partition, encryptionScheme, encryptionRequests, (kafkaRecord, plaintextBuffer, headers) -> { - builder.appendWithOffset(kafkaRecord.offset(), kafkaRecord.timestamp(), kafkaRecord.key(), plaintextBuffer, headers); - }, 0).thenApply(unused -> builder.build()); + return attemptEncrypt(topicName, partition, encryptionScheme, encryptionRequests, builder, 0).thenApply(unused -> builder.build()); } @NonNull @@ -171,7 +169,7 @@ private ByteBufferOutputStream allocateBufferForEncode(MemoryRecords records, In @SuppressWarnings("java:S2445") private CompletionStage attemptEncrypt(String topicName, int partition, @NonNull EncryptionScheme encryptionScheme, @NonNull List records, - @NonNull Receiver receiver, int attempt) { + MemoryRecordsBuilder builder, int attempt) { if (attempt >= MAX_ATTEMPTS) { return CompletableFuture.failedFuture( new RequestNotSatisfiable("failed to reserve an EDEK to encrypt " + records.size() + " records for topic " + topicName + " partition " @@ -187,17 +185,17 @@ private CompletionStage attemptEncrypt(String topicName, int partition, @N } else { // todo ensure that a failure during encryption terminates the entire operation with a failed future - return encrypt(encryptionScheme, records, receiver, keyContext); + return encrypt(encryptionScheme, records, builder, keyContext); } } } - return attemptEncrypt(topicName, partition, encryptionScheme, records, receiver, attempt + 1); + return attemptEncrypt(topicName, partition, encryptionScheme, records, builder, attempt + 1); }); } @NonNull private CompletableFuture encrypt(@NonNull EncryptionScheme encryptionScheme, @NonNull List records, - @NonNull Receiver receiver, KeyContext keyContext) { + @NonNull MemoryRecordsBuilder builder, KeyContext keyContext) { var maxParcelSize = records.stream() .mapToInt(kafkaRecord -> Parcel.sizeOfParcel( encryptionVersion.parcelVersion(), @@ -214,7 +212,7 @@ private CompletableFuture encrypt(@NonNull EncryptionScheme encryptionS ByteBuffer parcelBuffer = bufferPool.acquire(maxParcelSize); ByteBuffer wrapperBuffer = bufferPool.acquire(maxWrapperSize); try { - encryptRecords(encryptionScheme, keyContext, records, parcelBuffer, wrapperBuffer, receiver); + encryptRecords(encryptionScheme, keyContext, records, parcelBuffer, wrapperBuffer, builder); } finally { if (wrapperBuffer != null) { @@ -240,7 +238,7 @@ private void encryptRecords(@NonNull EncryptionScheme encryptionScheme, @NonNull List records, @NonNull ByteBuffer parcelBuffer, @NonNull ByteBuffer wrapperBuffer, - @NonNull Receiver receiver) { + @NonNull MemoryRecordsBuilder builder) { records.forEach(kafkaRecord -> { if (encryptionScheme.recordFields().contains(RecordField.RECORD_HEADER_VALUES) && kafkaRecord.headers().length > 0 @@ -253,12 +251,12 @@ private void encryptRecords(@NonNull EncryptionScheme encryptionScheme, parcelBuffer.flip(); var transformedValue = writeWrapper(keyContext, parcelBuffer, wrapperBuffer); Header[] headers = transformHeaders(encryptionScheme, kafkaRecord); - receiver.accept(kafkaRecord, transformedValue, headers); + builder.appendWithOffset(kafkaRecord.offset(), kafkaRecord.timestamp(), kafkaRecord.key(), transformedValue, headers); wrapperBuffer.rewind(); parcelBuffer.rewind(); } else { - receiver.accept(kafkaRecord, null, kafkaRecord.headers()); + builder.appendWithOffset(kafkaRecord.offset(), kafkaRecord.timestamp(), kafkaRecord.key(), null, kafkaRecord.headers()); } }); } From 57c05473976eff6a9b823b43c91441a952c97630 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Mon, 22 Jan 2024 10:52:52 +1300 Subject: [PATCH 02/11] Remove Receiver on decrypt path and delete it Signed-off-by: Robert Young --- .../filter/encryption/Receiver.java | 26 ------------------- .../encryption/inband/InBandKeyManager.java | 16 +++++------- .../filter/encryption/inband/Parcel.java | 9 +++---- .../filter/encryption/inband/ParcelTest.java | 17 +++++------- 4 files changed, 17 insertions(+), 51 deletions(-) delete mode 100644 kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/Receiver.java diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/Receiver.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/Receiver.java deleted file mode 100644 index 389eee51f1..0000000000 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/Receiver.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright Kroxylicious Authors. - * - * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 - */ - -package io.kroxylicious.filter.encryption; - -import java.nio.ByteBuffer; - -import org.apache.kafka.common.header.Header; -import org.apache.kafka.common.record.Record; - -/** - * Something that receives the result of an encryption or decryption operation - */ -public interface Receiver { - /** - * Receive the ciphertext (encryption) or the plaintext (decryption) associated with the given record.. - * - * @param kafkaRecord The record on which to base the revised record - * @param value The ciphertext or plaintext buffer. This buffer may be re-used, the implementor should extract all - * the bytes they need from the buffer before the end of the accept call. - */ - void accept(Record kafkaRecord, ByteBuffer value, Header[] headers); -} diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index 44faef5640..f14c99f417 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -35,7 +35,6 @@ import io.kroxylicious.filter.encryption.EncryptionVersion; import io.kroxylicious.filter.encryption.EnvelopeEncryptionFilter; import io.kroxylicious.filter.encryption.KeyManager; -import io.kroxylicious.filter.encryption.Receiver; import io.kroxylicious.filter.encryption.RecordField; import io.kroxylicious.filter.encryption.WrapperVersion; import io.kroxylicious.kms.service.Kms; @@ -352,16 +351,14 @@ public CompletionStage decrypt(@NonNull String topicName, int par } ByteBufferOutputStream buffer = allocateBufferForDecode(records, bufferAllocator); MemoryRecordsBuilder outputBuilder = recordsBuilder(buffer, records); - return decrypt(topicName, partition, recordStream(records).toList(), (kafkaRecord, plaintextBuffer, headers) -> { - outputBuilder.appendWithOffset(kafkaRecord.offset(), kafkaRecord.timestamp(), kafkaRecord.key(), plaintextBuffer, headers); - }).thenApply(unused -> outputBuilder.build()); + return decrypt(topicName, partition, recordStream(records).toList(), outputBuilder).thenApply(unused -> outputBuilder.build()); } @NonNull private CompletionStage decrypt(String topicName, int partition, @NonNull List records, - @NonNull Receiver receiver) { + @NonNull MemoryRecordsBuilder builder) { var decryptStateStages = new ArrayList>(records.size()); for (Record kafkaRecord : records) { @@ -382,10 +379,11 @@ private CompletionStage decrypt(String topicName, .thenApply(decryptStates -> { decryptStates.forEach(decryptState -> { if (decryptState.encryptor() == null) { - receiver.accept(decryptState.kafkaRecord(), decryptState.valueWrapper(), decryptState.kafkaRecord().headers()); + Record record = decryptState.kafkaRecord(); + builder.appendWithOffset(record.offset(), record.timestamp(), record.key(), decryptState.valueWrapper(), record.headers()); } else { - decryptRecord(decryptState.decryptionVersion(), decryptState.encryptor(), decryptState.valueWrapper(), decryptState.kafkaRecord(), receiver); + decryptRecord(decryptState.decryptionVersion(), decryptState.encryptor(), decryptState.valueWrapper(), decryptState.kafkaRecord(), builder); } }); return null; @@ -402,7 +400,7 @@ private void decryptRecord(EncryptionVersion decryptionVersion, AesGcmEncryptor encryptor, ByteBuffer wrapper, Record kafkaRecord, - @NonNull Receiver receiver) { + @NonNull MemoryRecordsBuilder builder) { var aadSpec = AadSpec.fromCode(wrapper.get()); ByteBuffer aad = switch (aadSpec) { case NONE -> ByteUtils.EMPTY_BUF; @@ -414,7 +412,7 @@ private void decryptRecord(EncryptionVersion decryptionVersion, synchronized (encryptor) { plaintextParcel = decryptParcel(wrapper.slice(), encryptor); } - Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, kafkaRecord, receiver); + Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, kafkaRecord, builder); } private CompletionStage resolveEncryptor(WrapperVersion wrapperVersion, ByteBuffer wrapper) { diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java index 92d2194fba..1188bd3f83 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java @@ -12,13 +12,13 @@ import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.MemoryRecordsBuilder; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.utils.ByteUtils; import org.apache.kafka.common.utils.Utils; import io.kroxylicious.filter.encryption.EncryptionException; import io.kroxylicious.filter.encryption.ParcelVersion; -import io.kroxylicious.filter.encryption.Receiver; import io.kroxylicious.filter.encryption.RecordField; import edu.umd.cs.findbugs.annotations.NonNull; @@ -58,7 +58,7 @@ static void writeParcel(ParcelVersion parcelVersion, Set recordFiel static void readParcel(ParcelVersion parcelVersion, ByteBuffer parcel, Record encryptedRecord, - @NonNull Receiver receiver) { + @NonNull MemoryRecordsBuilder builder) { switch (parcelVersion) { case V1: var parcelledValue = readRecordValue(parcel); @@ -79,9 +79,8 @@ static void readParcel(ParcelVersion parcelVersion, else { usedHeaders = parcelledHeaders; } - receiver.accept(encryptedRecord, - parcelledValue == ABSENT_VALUE ? encryptedRecord.value() : parcelledValue, - usedHeaders); + ByteBuffer parcelledBuffer = parcelledValue == ABSENT_VALUE ? encryptedRecord.value() : parcelledValue; + builder.appendWithOffset(encryptedRecord.offset(), encryptedRecord.timestamp(), encryptedRecord.key(), parcelledBuffer, usedHeaders); break; default: throw new EncryptionException("Unknown parcel version " + parcelVersion); diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java index 3c2241c63d..1457c30790 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java @@ -11,19 +11,20 @@ import java.util.Set; import java.util.stream.Stream; -import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.MemoryRecordsBuilder; import org.apache.kafka.common.record.Record; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; import io.kroxylicious.filter.encryption.ParcelVersion; -import io.kroxylicious.filter.encryption.Receiver; import io.kroxylicious.filter.encryption.RecordField; import io.kroxylicious.test.record.RecordTestUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verify; class ParcelTest { @@ -56,15 +57,9 @@ void shouldRoundTrip(Set fields, Record record) { buffer.flip(); - Parcel.readParcel(ParcelVersion.V1, buffer, record, new Receiver() { - @Override - public void accept(Record kafkaRecord, ByteBuffer value, Header[] headers) { - assertThat(kafkaRecord).isEqualTo(record); - assertThat(value).isEqualTo(expectedValue); - assertThat(headers).isEqualTo(record.headers()); - } - }); - + MemoryRecordsBuilder mockBuilder = Mockito.mock(MemoryRecordsBuilder.class); + Parcel.readParcel(ParcelVersion.V1, buffer, record, mockBuilder); + verify(mockBuilder).appendWithOffset(record.offset(), record.timestamp(), record.key(), expectedValue, record.headers()); assertThat(buffer.remaining()).isEqualTo(0); } From a1950017dd15179e8f6adcee2878978a77433cae Mon Sep 17 00:00:00 2001 From: Robert Young Date: Mon, 22 Jan 2024 11:19:25 +1300 Subject: [PATCH 03/11] Refactor encryption to be batch aware Signed-off-by: Robert Young --- .../encryption/inband/InBandKeyManager.java | 130 +++++++++++++----- .../inband/InBandKeyManagerTest.java | 46 +++++++ 2 files changed, 140 insertions(+), 36 deletions(-) diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index f14c99f417..7eee88ea9a 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -10,6 +10,7 @@ import java.security.SecureRandom; import java.util.ArrayList; import java.util.List; +import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.function.IntFunction; @@ -20,10 +21,13 @@ import org.apache.kafka.common.header.internals.RecordHeader; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.utils.BufferSupplier; import org.apache.kafka.common.utils.ByteBufferOutputStream; import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.CloseableIterator; import com.github.benmanes.caffeine.cache.AsyncLoadingCache; import com.github.benmanes.caffeine.cache.Caffeine; @@ -37,6 +41,7 @@ import io.kroxylicious.filter.encryption.KeyManager; import io.kroxylicious.filter.encryption.RecordField; import io.kroxylicious.filter.encryption.WrapperVersion; +import io.kroxylicious.filter.encryption.records.BatchAwareMemoryRecordsBuilder; import io.kroxylicious.kms.service.Kms; import io.kroxylicious.kms.service.Serde; @@ -116,6 +121,26 @@ private CompletableFuture makeKeyContext(@NonNull K kekId) { }).toCompletableFuture(); } + record BatchDescription(int index, MutableRecordBatch batch, int recordCount) { + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BatchDescription that = (BatchDescription) o; + return index == that.index && recordCount == that.recordCount; + } + + @Override + public int hashCode() { + return Objects.hash(index, recordCount); + } + } + @Override @NonNull @SuppressWarnings("java:S2445") @@ -128,14 +153,40 @@ public CompletionStage encrypt(@NonNull String topicName, // no records to transform, return input without modification return CompletableFuture.completedFuture(records); } - List encryptionRequests = recordStream(records).toList(); + + List descriptions = describeBatches(records); // it is possible to encounter MemoryRecords that have had all their records compacted away, but // the recordbatch metadata still exists. https://kafka.apache.org/documentation/#recordbatch - if (encryptionRequests.isEmpty()) { + if (descriptions.stream().allMatch(batchDescription -> batchDescription.recordCount == 0)) { return CompletableFuture.completedFuture(records); } - MemoryRecordsBuilder builder = recordsBuilder(allocateBufferForEncode(records, bufferAllocator), records); - return attemptEncrypt(topicName, partition, encryptionScheme, encryptionRequests, builder, 0).thenApply(unused -> builder.build()); + BatchAwareMemoryRecordsBuilder builder = new BatchAwareMemoryRecordsBuilder(allocateBufferForEncode(records, bufferAllocator)); + return attemptEncrypt(topicName, partition, encryptionScheme, records, builder, 0, descriptions).thenApply(unused -> builder.build()); + } + + @NonNull + private static List describeBatches(@NonNull MemoryRecords records) { + int batchIndex = 0; + List descriptions = new ArrayList<>(); + for (MutableRecordBatch batch : records.batches()) { + descriptions.add(new BatchDescription(batchIndex++, batch, batchSize(batch))); + } + return descriptions; + } + + private static int batchSize(MutableRecordBatch batch) { + Integer count = batch.countOrNull(); + if (count == null) { + // for magic <2 count will be null + CloseableIterator iterator = batch.skipKeyValueIterator(BufferSupplier.NO_CACHING); + int c = 0; + while (iterator.hasNext()) { + c++; + iterator.next(); + } + count = c; + } + return count; } @NonNull @@ -167,18 +218,19 @@ private ByteBufferOutputStream allocateBufferForEncode(MemoryRecords records, In } @SuppressWarnings("java:S2445") - private CompletionStage attemptEncrypt(String topicName, int partition, @NonNull EncryptionScheme encryptionScheme, @NonNull List records, - MemoryRecordsBuilder builder, int attempt) { + private CompletionStage attemptEncrypt(String topicName, int partition, @NonNull EncryptionScheme encryptionScheme, @NonNull MemoryRecords records, + BatchAwareMemoryRecordsBuilder builder, int attempt, List batchDescriptions) { + int recordsCount = batchDescriptions.stream().mapToInt(value -> value.recordCount).sum(); if (attempt >= MAX_ATTEMPTS) { return CompletableFuture.failedFuture( - new RequestNotSatisfiable("failed to reserve an EDEK to encrypt " + records.size() + " records for topic " + topicName + " partition " + new RequestNotSatisfiable("failed to reserve an EDEK to encrypt " + recordsCount + " records for topic " + topicName + " partition " + partition + " after " + attempt + " attempts")); } return currentDekContext(encryptionScheme.kekId()).thenCompose(keyContext -> { synchronized (keyContext) { // if it's not alive we know a previous encrypt call has removed this stage from the cache and fall through to retry encrypt if (!keyContext.isDestroyed()) { - if (!keyContext.hasAtLeastRemainingEncryptions(records.size())) { + if (!keyContext.hasAtLeastRemainingEncryptions(recordsCount)) { // remove the key context from the cache, then call encrypt again to drive caffeine to recreate it rotateKeyContext(encryptionScheme, keyContext); } @@ -188,40 +240,46 @@ private CompletionStage attemptEncrypt(String topicName, int partition, @N } } } - return attemptEncrypt(topicName, partition, encryptionScheme, records, builder, attempt + 1); + return attemptEncrypt(topicName, partition, encryptionScheme, records, builder, attempt + 1, batchDescriptions); }); } @NonNull - private CompletableFuture encrypt(@NonNull EncryptionScheme encryptionScheme, @NonNull List records, - @NonNull MemoryRecordsBuilder builder, KeyContext keyContext) { - var maxParcelSize = records.stream() - .mapToInt(kafkaRecord -> Parcel.sizeOfParcel( - encryptionVersion.parcelVersion(), - encryptionScheme.recordFields(), - kafkaRecord)) - .filter(value -> value > 0) - .max() - .orElseThrow(); - var maxWrapperSize = records.stream() - .mapToInt(kafkaRecord -> sizeOfWrapper(keyContext, maxParcelSize)) - .filter(value -> value > 0) - .max() - .orElseThrow(); - ByteBuffer parcelBuffer = bufferPool.acquire(maxParcelSize); - ByteBuffer wrapperBuffer = bufferPool.acquire(maxWrapperSize); - try { - encryptRecords(encryptionScheme, keyContext, records, parcelBuffer, wrapperBuffer, builder); - } - finally { - if (wrapperBuffer != null) { - bufferPool.release(wrapperBuffer); + private CompletableFuture encrypt(@NonNull EncryptionScheme encryptionScheme, + @NonNull MemoryRecords memoryRecords, + @NonNull BatchAwareMemoryRecordsBuilder builder, + @NonNull KeyContext keyContext) { + for (MutableRecordBatch batch : memoryRecords.batches()) { + List records = StreamSupport.stream(batch.spliterator(), false).toList(); + builder.addBatchLike(batch); + var maxParcelSize = records.stream() + .mapToInt(kafkaRecord -> Parcel.sizeOfParcel( + encryptionVersion.parcelVersion(), + encryptionScheme.recordFields(), + kafkaRecord)) + .filter(value -> value > 0) + .max() + .orElseThrow(); + var maxWrapperSize = records.stream() + .mapToInt(kafkaRecord -> sizeOfWrapper(keyContext, maxParcelSize)) + .filter(value -> value > 0) + .max() + .orElseThrow(); + ByteBuffer parcelBuffer = bufferPool.acquire(maxParcelSize); + ByteBuffer wrapperBuffer = bufferPool.acquire(maxWrapperSize); + try { + encryptRecords(encryptionScheme, keyContext, records, parcelBuffer, wrapperBuffer, builder); } - if (parcelBuffer != null) { - bufferPool.release(parcelBuffer); + finally { + if (wrapperBuffer != null) { + bufferPool.release(wrapperBuffer); + } + if (parcelBuffer != null) { + bufferPool.release(parcelBuffer); + } } + keyContext.recordEncryptions(records.size()); } - keyContext.recordEncryptions(records.size()); return CompletableFuture.completedFuture(null); } @@ -237,7 +295,7 @@ private void encryptRecords(@NonNull EncryptionScheme encryptionScheme, @NonNull List records, @NonNull ByteBuffer parcelBuffer, @NonNull ByteBuffer wrapperBuffer, - @NonNull MemoryRecordsBuilder builder) { + @NonNull BatchAwareMemoryRecordsBuilder builder) { records.forEach(kafkaRecord -> { if (encryptionScheme.recordFields().contains(RecordField.RECORD_HEADER_VALUES) && kafkaRecord.headers().length > 0 diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java index e12e3889a6..e3921b3055 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java @@ -19,13 +19,17 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.stream.Stream; +import java.util.stream.StreamSupport; import javax.crypto.SecretKey; import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.CompressionType; import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.TimestampType; import org.apache.kafka.common.utils.ByteBufferOutputStream; import org.apache.kafka.common.utils.ByteUtils; import org.assertj.core.api.InstanceOfAssertFactories; @@ -39,6 +43,7 @@ import io.kroxylicious.filter.encryption.EncryptionScheme; import io.kroxylicious.filter.encryption.RecordField; +import io.kroxylicious.filter.encryption.records.BatchAwareMemoryRecordsBuilder; import io.kroxylicious.kms.provider.kroxylicious.inmemory.InMemoryEdek; import io.kroxylicious.kms.provider.kroxylicious.inmemory.InMemoryKms; import io.kroxylicious.kms.provider.kroxylicious.inmemory.UnitTestingKmsService; @@ -105,6 +110,47 @@ void shouldEncryptRecordValue() { .isEqualTo(value); } + @Test + void shouldPreserveMultipleBatches() { + var kmsService = UnitTestingKmsService.newInstance(); + InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + + var kekId = kms.generateKey(); + + var value = new byte[]{ 1, 2, 3 }; + Record record = RecordTestUtils.record(1, ByteBuffer.wrap(value)); + + var value2 = new byte[]{ 4, 5, 6 }; + Record record2 = RecordTestUtils.record(2, ByteBuffer.wrap(value2)); + BatchAwareMemoryRecordsBuilder builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(1000)); + builder.addBatch(CompressionType.NONE, TimestampType.CREATE_TIME, 1); + builder.appendWithOffset(1l, record); + builder.addBatch(CompressionType.GZIP, TimestampType.LOG_APPEND_TIME, 2); + builder.appendWithOffset(2l, record2); + MemoryRecords records = builder.build(); + + EncryptionScheme scheme = new EncryptionScheme<>(kekId, EnumSet.of(RecordField.RECORD_VALUE)); + CompletableFuture encryptedFuture = km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new).toCompletableFuture(); + assertThat(encryptedFuture).succeedsWithin(Duration.ZERO); + MemoryRecords encrypted = encryptedFuture.join(); + + assertThat(encrypted.batches()).hasSize(2); + List batches = StreamSupport.stream(encrypted.batches().spliterator(), false).toList(); + MutableRecordBatch first = batches.get(0); + assertThat(first.compressionType()).isEqualTo(CompressionType.NONE); + assertThat(first.timestampType()).isEqualTo(TimestampType.CREATE_TIME); + assertThat(first.baseOffset()).isEqualTo(1L); + assertThat(first).hasSize(1); + + MutableRecordBatch second = batches.get(1); + // should we keep the client's compression type? + assertThat(second.compressionType()).isEqualTo(CompressionType.GZIP); + assertThat(second.timestampType()).isEqualTo(TimestampType.LOG_APPEND_TIME); + assertThat(second.baseOffset()).isEqualTo(2L); + assertThat(second).hasSize(1); + } + @NonNull private static CompletionStage doDecrypt(InBandKeyManager km, String topic, int partition, List encrypted, List decrypted) { From 94bf00e68f398ed68a18de0481fbb21051d4a66a Mon Sep 17 00:00:00 2001 From: Robert Young Date: Mon, 22 Jan 2024 12:06:38 +1300 Subject: [PATCH 04/11] Add failing test demonstrating that transactions are broken Signed-off-by: Robert Young --- .../EnvelopeEncryptionFilterIT.java | 112 ++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java b/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java index 3ef869e04d..ec4a0e16e8 100644 --- a/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java +++ b/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java @@ -10,6 +10,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.stream.Stream; @@ -19,10 +20,12 @@ import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.TopicPartition; import org.assertj.core.api.InstanceOfAssertFactories; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; @@ -81,6 +84,115 @@ void roundTripSingleRecord(KafkaCluster cluster, Topic topic, TestKmsFacade testKmsFacade) { + var testKekManager = testKmsFacade.getTestKekManager(); + testKekManager.generateKek(topic.name()); + + var builder = proxy(cluster); + + builder.addToFilters(buildEncryptionFilterDefinition(testKmsFacade)); + + try (var tester = kroxyliciousTester(builder); + var producer = tester.producer(Map.of(ProducerConfig.TRANSACTIONAL_ID_CONFIG, UUID.randomUUID().toString())); + var consumer = tester.consumer(Map.of(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest", + ConsumerConfig.GROUP_ID_CONFIG, UUID.randomUUID().toString(), + ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"))) { + producer.initTransactions(); + withTransaction(producer, transactionProducer -> { + producer.send(new ProducerRecord<>(topic.name(), HELLO_WORLD)).get(5, TimeUnit.SECONDS); + }).commitTransaction(); + consumer.subscribe(List.of(topic.name())); + var records = consumer.poll(Duration.ofSeconds(2)); + assertThat(records.iterator()) + .toIterable() + .singleElement() + .extracting(ConsumerRecord::value) + .isEqualTo(HELLO_WORLD); + } + } + + // check that records from aborted transaction are not exposed to read_committed clients + @TestTemplate + void roundTripTransactionalAbort(KafkaCluster cluster, Topic topic, TestKmsFacade testKmsFacade) { + var testKekManager = testKmsFacade.getTestKekManager(); + testKekManager.generateKek(topic.name()); + + var builder = proxy(cluster); + + builder.addToFilters(buildEncryptionFilterDefinition(testKmsFacade)); + + try (var tester = kroxyliciousTester(builder); + var producer = tester.producer(Map.of(ProducerConfig.TRANSACTIONAL_ID_CONFIG, UUID.randomUUID().toString())); + var consumer = tester.consumer(Map.of(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest", + ConsumerConfig.GROUP_ID_CONFIG, UUID.randomUUID().toString(), + ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"))) { + producer.initTransactions(); + // send to the same partition to demonstrate a message appended to the same partition after the abort is made available + String key = "key"; + withTransaction(producer, transactionProducer -> { + producer.send(new ProducerRecord<>(topic.name(), key, "aborted message")).get(5, TimeUnit.SECONDS); + }).abortTransaction(); + + withTransaction(producer, transactionProducer -> { + producer.send(new ProducerRecord<>(topic.name(), key, HELLO_WORLD)).get(5, TimeUnit.SECONDS); + }).commitTransaction(); + + consumer.subscribe(List.of(topic.name())); + var records = consumer.poll(Duration.ofSeconds(2)); + assertThat(records.iterator()) + .toIterable() + .singleElement() + .extracting(ConsumerRecord::value) + .isEqualTo(HELLO_WORLD); + } + } + + // check that records from uncommitted transaction are not exposed to read_committed clients + @TestTemplate + void roundTripTransactionalIsolation(KafkaCluster cluster, Topic topic, TestKmsFacade testKmsFacade) { + var testKekManager = testKmsFacade.getTestKekManager(); + testKekManager.generateKek(topic.name()); + + var builder = proxy(cluster); + + builder.addToFilters(buildEncryptionFilterDefinition(testKmsFacade)); + + try (var tester = kroxyliciousTester(builder); + var producer = tester.producer(Map.of(ProducerConfig.TRANSACTIONAL_ID_CONFIG, UUID.randomUUID().toString())); + var consumer = tester.consumer(Map.of(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest", + ConsumerConfig.GROUP_ID_CONFIG, UUID.randomUUID().toString(), + ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"))) { + producer.initTransactions(); + + withTransaction(producer, transactionProducer -> { + transactionProducer.send(new ProducerRecord<>(topic.name(), "uncommitted message")).get(5, TimeUnit.SECONDS); + }); + + consumer.subscribe(List.of(topic.name())); + var records = consumer.poll(Duration.ofSeconds(2)); + assertThat(records.iterator()) + .isExhausted(); + } + } + + interface ExceptionalConsumer { + void accept(T t) throws Exception; + } + + Producer withTransaction(Producer producer, ExceptionalConsumer> consumer) { + producer.beginTransaction(); + try { + consumer.accept(producer); + } + catch (Exception e) { + throw new RuntimeException(e); + } + return producer; + } + @TestTemplate void roundTripManyRecordsFromDifferentProducers(KafkaCluster cluster, Topic topic, TestKmsFacade testKmsFacade) throws Exception { var testKekManager = testKmsFacade.getTestKekManager(); From 7615d36af8307426c60b615d8b390c97e9cfb99e Mon Sep 17 00:00:00 2001 From: Robert Young Date: Mon, 22 Jan 2024 13:49:52 +1300 Subject: [PATCH 05/11] Refactor decryption to be batch aware Algorithm: 1. iterate over all records in the MemoryRecords extracting unique Edeks 2. asynchronously resolve all unique Edeks 3. do a second batch-aware iteration to decrypt, assuming that we have a decryptor located for every unique Edek Control batches should never be encrypted so we rely on the absense of the encryption header to signal they are unencrypted Signed-off-by: Robert Young --- .../encryption/inband/InBandKeyManager.java | 162 ++++++++++-------- .../filter/encryption/inband/Parcel.java | 4 +- .../inband/InBandKeyManagerTest.java | 125 +++++++++++++- .../filter/encryption/inband/ParcelTest.java | 4 +- .../EnvelopeEncryptionFilterIT.java | 3 - 5 files changed, 219 insertions(+), 79 deletions(-) diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index 7eee88ea9a..0c98bcc6cc 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -9,11 +9,15 @@ import java.nio.ByteBuffer; import java.security.SecureRandom; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.function.IntFunction; +import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -40,7 +44,6 @@ import io.kroxylicious.filter.encryption.EnvelopeEncryptionFilter; import io.kroxylicious.filter.encryption.KeyManager; import io.kroxylicious.filter.encryption.RecordField; -import io.kroxylicious.filter.encryption.WrapperVersion; import io.kroxylicious.filter.encryption.records.BatchAwareMemoryRecordsBuilder; import io.kroxylicious.kms.service.Kms; import io.kroxylicious.kms.service.Serde; @@ -252,33 +255,42 @@ private CompletableFuture encrypt(@NonNull EncryptionScheme encryptionS for (MutableRecordBatch batch : memoryRecords.batches()) { List records = StreamSupport.stream(batch.spliterator(), false).toList(); builder.addBatchLike(batch); - var maxParcelSize = records.stream() - .mapToInt(kafkaRecord -> Parcel.sizeOfParcel( - encryptionVersion.parcelVersion(), - encryptionScheme.recordFields(), - kafkaRecord)) - .filter(value -> value > 0) - .max() - .orElseThrow(); - var maxWrapperSize = records.stream() - .mapToInt(kafkaRecord -> sizeOfWrapper(keyContext, maxParcelSize)) - .filter(value -> value > 0) - .max() - .orElseThrow(); - ByteBuffer parcelBuffer = bufferPool.acquire(maxParcelSize); - ByteBuffer wrapperBuffer = bufferPool.acquire(maxWrapperSize); - try { - encryptRecords(encryptionScheme, keyContext, records, parcelBuffer, wrapperBuffer, builder); + if (batch.isControlBatch()) { + // the proxy should not encounter these on the produce path as it's written by the transaction co-ordinator + // broker side. No user data is contained, so we do not need to encrypt. + for (Record record : batch) { + builder.append(record); + } } - finally { - if (wrapperBuffer != null) { - bufferPool.release(wrapperBuffer); + else { + var maxParcelSize = records.stream() + .mapToInt(kafkaRecord -> Parcel.sizeOfParcel( + encryptionVersion.parcelVersion(), + encryptionScheme.recordFields(), + kafkaRecord)) + .filter(value -> value > 0) + .max() + .orElseThrow(); + var maxWrapperSize = records.stream() + .mapToInt(kafkaRecord -> sizeOfWrapper(keyContext, maxParcelSize)) + .filter(value -> value > 0) + .max() + .orElseThrow(); + ByteBuffer parcelBuffer = bufferPool.acquire(maxParcelSize); + ByteBuffer wrapperBuffer = bufferPool.acquire(maxWrapperSize); + try { + encryptRecords(encryptionScheme, keyContext, records, parcelBuffer, wrapperBuffer, builder, batch); } - if (parcelBuffer != null) { - bufferPool.release(parcelBuffer); + finally { + if (wrapperBuffer != null) { + bufferPool.release(wrapperBuffer); + } + if (parcelBuffer != null) { + bufferPool.release(parcelBuffer); + } } + keyContext.recordEncryptions(records.size()); } - keyContext.recordEncryptions(records.size()); } return CompletableFuture.completedFuture(null); } @@ -295,7 +307,8 @@ private void encryptRecords(@NonNull EncryptionScheme encryptionScheme, @NonNull List records, @NonNull ByteBuffer parcelBuffer, @NonNull ByteBuffer wrapperBuffer, - @NonNull BatchAwareMemoryRecordsBuilder builder) { + @NonNull BatchAwareMemoryRecordsBuilder builder, + @NonNull MutableRecordBatch batch) { records.forEach(kafkaRecord -> { if (encryptionScheme.recordFields().contains(RecordField.RECORD_HEADER_VALUES) && kafkaRecord.headers().length > 0 @@ -401,51 +414,70 @@ public CompletionStage decrypt(@NonNull String topicName, int par // no records to transform, return input without modification return CompletableFuture.completedFuture(records); } - List encryptionRequests = recordStream(records).toList(); + List descriptions = describeBatches(records); // it is possible to encounter MemoryRecords that have had all their records compacted away, but // the recordbatch metadata still exists. https://kafka.apache.org/documentation/#recordbatch - if (encryptionRequests.isEmpty()) { + if (descriptions.stream().allMatch(batchDescription -> batchDescription.recordCount == 0)) { return CompletableFuture.completedFuture(records); } - ByteBufferOutputStream buffer = allocateBufferForDecode(records, bufferAllocator); - MemoryRecordsBuilder outputBuilder = recordsBuilder(buffer, records); - return decrypt(topicName, partition, recordStream(records).toList(), outputBuilder).thenApply(unused -> outputBuilder.build()); + Set uniqueEdeks = extractEdeks(topicName, partition, records); + CompletionStage> decryptors = resolveAll(uniqueEdeks); + CompletionStage objectCompletionStage = decryptors.thenApply( + encryptorMap -> decrypt(topicName, partition, records, new BatchAwareMemoryRecordsBuilder(allocateBufferForDecode(records, bufferAllocator)), + encryptorMap)); + return objectCompletionStage.thenApply(BatchAwareMemoryRecordsBuilder::build); } - @NonNull - private CompletionStage decrypt(String topicName, - int partition, - @NonNull List records, - @NonNull MemoryRecordsBuilder builder) { - var decryptStateStages = new ArrayList>(records.size()); + private CompletionStage> resolveAll(Set uniqueEdeks) { + CompletionStage>> join = EnvelopeEncryptionFilter.join( + uniqueEdeks.stream().map(e -> decryptorCache.get(e).thenApply(aesGcmEncryptor -> Map.entry(e, aesGcmEncryptor))).toList()); + return join.thenApply(entries -> entries.stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); + } - for (Record kafkaRecord : records) { + private Set extractEdeks(String topicName, int partition, MemoryRecords records) { + Set edeks = new HashSet<>(); + Serde serde = kms.edekSerde(); + for (Record kafkaRecord : records.records()) { var decryptionVersion = decryptionVersion(topicName, partition, kafkaRecord); - if (decryptionVersion == null) { - decryptStateStages.add(CompletableFuture.completedStage(new DecryptState(kafkaRecord, kafkaRecord.value(), null, null))); - } - else { - // right now (because we only support topic name based kek selection) once we've resolved the first value we - // can keep the lock and process all the records + if (decryptionVersion == EncryptionVersion.V1) { ByteBuffer wrapper = kafkaRecord.value(); - decryptStateStages.add( - resolveEncryptor(decryptionVersion.wrapperVersion(), wrapper).thenApply(enc -> new DecryptState(kafkaRecord, wrapper, decryptionVersion, enc))); + var edekLength = ByteUtils.readUnsignedVarint(wrapper); + ByteBuffer slice = wrapper.slice(wrapper.position(), edekLength); + var edek = serde.deserialize(slice); + edeks.add(edek); } } + return edeks; + } - return EnvelopeEncryptionFilter.join(decryptStateStages) - .thenApply(decryptStates -> { - decryptStates.forEach(decryptState -> { - if (decryptState.encryptor() == null) { - Record record = decryptState.kafkaRecord(); - builder.appendWithOffset(record.offset(), record.timestamp(), record.key(), decryptState.valueWrapper(), record.headers()); - } - else { - decryptRecord(decryptState.decryptionVersion(), decryptState.encryptor(), decryptState.valueWrapper(), decryptState.kafkaRecord(), builder); - } - }); - return null; - }); + @NonNull + private BatchAwareMemoryRecordsBuilder decrypt(String topicName, + int partition, + @NonNull MemoryRecords records, + @NonNull BatchAwareMemoryRecordsBuilder builder, + Map encryptorMap) { + for (MutableRecordBatch batch : records.batches()) { + builder.addBatchLike(batch); + for (Record kafkaRecord : batch) { + var decryptionVersion = decryptionVersion(topicName, partition, kafkaRecord); + if (decryptionVersion == null) { + builder.append(kafkaRecord); + } + else if (decryptionVersion == EncryptionVersion.V1) { + ByteBuffer wrapper = kafkaRecord.value(); + var edekLength = ByteUtils.readUnsignedVarint(wrapper); + ByteBuffer slice = wrapper.slice(wrapper.position(), edekLength); + var edek = edekSerde.deserialize(slice); + wrapper.position(wrapper.position() + edekLength); + AesGcmEncryptor aesGcmEncryptor = encryptorMap.get(edek); + if (aesGcmEncryptor == null) { + throw new RuntimeException("no encryptor loaded for edek, " + edek); + } + decryptRecord(EncryptionVersion.V1, aesGcmEncryptor, wrapper, kafkaRecord, builder); + } + } + } + return builder; } private ByteBufferOutputStream allocateBufferForDecode(MemoryRecords memoryRecords, IntFunction allocator) { @@ -458,7 +490,7 @@ private void decryptRecord(EncryptionVersion decryptionVersion, AesGcmEncryptor encryptor, ByteBuffer wrapper, Record kafkaRecord, - @NonNull MemoryRecordsBuilder builder) { + @NonNull BatchAwareMemoryRecordsBuilder builder) { var aadSpec = AadSpec.fromCode(wrapper.get()); ByteBuffer aad = switch (aadSpec) { case NONE -> ByteUtils.EMPTY_BUF; @@ -473,18 +505,6 @@ private void decryptRecord(EncryptionVersion decryptionVersion, Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, kafkaRecord, builder); } - private CompletionStage resolveEncryptor(WrapperVersion wrapperVersion, ByteBuffer wrapper) { - switch (wrapperVersion) { - case V1: - var edekLength = ByteUtils.readUnsignedVarint(wrapper); - ByteBuffer slice = wrapper.slice(wrapper.position(), edekLength); - var edek = edekSerde.deserialize(slice); - wrapper.position(wrapper.position() + edekLength); - return decryptorCache.get(edek); - } - throw new EncryptionException("Unknown wrapper version " + wrapperVersion); - } - private ByteBuffer decryptParcel(ByteBuffer ciphertextParcel, AesGcmEncryptor encryptor) { ByteBuffer plaintext = ciphertextParcel.duplicate(); encryptor.decrypt(ciphertextParcel, plaintext); diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java index 1188bd3f83..debc3b066e 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/Parcel.java @@ -12,7 +12,6 @@ import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.internals.RecordHeader; -import org.apache.kafka.common.record.MemoryRecordsBuilder; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.utils.ByteUtils; import org.apache.kafka.common.utils.Utils; @@ -20,6 +19,7 @@ import io.kroxylicious.filter.encryption.EncryptionException; import io.kroxylicious.filter.encryption.ParcelVersion; import io.kroxylicious.filter.encryption.RecordField; +import io.kroxylicious.filter.encryption.records.BatchAwareMemoryRecordsBuilder; import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; @@ -58,7 +58,7 @@ static void writeParcel(ParcelVersion parcelVersion, Set recordFiel static void readParcel(ParcelVersion parcelVersion, ByteBuffer parcel, Record encryptedRecord, - @NonNull MemoryRecordsBuilder builder) { + @NonNull BatchAwareMemoryRecordsBuilder builder) { switch (parcelVersion) { case V1: var parcelledValue = readRecordValue(parcel); diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java index e3921b3055..ea180107db 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java @@ -26,9 +26,13 @@ import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.internals.RecordHeader; import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.ControlRecordType; import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.SimpleRecord; import org.apache.kafka.common.record.TimestampType; import org.apache.kafka.common.utils.ByteBufferOutputStream; import org.apache.kafka.common.utils.ByteUtils; @@ -118,7 +122,7 @@ void shouldPreserveMultipleBatches() { var kekId = kms.generateKey(); - var value = new byte[]{ 1, 2, 3 }; + byte[] value = { 1, 2, 3 }; Record record = RecordTestUtils.record(1, ByteBuffer.wrap(value)); var value2 = new byte[]{ 4, 5, 6 }; @@ -134,6 +138,8 @@ void shouldPreserveMultipleBatches() { CompletableFuture encryptedFuture = km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new).toCompletableFuture(); assertThat(encryptedFuture).succeedsWithin(Duration.ZERO); MemoryRecords encrypted = encryptedFuture.join(); + record.value().rewind(); + record2.value().rewind(); assertThat(encrypted.batches()).hasSize(2); List batches = StreamSupport.stream(encrypted.batches().spliterator(), false).toList(); @@ -149,6 +155,123 @@ void shouldPreserveMultipleBatches() { assertThat(second.timestampType()).isEqualTo(TimestampType.LOG_APPEND_TIME); assertThat(second.baseOffset()).isEqualTo(2L); assertThat(second).hasSize(1); + + CompletableFuture decryptedFuture = km.decrypt("topic", 1, encrypted, ByteBufferOutputStream::new).toCompletableFuture(); + assertThat(decryptedFuture).succeedsWithin(Duration.ZERO); + MemoryRecords decrypted = decryptedFuture.join(); + + assertThat(decrypted.batches()).hasSize(2); + List decryptedBatches = StreamSupport.stream(decrypted.batches().spliterator(), false).toList(); + MutableRecordBatch firstDecrypted = decryptedBatches.get(0); + assertThat(firstDecrypted.compressionType()).isEqualTo(CompressionType.NONE); + assertThat(firstDecrypted.timestampType()).isEqualTo(TimestampType.CREATE_TIME); + assertThat(firstDecrypted.baseOffset()).isEqualTo(1L); + assertThat(firstDecrypted).hasSize(1); + assertThat(firstDecrypted.iterator()) + .toIterable() + .singleElement() + .extracting(RecordTestUtils::recordValueAsBytes) + .isEqualTo(value); + + MutableRecordBatch secondDecrypted = decryptedBatches.get(1); + assertThat(secondDecrypted.compressionType()).isEqualTo(CompressionType.GZIP); + assertThat(secondDecrypted.timestampType()).isEqualTo(TimestampType.LOG_APPEND_TIME); + assertThat(secondDecrypted.baseOffset()).isEqualTo(2L); + assertThat(secondDecrypted).hasSize(1); + assertThat(secondDecrypted.iterator()) + .toIterable() + .singleElement() + .extracting(RecordTestUtils::recordValueAsBytes) + .isEqualTo(value2); + + } + + @Test + void shouldPreserveControlBatch() { + var kmsService = UnitTestingKmsService.newInstance(); + InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + + var kekId = kms.generateKey(); + + byte[] value = { 1, 2, 3 }; + Record record = RecordTestUtils.record(1, ByteBuffer.wrap(value)); + BatchAwareMemoryRecordsBuilder builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(1000)); + builder.addBatch(CompressionType.NONE, TimestampType.CREATE_TIME, 1); + builder.appendWithOffset(1L, record); + byte[] controlBatchValue = { 4, 5, 6 }; + RecordBatch controlBatch = controlBatch(2, controlBatchValue); + builder.addBatchLike(controlBatch); + builder.append(controlBatch.iterator().next()); + MemoryRecords records = builder.build(); + + EncryptionScheme scheme = new EncryptionScheme<>(kekId, EnumSet.of(RecordField.RECORD_VALUE)); + CompletableFuture encryptedFuture = km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new).toCompletableFuture(); + assertThat(encryptedFuture).succeedsWithin(Duration.ZERO); + MemoryRecords encrypted = encryptedFuture.join(); + record.value().rewind(); + + assertThat(encrypted.batches()).hasSize(2); + List batches = StreamSupport.stream(encrypted.batches().spliterator(), false).toList(); + MutableRecordBatch first = batches.get(0); + assertThat(first.compressionType()).isEqualTo(CompressionType.NONE); + assertThat(first.timestampType()).isEqualTo(TimestampType.CREATE_TIME); + assertThat(first.baseOffset()).isEqualTo(1L); + assertThat(first).hasSize(1); + + MutableRecordBatch second = batches.get(1); + // should we keep the client's compression type? + assertThat(second.compressionType()).isEqualTo(controlBatch.compressionType()); + assertThat(second.timestampType()).isEqualTo(controlBatch.timestampType()); + assertThat(second.baseOffset()).isEqualTo(controlBatch.baseOffset()); + assertThat(second.isControlBatch()).isTrue(); + assertThat(second).hasSize(1); + // control batches are not encrypted + assertThat(second.iterator()) + .toIterable() + .singleElement() + .extracting(RecordTestUtils::recordValueAsBytes) + .isEqualTo(controlBatchValue); + + CompletableFuture decryptedFuture = km.decrypt("topic", 1, encrypted, ByteBufferOutputStream::new).toCompletableFuture(); + assertThat(decryptedFuture).succeedsWithin(Duration.ZERO); + MemoryRecords decrypted = decryptedFuture.join(); + + assertThat(decrypted.batches()).hasSize(2); + List decryptedBatches = StreamSupport.stream(decrypted.batches().spliterator(), false).toList(); + MutableRecordBatch firstDecrypted = decryptedBatches.get(0); + assertThat(firstDecrypted.compressionType()).isEqualTo(CompressionType.NONE); + assertThat(firstDecrypted.timestampType()).isEqualTo(TimestampType.CREATE_TIME); + assertThat(firstDecrypted.baseOffset()).isEqualTo(1L); + assertThat(firstDecrypted).hasSize(1); + assertThat(firstDecrypted.iterator()) + .toIterable() + .singleElement() + .extracting(RecordTestUtils::recordValueAsBytes) + .isEqualTo(value); + + MutableRecordBatch secondDecrypted = decryptedBatches.get(1); + assertThat(secondDecrypted.compressionType()).isEqualTo(controlBatch.compressionType()); + assertThat(secondDecrypted.timestampType()).isEqualTo(controlBatch.timestampType()); + assertThat(secondDecrypted.baseOffset()).isEqualTo(controlBatch.baseOffset()); + assertThat(secondDecrypted.isControlBatch()).isTrue(); + assertThat(secondDecrypted).hasSize(1); + // control batch value is preserved + assertThat(second.iterator()) + .toIterable() + .singleElement() + .extracting(RecordTestUtils::recordValueAsBytes) + .isEqualTo(controlBatchValue); + + } + + private static RecordBatch controlBatch(int baseOffset, byte[] arbitraryValue) { + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(ByteBuffer.allocate(1000), RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, baseOffset, 1L, 1L, (short) 1, 1, false, true, 1, 1); + byte[] key = { 0, 0, (byte) ControlRecordType.ABORT.type(), (byte) (ControlRecordType.ABORT.type() >> 8) }; + builder.appendControlRecordWithOffset(baseOffset, new SimpleRecord(1L, key, arbitraryValue)); + MemoryRecords controlBatchRecords = builder.build(); + return controlBatchRecords.firstBatch(); } @NonNull diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java index 1457c30790..23bbeffd6d 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/ParcelTest.java @@ -12,7 +12,6 @@ import java.util.stream.Stream; import org.apache.kafka.common.header.internals.RecordHeader; -import org.apache.kafka.common.record.MemoryRecordsBuilder; import org.apache.kafka.common.record.Record; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -21,6 +20,7 @@ import io.kroxylicious.filter.encryption.ParcelVersion; import io.kroxylicious.filter.encryption.RecordField; +import io.kroxylicious.filter.encryption.records.BatchAwareMemoryRecordsBuilder; import io.kroxylicious.test.record.RecordTestUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -57,7 +57,7 @@ void shouldRoundTrip(Set fields, Record record) { buffer.flip(); - MemoryRecordsBuilder mockBuilder = Mockito.mock(MemoryRecordsBuilder.class); + BatchAwareMemoryRecordsBuilder mockBuilder = Mockito.mock(BatchAwareMemoryRecordsBuilder.class); Parcel.readParcel(ParcelVersion.V1, buffer, record, mockBuilder); verify(mockBuilder).appendWithOffset(record.offset(), record.timestamp(), record.key(), expectedValue, record.headers()); assertThat(buffer.remaining()).isEqualTo(0); diff --git a/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java b/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java index ec4a0e16e8..15e5bf60a0 100644 --- a/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java +++ b/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java @@ -25,7 +25,6 @@ import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.TopicPartition; import org.assertj.core.api.InstanceOfAssertFactories; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; @@ -84,8 +83,6 @@ void roundTripSingleRecord(KafkaCluster cluster, Topic topic, TestKmsFacade testKmsFacade) { var testKekManager = testKmsFacade.getTestKekManager(); From 898acd95fe16653b0174e40a0f194c48656ba9ed Mon Sep 17 00:00:00 2001 From: Robert Young Date: Mon, 22 Jan 2024 15:29:54 +1300 Subject: [PATCH 06/11] Envelope encryption: pass-through empty record batches Why: We don't have to do any encryption work on batches with all records removed by compaction. But we do want to preserve them to keep the proxy transparent and returning the same response as Kafka Signed-off-by: Robert Young --- .../test/record/RecordTestUtils.java | 14 ++++- .../encryption/inband/InBandKeyManager.java | 63 ++++++++++--------- .../BatchAwareMemoryRecordsBuilder.java | 16 +++++ .../inband/InBandKeyManagerTest.java | 63 +++++++++++++++++++ 4 files changed, 126 insertions(+), 30 deletions(-) diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java index 0ccb8e5154..29c693bcab 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java @@ -343,7 +343,12 @@ public static MemoryRecords memoryRecords(@NonNull List records) { * @see Apache Kafka RecordBatch documentation */ public static MemoryRecords memoryRecordsWithAllRecordsRemoved() { - try (MemoryRecordsBuilder memoryRecordsBuilder = defaultMemoryRecordsBuilder(DEFAULT_MAGIC_VALUE)) { + return memoryRecordsWithAllRecordsRemoved(0L); + } + + @NonNull + public static MemoryRecords memoryRecordsWithAllRecordsRemoved(long baseOffset) { + try (MemoryRecordsBuilder memoryRecordsBuilder = memoryRecordsBuilder(DEFAULT_MAGIC_VALUE, baseOffset)) { // append arbitrary record memoryRecordsBuilder.append(DEFAULT_TIMESTAMP, new byte[]{ 1, 2, 3 }, new byte[]{ 1, 2, 3 }); MemoryRecords records = memoryRecordsBuilder.build(); @@ -366,12 +371,17 @@ protected boolean shouldRetainRecord(RecordBatch recordBatch, Record record) { } private static MemoryRecordsBuilder defaultMemoryRecordsBuilder(byte magic) { + return memoryRecordsBuilder(magic, 0L); + } + + @NonNull + private static MemoryRecordsBuilder memoryRecordsBuilder(byte magic, long baseOffset) { return new MemoryRecordsBuilder( ByteBuffer.allocate(1024), magic, CompressionType.NONE, TimestampType.CREATE_TIME, - 0L, + baseOffset, 0L, 0L, (short) 0, diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index 0c98bcc6cc..190496e498 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -239,7 +239,7 @@ private CompletionStage attemptEncrypt(String topicName, int partition, @N } else { // todo ensure that a failure during encryption terminates the entire operation with a failed future - return encrypt(encryptionScheme, records, builder, keyContext); + return encrypt(encryptionScheme, records, builder, keyContext, batchDescriptions); } } } @@ -251,18 +251,17 @@ private CompletionStage attemptEncrypt(String topicName, int partition, @N private CompletableFuture encrypt(@NonNull EncryptionScheme encryptionScheme, @NonNull MemoryRecords memoryRecords, @NonNull BatchAwareMemoryRecordsBuilder builder, - @NonNull KeyContext keyContext) { + @NonNull KeyContext keyContext, + @NonNull List batchDescriptions) { + int i = 0; for (MutableRecordBatch batch : memoryRecords.batches()) { - List records = StreamSupport.stream(batch.spliterator(), false).toList(); - builder.addBatchLike(batch); - if (batch.isControlBatch()) { - // the proxy should not encounter these on the produce path as it's written by the transaction co-ordinator - // broker side. No user data is contained, so we do not need to encrypt. - for (Record record : batch) { - builder.append(record); - } + BatchDescription batchDescription = batchDescriptions.get(i++); + if (batchDescription.recordCount == 0 || batch.isControlBatch()) { + builder.writeBatch(batch); } else { + List records = StreamSupport.stream(batch.spliterator(), false).toList(); + builder.addBatchLike(batch); var maxParcelSize = records.stream() .mapToInt(kafkaRecord -> Parcel.sizeOfParcel( encryptionVersion.parcelVersion(), @@ -424,7 +423,7 @@ public CompletionStage decrypt(@NonNull String topicName, int par CompletionStage> decryptors = resolveAll(uniqueEdeks); CompletionStage objectCompletionStage = decryptors.thenApply( encryptorMap -> decrypt(topicName, partition, records, new BatchAwareMemoryRecordsBuilder(allocateBufferForDecode(records, bufferAllocator)), - encryptorMap)); + encryptorMap, descriptions)); return objectCompletionStage.thenApply(BatchAwareMemoryRecordsBuilder::build); } @@ -455,25 +454,33 @@ private BatchAwareMemoryRecordsBuilder decrypt(String topicName, int partition, @NonNull MemoryRecords records, @NonNull BatchAwareMemoryRecordsBuilder builder, - Map encryptorMap) { + @NonNull Map encryptorMap, + @NonNull List descriptions) { + int i = 0; for (MutableRecordBatch batch : records.batches()) { - builder.addBatchLike(batch); - for (Record kafkaRecord : batch) { - var decryptionVersion = decryptionVersion(topicName, partition, kafkaRecord); - if (decryptionVersion == null) { - builder.append(kafkaRecord); - } - else if (decryptionVersion == EncryptionVersion.V1) { - ByteBuffer wrapper = kafkaRecord.value(); - var edekLength = ByteUtils.readUnsignedVarint(wrapper); - ByteBuffer slice = wrapper.slice(wrapper.position(), edekLength); - var edek = edekSerde.deserialize(slice); - wrapper.position(wrapper.position() + edekLength); - AesGcmEncryptor aesGcmEncryptor = encryptorMap.get(edek); - if (aesGcmEncryptor == null) { - throw new RuntimeException("no encryptor loaded for edek, " + edek); + BatchDescription batchDescription = descriptions.get(i++); + if (batchDescription.recordCount == 0 || batch.isControlBatch()) { + builder.writeBatch(batch); + } + else { + builder.addBatchLike(batch); + for (Record kafkaRecord : batch) { + var decryptionVersion = decryptionVersion(topicName, partition, kafkaRecord); + if (decryptionVersion == null) { + builder.append(kafkaRecord); + } + else if (decryptionVersion == EncryptionVersion.V1) { + ByteBuffer wrapper = kafkaRecord.value(); + var edekLength = ByteUtils.readUnsignedVarint(wrapper); + ByteBuffer slice = wrapper.slice(wrapper.position(), edekLength); + var edek = edekSerde.deserialize(slice); + wrapper.position(wrapper.position() + edekLength); + AesGcmEncryptor aesGcmEncryptor = encryptorMap.get(edek); + if (aesGcmEncryptor == null) { + throw new RuntimeException("no encryptor loaded for edek, " + edek); + } + decryptRecord(EncryptionVersion.V1, aesGcmEncryptor, wrapper, kafkaRecord, builder); } - decryptRecord(EncryptionVersion.V1, aesGcmEncryptor, wrapper, kafkaRecord, builder); } } } diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java index 0e2c147b09..e50914b14d 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java @@ -16,6 +16,7 @@ import org.apache.kafka.common.record.EndTransactionMarker; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.record.RecordBatch; import org.apache.kafka.common.record.SimpleRecord; @@ -144,6 +145,21 @@ public BatchAwareMemoryRecordsBuilder addBatch(CompressionType compressionType, templateBatch.deleteHorizonMs().orElse(RecordBatch.NO_TIMESTAMP)); } + /** + * Directly appends a batch, intended to be used for passing through unmodified batches. Writes + * the previous batch to the stream if required + * @param templateBatch The batch to use as a source of batch parameters + * @return this builder + */ + public @NonNull BatchAwareMemoryRecordsBuilder writeBatch(MutableRecordBatch templateBatch) { + if (haveBatch()) { + appendCurrentBatch(); + } + templateBatch.writeTo(buffer); + builder = null; + return this; + } + private void maybeAppendCurrentBatch() { if (haveBatch()) { appendCurrentBatch(); diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java index ea180107db..8d00274389 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java @@ -265,6 +265,69 @@ void shouldPreserveControlBatch() { } + @Test + void shouldPreserveMultipleBatches_IncludingEmptyBatch() { + var kmsService = UnitTestingKmsService.newInstance(); + InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + + var kekId = kms.generateKey(); + + byte[] value = { 1, 2, 3 }; + Record record = RecordTestUtils.record(1, ByteBuffer.wrap(value)); + BatchAwareMemoryRecordsBuilder builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(1000)); + builder.addBatch(CompressionType.NONE, TimestampType.CREATE_TIME, 1); + builder.appendWithOffset(1L, record); + + MemoryRecords empty = RecordTestUtils.memoryRecordsWithAllRecordsRemoved(2L); + MutableRecordBatch emptyBatch = empty.batches().iterator().next(); + builder.writeBatch(emptyBatch); + MemoryRecords records = builder.build(); + + EncryptionScheme scheme = new EncryptionScheme<>(kekId, EnumSet.of(RecordField.RECORD_VALUE)); + CompletableFuture encryptedFuture = km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new).toCompletableFuture(); + assertThat(encryptedFuture).succeedsWithin(Duration.ZERO); + MemoryRecords encrypted = encryptedFuture.join(); + record.value().rewind(); + + assertThat(encrypted.batches()).hasSize(2); + List batches = StreamSupport.stream(encrypted.batches().spliterator(), false).toList(); + MutableRecordBatch first = batches.get(0); + assertThat(first.compressionType()).isEqualTo(CompressionType.NONE); + assertThat(first.timestampType()).isEqualTo(TimestampType.CREATE_TIME); + assertThat(first.baseOffset()).isEqualTo(1L); + assertThat(first).hasSize(1); + + MutableRecordBatch second = batches.get(1); + // should we keep the client's compression type? + assertThat(second.compressionType()).isEqualTo(emptyBatch.compressionType()); + assertThat(second.timestampType()).isEqualTo(emptyBatch.timestampType()); + assertThat(second.baseOffset()).isEqualTo(emptyBatch.baseOffset()); + assertThat(second).hasSize(0); + + CompletableFuture decryptedFuture = km.decrypt("topic", 1, encrypted, ByteBufferOutputStream::new).toCompletableFuture(); + assertThat(decryptedFuture).succeedsWithin(Duration.ZERO); + MemoryRecords decrypted = decryptedFuture.join(); + + assertThat(decrypted.batches()).hasSize(2); + List decryptedBatches = StreamSupport.stream(decrypted.batches().spliterator(), false).toList(); + MutableRecordBatch firstDecrypted = decryptedBatches.get(0); + assertThat(firstDecrypted.compressionType()).isEqualTo(CompressionType.NONE); + assertThat(firstDecrypted.timestampType()).isEqualTo(TimestampType.CREATE_TIME); + assertThat(firstDecrypted.baseOffset()).isEqualTo(1L); + assertThat(firstDecrypted).hasSize(1); + assertThat(firstDecrypted.iterator()) + .toIterable() + .singleElement() + .extracting(RecordTestUtils::recordValueAsBytes) + .isEqualTo(value); + + MutableRecordBatch secondDecrypted = decryptedBatches.get(1); + assertThat(secondDecrypted.compressionType()).isEqualTo(emptyBatch.compressionType()); + assertThat(secondDecrypted.timestampType()).isEqualTo(emptyBatch.timestampType()); + assertThat(secondDecrypted.baseOffset()).isEqualTo(emptyBatch.baseOffset()); + } + private static RecordBatch controlBatch(int baseOffset, byte[] arbitraryValue) { MemoryRecordsBuilder builder = new MemoryRecordsBuilder(ByteBuffer.allocate(1000), RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, TimestampType.CREATE_TIME, baseOffset, 1L, 1L, (short) 1, 1, false, true, 1, 1); From 26fa9a026aff0f430770949a6820c99f97b2199a Mon Sep 17 00:00:00 2001 From: Robert Young Date: Mon, 22 Jan 2024 16:11:21 +1300 Subject: [PATCH 07/11] Apply review suggestions Signed-off-by: Robert Young --- .../test/record/RecordTestUtils.java | 14 ++- .../encryption/inband/InBandKeyManager.java | 101 +++++------------- .../BatchAwareMemoryRecordsBuilder.java | 8 +- .../BatchAwareMemoryRecordsBuilderTest.java | 83 ++++++++++++++ 4 files changed, 122 insertions(+), 84 deletions(-) diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java index 29c693bcab..0c1987703f 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java @@ -329,6 +329,15 @@ public static MemoryRecords memoryRecords(@NonNull List records) { } } + /** + * Simulates a MemoryRecords that contained some records, but then had all it's records removed by log compaction. Sets + * the baseOffset of the single batch within the MemoryRecords to 0. + * @see RecordTestUtils#memoryRecordsWithAllRecordsRemoved(long) + */ + public static MemoryRecords memoryRecordsWithAllRecordsRemoved() { + return memoryRecordsWithAllRecordsRemoved(0L); + } + /** * This is a special case that is different from {@link MemoryRecords#EMPTY}. An empty MemoryRecords is * backed by a 0-length buffer. In this case we are simulating a MemoryRecords that contained some @@ -340,12 +349,9 @@ public static MemoryRecords memoryRecords(@NonNull List records) { * the log when all the records in the batch are cleaned but batch is still retained in order to preserve * a producer's last sequence number. *

+ * @param baseOffset the baseOffset of the single batch contained in the output MemoryRecords * @see Apache Kafka RecordBatch documentation */ - public static MemoryRecords memoryRecordsWithAllRecordsRemoved() { - return memoryRecordsWithAllRecordsRemoved(0L); - } - @NonNull public static MemoryRecords memoryRecordsWithAllRecordsRemoved(long baseOffset) { try (MemoryRecordsBuilder memoryRecordsBuilder = memoryRecordsBuilder(DEFAULT_MAGIC_VALUE, baseOffset)) { diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index 190496e498..ac973e6a76 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -12,22 +12,18 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.function.IntFunction; import java.util.stream.Collectors; -import java.util.stream.Stream; import java.util.stream.StreamSupport; import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.internals.RecordHeader; import org.apache.kafka.common.record.MemoryRecords; -import org.apache.kafka.common.record.MemoryRecordsBuilder; import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; -import org.apache.kafka.common.record.RecordBatch; import org.apache.kafka.common.utils.BufferSupplier; import org.apache.kafka.common.utils.ByteBufferOutputStream; import org.apache.kafka.common.utils.ByteUtils; @@ -124,26 +120,6 @@ private CompletableFuture makeKeyContext(@NonNull K kekId) { }).toCompletableFuture(); } - record BatchDescription(int index, MutableRecordBatch batch, int recordCount) { - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - BatchDescription that = (BatchDescription) o; - return index == that.index && recordCount == that.recordCount; - } - - @Override - public int hashCode() { - return Objects.hash(index, recordCount); - } - } - @Override @NonNull @SuppressWarnings("java:S2445") @@ -157,27 +133,26 @@ public CompletionStage encrypt(@NonNull String topicName, return CompletableFuture.completedFuture(records); } - List descriptions = describeBatches(records); + List batchRecordCounts = batchRecordCounts(records); // it is possible to encounter MemoryRecords that have had all their records compacted away, but // the recordbatch metadata still exists. https://kafka.apache.org/documentation/#recordbatch - if (descriptions.stream().allMatch(batchDescription -> batchDescription.recordCount == 0)) { + if (batchRecordCounts.stream().allMatch(size -> size == 0)) { return CompletableFuture.completedFuture(records); } BatchAwareMemoryRecordsBuilder builder = new BatchAwareMemoryRecordsBuilder(allocateBufferForEncode(records, bufferAllocator)); - return attemptEncrypt(topicName, partition, encryptionScheme, records, builder, 0, descriptions).thenApply(unused -> builder.build()); + return attemptEncrypt(topicName, partition, encryptionScheme, records, builder, 0, batchRecordCounts).thenApply(unused -> builder.build()); } @NonNull - private static List describeBatches(@NonNull MemoryRecords records) { - int batchIndex = 0; - List descriptions = new ArrayList<>(); + private static List batchRecordCounts(@NonNull MemoryRecords records) { + List sizes = new ArrayList<>(); for (MutableRecordBatch batch : records.batches()) { - descriptions.add(new BatchDescription(batchIndex++, batch, batchSize(batch))); + sizes.add(recordCount(batch)); } - return descriptions; + return sizes; } - private static int batchSize(MutableRecordBatch batch) { + private static int recordCount(MutableRecordBatch batch) { Integer count = batch.countOrNull(); if (count == null) { // for magic <2 count will be null @@ -192,28 +167,6 @@ private static int batchSize(MutableRecordBatch batch) { return count; } - @NonNull - private static Stream recordStream(MemoryRecords memoryRecords) { - return StreamSupport.stream(memoryRecords.records().spliterator(), false); - } - - private static MemoryRecordsBuilder recordsBuilder(@NonNull ByteBufferOutputStream buffer, @NonNull MemoryRecords records) { - RecordBatch firstBatch = records.firstBatch(); - return new MemoryRecordsBuilder(buffer, - firstBatch.magic(), - firstBatch.compressionType(), // TODO we might not want to use the client's compression - firstBatch.timestampType(), - firstBatch.baseOffset(), - 0L, - firstBatch.producerId(), - firstBatch.producerEpoch(), - firstBatch.baseSequence(), - firstBatch.isTransactional(), - firstBatch.isControlBatch(), - firstBatch.partitionLeaderEpoch(), - 0); - } - private ByteBufferOutputStream allocateBufferForEncode(MemoryRecords records, IntFunction bufferAllocator) { int sizeEstimate = 2 * records.sizeInBytes(); // Accurate estimation is tricky without knowing the record sizes @@ -222,28 +175,28 @@ private ByteBufferOutputStream allocateBufferForEncode(MemoryRecords records, In @SuppressWarnings("java:S2445") private CompletionStage attemptEncrypt(String topicName, int partition, @NonNull EncryptionScheme encryptionScheme, @NonNull MemoryRecords records, - BatchAwareMemoryRecordsBuilder builder, int attempt, List batchDescriptions) { - int recordsCount = batchDescriptions.stream().mapToInt(value -> value.recordCount).sum(); + BatchAwareMemoryRecordsBuilder builder, int attempt, List batchRecordCounts) { + int allRecordsCount = batchRecordCounts.stream().mapToInt(value -> value).sum(); if (attempt >= MAX_ATTEMPTS) { return CompletableFuture.failedFuture( - new RequestNotSatisfiable("failed to reserve an EDEK to encrypt " + recordsCount + " records for topic " + topicName + " partition " + new RequestNotSatisfiable("failed to reserve an EDEK to encrypt " + allRecordsCount + " records for topic " + topicName + " partition " + partition + " after " + attempt + " attempts")); } return currentDekContext(encryptionScheme.kekId()).thenCompose(keyContext -> { synchronized (keyContext) { // if it's not alive we know a previous encrypt call has removed this stage from the cache and fall through to retry encrypt if (!keyContext.isDestroyed()) { - if (!keyContext.hasAtLeastRemainingEncryptions(recordsCount)) { + if (!keyContext.hasAtLeastRemainingEncryptions(allRecordsCount)) { // remove the key context from the cache, then call encrypt again to drive caffeine to recreate it rotateKeyContext(encryptionScheme, keyContext); } else { // todo ensure that a failure during encryption terminates the entire operation with a failed future - return encrypt(encryptionScheme, records, builder, keyContext, batchDescriptions); + return encrypt(encryptionScheme, records, builder, keyContext, batchRecordCounts); } } } - return attemptEncrypt(topicName, partition, encryptionScheme, records, builder, attempt + 1, batchDescriptions); + return attemptEncrypt(topicName, partition, encryptionScheme, records, builder, attempt + 1, batchRecordCounts); }); } @@ -252,11 +205,11 @@ private CompletableFuture encrypt(@NonNull EncryptionScheme encryptionS @NonNull MemoryRecords memoryRecords, @NonNull BatchAwareMemoryRecordsBuilder builder, @NonNull KeyContext keyContext, - @NonNull List batchDescriptions) { + @NonNull List batchRecordCounts) { int i = 0; for (MutableRecordBatch batch : memoryRecords.batches()) { - BatchDescription batchDescription = batchDescriptions.get(i++); - if (batchDescription.recordCount == 0 || batch.isControlBatch()) { + Integer batchRecordCount = batchRecordCounts.get(i++); + if (batchRecordCount == 0 || batch.isControlBatch()) { builder.writeBatch(batch); } else { @@ -278,7 +231,7 @@ private CompletableFuture encrypt(@NonNull EncryptionScheme encryptionS ByteBuffer parcelBuffer = bufferPool.acquire(maxParcelSize); ByteBuffer wrapperBuffer = bufferPool.acquire(maxWrapperSize); try { - encryptRecords(encryptionScheme, keyContext, records, parcelBuffer, wrapperBuffer, builder, batch); + encryptRecords(encryptionScheme, keyContext, records, parcelBuffer, wrapperBuffer, builder); } finally { if (wrapperBuffer != null) { @@ -306,8 +259,7 @@ private void encryptRecords(@NonNull EncryptionScheme encryptionScheme, @NonNull List records, @NonNull ByteBuffer parcelBuffer, @NonNull ByteBuffer wrapperBuffer, - @NonNull BatchAwareMemoryRecordsBuilder builder, - @NonNull MutableRecordBatch batch) { + @NonNull BatchAwareMemoryRecordsBuilder builder) { records.forEach(kafkaRecord -> { if (encryptionScheme.recordFields().contains(RecordField.RECORD_HEADER_VALUES) && kafkaRecord.headers().length > 0 @@ -402,9 +354,6 @@ private CompletableFuture makeDecryptor(E edek) { .thenApply(AesGcmEncryptor::forDecrypt).toCompletableFuture(); } - private record DecryptState(@NonNull Record kafkaRecord, @NonNull ByteBuffer valueWrapper, @Nullable EncryptionVersion decryptionVersion, - @Nullable AesGcmEncryptor encryptor) {} - @NonNull @Override public CompletionStage decrypt(@NonNull String topicName, int partition, @NonNull MemoryRecords records, @@ -413,17 +362,17 @@ public CompletionStage decrypt(@NonNull String topicName, int par // no records to transform, return input without modification return CompletableFuture.completedFuture(records); } - List descriptions = describeBatches(records); + List batchRecordCounts = batchRecordCounts(records); // it is possible to encounter MemoryRecords that have had all their records compacted away, but // the recordbatch metadata still exists. https://kafka.apache.org/documentation/#recordbatch - if (descriptions.stream().allMatch(batchDescription -> batchDescription.recordCount == 0)) { + if (batchRecordCounts.stream().allMatch(recordCount -> recordCount == 0)) { return CompletableFuture.completedFuture(records); } Set uniqueEdeks = extractEdeks(topicName, partition, records); CompletionStage> decryptors = resolveAll(uniqueEdeks); CompletionStage objectCompletionStage = decryptors.thenApply( encryptorMap -> decrypt(topicName, partition, records, new BatchAwareMemoryRecordsBuilder(allocateBufferForDecode(records, bufferAllocator)), - encryptorMap, descriptions)); + encryptorMap, batchRecordCounts)); return objectCompletionStage.thenApply(BatchAwareMemoryRecordsBuilder::build); } @@ -455,11 +404,11 @@ private BatchAwareMemoryRecordsBuilder decrypt(String topicName, @NonNull MemoryRecords records, @NonNull BatchAwareMemoryRecordsBuilder builder, @NonNull Map encryptorMap, - @NonNull List descriptions) { + @NonNull List batchRecordCounts) { int i = 0; for (MutableRecordBatch batch : records.batches()) { - BatchDescription batchDescription = descriptions.get(i++); - if (batchDescription.recordCount == 0 || batch.isControlBatch()) { + Integer batchRecordCount = batchRecordCounts.get(i++); + if (batchRecordCount == 0 || batch.isControlBatch()) { builder.writeBatch(batch); } else { diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java index e50914b14d..da1c150b52 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java @@ -147,15 +147,15 @@ public BatchAwareMemoryRecordsBuilder addBatch(CompressionType compressionType, /** * Directly appends a batch, intended to be used for passing through unmodified batches. Writes - * the previous batch to the stream if required - * @param templateBatch The batch to use as a source of batch parameters + * and closes the previous MemoryRecordBuilder batch to the stream if required + * @param batch The batch to write to the buffer * @return this builder */ - public @NonNull BatchAwareMemoryRecordsBuilder writeBatch(MutableRecordBatch templateBatch) { + public @NonNull BatchAwareMemoryRecordsBuilder writeBatch(@NonNull MutableRecordBatch batch) { if (haveBatch()) { appendCurrentBatch(); } - templateBatch.writeTo(buffer); + batch.writeTo(buffer); builder = null; return this; } diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java index 3c6529dac6..341a1161fe 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java @@ -11,8 +11,10 @@ import java.util.List; import java.util.stream.StreamSupport; +import org.apache.kafka.common.header.Header; import org.apache.kafka.common.record.CompressionType; import org.apache.kafka.common.record.ControlRecordType; +import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.record.RecordBatch; @@ -23,6 +25,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import io.kroxylicious.test.record.RecordTestUtils; + import edu.umd.cs.findbugs.annotations.NonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -42,6 +46,85 @@ void shouldRequireABatchBeforeAppend() { .hasMessageContaining("You must start a batch"); } + @Test + void shouldBePossibleToWriteBatchDirectly() { + // Given + var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); + MemoryRecords input = RecordTestUtils.memoryRecords("a", "b"); + MutableRecordBatch recordBatch = input.batchIterator().next(); + + // When + builder.writeBatch(recordBatch); + MemoryRecords output = builder.build(); + + // Then + assertThat(output).isEqualTo(input); + } + + @Test + void shouldBePossibleToWriteBatchAfterBuildingABatch() { + // Given + var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); + builder.addBatch(CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + byte[] value1 = { 4, 5, 6 }; + builder.appendWithOffset(0L, 1L, new byte[]{ 1, 2, 3 }, value1, new Header[]{}); + byte[] value2 = { 10, 11, 12 }; + MemoryRecords input = RecordTestUtils.memoryRecords(RecordBatch.CURRENT_MAGIC_VALUE, 1L, 1L, new byte[]{ 7, 8, 9 }, value2); + MutableRecordBatch recordBatch = input.batchIterator().next(); + + // When + builder.writeBatch(recordBatch); + MemoryRecords output = builder.build(); + + // Then + List batches = StreamSupport.stream(output.batches().spliterator(), false).toList(); + assertThat(batches).hasSize(2); + + var batch1 = batches.get(0); + assertThat(batch1.countOrNull()).isEqualTo(1); + Record batch1Record = batch1.iterator().next(); + assertThat(batch1Record.value()).isEqualTo(ByteBuffer.wrap(value1)); + assertThat(batch1Record.offset()).isZero(); + + var batch2 = batches.get(1); + assertThat(batch2.countOrNull()).isEqualTo(1); + Record batch2Record = batch2.iterator().next(); + assertThat(batch2Record.value()).isEqualTo(ByteBuffer.wrap(value2)); + assertThat(batch2Record.offset()).isEqualTo(1); + } + + @Test + void shouldBePossibleToBuildABatchAfterWritingBatch() { + // Given + byte[] value1 = { 10, 11, 12 }; + var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); + MemoryRecords input = RecordTestUtils.memoryRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, 1L, new byte[]{ 7, 8, 9 }, value1); + MutableRecordBatch recordBatch = input.batchIterator().next(); + builder.writeBatch(recordBatch); + + // When + builder.addBatch(CompressionType.NONE, TimestampType.CREATE_TIME, 1L); + byte[] value2 = { 4, 5, 6 }; + builder.appendWithOffset(1L, 1L, new byte[]{ 1, 2, 3 }, value2, new Header[]{}); + MemoryRecords output = builder.build(); + + // Then + List batches = StreamSupport.stream(output.batches().spliterator(), false).toList(); + assertThat(batches).hasSize(2); + + var batch1 = batches.get(0); + assertThat(batch1.countOrNull()).isEqualTo(1); + Record batch1Record = batch1.iterator().next(); + assertThat(batch1Record.value()).isEqualTo(ByteBuffer.wrap(value1)); + assertThat(batch1Record.offset()).isZero(); + + var batch2 = batches.get(1); + assertThat(batch2.countOrNull()).isEqualTo(1); + Record batch2Record = batch2.iterator().next(); + assertThat(batch2Record.value()).isEqualTo(ByteBuffer.wrap(value2)); + assertThat(batch2Record.offset()).isEqualTo(1); + } + @Test void shouldPreventAppendAfterBuild1() { // Given From 81c12da621dcfb8913d812e660479c5e7d9573e9 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Tue, 23 Jan 2024 11:57:10 +1300 Subject: [PATCH 08/11] Refactor BatchAwareMemoryRecordsBuilder to track if it has been closed Why: With the new direct passthrough ability where we can hand it an existing batch and have it written to the buffer, we can no longer exclusively rely on the state of the builder variable to track whether we've finished working with the BatchAwareMemoryRecordsBuilder. It can happen that the last builder in the variable is closed by the time we call build(). Instead we will add a closed variable. Since the builder is marked NonThreadSafe we do not consider async calls to build/close. Signed-off-by: Robert Young --- .../BatchAwareMemoryRecordsBuilder.java | 30 ++++- .../BatchAwareMemoryRecordsBuilderTest.java | 111 +++++++++++++++++- 2 files changed, 131 insertions(+), 10 deletions(-) diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java index da1c150b52..dff6b6da29 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java @@ -36,6 +36,7 @@ public class BatchAwareMemoryRecordsBuilder implements AutoCloseable { private final ByteBufferOutputStream buffer; private MemoryRecordsBuilder builder = null; + private boolean closed = false; /** * Initialize a new instance, which will append into the given buffer. @@ -59,6 +60,12 @@ private void checkHasBatch() { } } + private void checkIfClosed() { + if (closed) { + throw new IllegalStateException("Builder is closed"); + } + } + /** * Starts a batch * @param magic @@ -87,6 +94,7 @@ private void checkHasBatch() { boolean isControlBatch, int partitionLeaderEpoch, long deleteHorizonMs) { + checkIfClosed(); maybeAppendCurrentBatch(); // MRB respects the initial position() of buffer, so this doesn't overwrite anything already in buffer builder = new MemoryRecordsBuilder(buffer, @@ -152,11 +160,11 @@ public BatchAwareMemoryRecordsBuilder addBatch(CompressionType compressionType, * @return this builder */ public @NonNull BatchAwareMemoryRecordsBuilder writeBatch(@NonNull MutableRecordBatch batch) { + checkIfClosed(); if (haveBatch()) { appendCurrentBatch(); } batch.writeTo(buffer); - builder = null; return this; } @@ -177,6 +185,7 @@ private void appendCurrentBatch() { * @return This builder */ public @NonNull BatchAwareMemoryRecordsBuilder append(SimpleRecord record) { + checkIfClosed(); checkHasBatch(); builder.append(record); return this; @@ -188,6 +197,7 @@ private void appendCurrentBatch() { * @return This builder */ public @NonNull BatchAwareMemoryRecordsBuilder append(Record record) { + checkIfClosed(); checkHasBatch(); builder.append(record); return this; @@ -200,6 +210,7 @@ private void appendCurrentBatch() { * @return This builder */ public BatchAwareMemoryRecordsBuilder appendWithOffset(long offset, Record record) { + checkIfClosed(); checkHasBatch(); builder.appendWithOffset(offset, record); return this; @@ -215,6 +226,7 @@ public BatchAwareMemoryRecordsBuilder appendWithOffset(long offset, Record recor * @return This builder */ public @NonNull BatchAwareMemoryRecordsBuilder appendWithOffset(long offset, long timestamp, byte[] key, byte[] value, Header[] headers) { + checkIfClosed(); checkHasBatch(); builder.appendWithOffset(offset, timestamp, key, value, headers); return this; @@ -234,12 +246,14 @@ public BatchAwareMemoryRecordsBuilder appendWithOffset(long offset, Record recor ByteBuffer key, ByteBuffer value, Header[] headers) { + checkIfClosed(); checkHasBatch(); builder.appendWithOffset(offset, timestamp, key, value, headers); return this; } public @NonNull BatchAwareMemoryRecordsBuilder appendControlRecordWithOffset(long offset, @NonNull SimpleRecord record) { + checkIfClosed(); checkHasBatch(); builder.appendControlRecordWithOffset(offset, record); return this; @@ -247,6 +261,7 @@ public BatchAwareMemoryRecordsBuilder appendWithOffset(long offset, Record recor public @NonNull BatchAwareMemoryRecordsBuilder appendEndTxnMarker(long timestamp, @NonNull EndTransactionMarker marker) { + checkIfClosed(); checkHasBatch(); builder.appendEndTxnMarker(timestamp, marker); return this; @@ -261,13 +276,16 @@ public BatchAwareMemoryRecordsBuilder appendWithOffset(long offset, Record recor * @return the memory records */ public @NonNull MemoryRecords build() { - boolean needsFlip = builder == null || !builder.isClosed(); - maybeAppendCurrentBatch(); - ByteBuffer buf = this.buffer.buffer(); - if (needsFlip) { + if (closed) { + return MemoryRecords.readableRecords(this.buffer.buffer()); + } + else { + closed = true; + maybeAppendCurrentBatch(); + ByteBuffer buf = this.buffer.buffer(); buf.flip(); + return MemoryRecords.readableRecords(buf); } - return MemoryRecords.readableRecords(buf); } /** diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java index 341a1161fe..88f8505284 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java @@ -14,6 +14,7 @@ import org.apache.kafka.common.header.Header; import org.apache.kafka.common.record.CompressionType; import org.apache.kafka.common.record.ControlRecordType; +import org.apache.kafka.common.record.EndTransactionMarker; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; @@ -31,6 +32,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.in; class BatchAwareMemoryRecordsBuilderTest { @@ -59,6 +61,8 @@ void shouldBePossibleToWriteBatchDirectly() { // Then assertThat(output).isEqualTo(input); + + assertThat(builder.build()).describedAs("Build should be idempotent").isEqualTo(input); } @Test @@ -136,9 +140,51 @@ void shouldPreventAppendAfterBuild1() { // Then assertThatThrownBy(() -> builder.append((Record) null)) .isExactlyInstanceOf(IllegalStateException.class) - // Batchless is a special case: We use the MRB's isClosed() to detect append-after-build - // which saves needing our own field but means the error message is not ideal in this case - .hasMessageContaining("You must start a batch"); + .hasMessageContaining("Builder is closed"); + } + + @Test + void shouldPreventAddBatchAfterBuild() { + // Given + var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); + + // When + builder.build(); + + // Then + assertThatThrownBy(() -> { + builder.addBatch(RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, + TimestampType.CREATE_TIME, + 0, + 0, + 0, + (short) 0, + 0, + false, + false, + 0, + 0); + }) + .isExactlyInstanceOf(IllegalStateException.class) + .hasMessageContaining("Builder is closed"); + } + + @Test + void shouldPreventAddBatchLikeAfterBuild() { + // Given + var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); + RecordBatch batch = RecordTestUtils.memoryRecords("key", "value").firstBatch(); + + // When + builder.build(); + + // Then + assertThatThrownBy(() -> { + builder.addBatchLike(batch); + }) + .isExactlyInstanceOf(IllegalStateException.class) + .hasMessageContaining("Builder is closed"); } @Test @@ -164,7 +210,64 @@ void shouldPreventAppendAfterBuild2() { // Then assertThatThrownBy(() -> builder.append((Record) null)) .isExactlyInstanceOf(IllegalStateException.class) - .hasMessageContaining("This builder has been built"); + .hasMessageContaining("Builder is closed"); + } + + @Test + void shouldPreventAppendControlRecordAfterBuild() { + // Given + var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); + builder.addBatch(RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, + TimestampType.CREATE_TIME, + 0, + 0, + 0, + (short) 0, + 0, + false, + false, + 0, + 0); + + // When + builder.build(); + + // Then + assertThatThrownBy(() -> { + SimpleRecord controlRecord = controlRecord(); + builder.appendControlRecordWithOffset(1, controlRecord); + }) + .isExactlyInstanceOf(IllegalStateException.class) + .hasMessageContaining("Builder is closed"); + } + + @Test + void shouldPreventAppendEndTxnMarkerRecordAfterBuild() { + // Given + var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); + builder.addBatch(RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, + TimestampType.CREATE_TIME, + 0, + 0, + 0, + (short) 0, + 0, + false, + false, + 0, + 0); + + // When + builder.build(); + + // Then + assertThatThrownBy(() -> { + builder.appendEndTxnMarker(1, new EndTransactionMarker(ControlRecordType.ABORT, 1)); + }) + .isExactlyInstanceOf(IllegalStateException.class) + .hasMessageContaining("Builder is closed"); } // 0 batches From dacc5ce79503e302ccbd66d2c73ea924fd62f0bf Mon Sep 17 00:00:00 2001 From: Robert Young Date: Tue, 23 Jan 2024 13:14:51 +1300 Subject: [PATCH 09/11] Apply review suggestions Signed-off-by: Robert Young --- .../test/assertj/HeaderAssert.java | 4 +- .../test/assertj/MemoryRecordsAssert.java | 13 +- .../test/assertj/RecordAssert.java | 48 +- .../test/assertj/RecordBatchAssert.java | 127 ++++- .../test/record/RecordTestUtils.java | 181 ++++++- .../kroxylicious/test/assertj/Assertions.java | 17 + .../test/assertj/HeaderAssertTest.java | 62 +++ .../test/assertj/KafkaAssertionsTest.java | 126 ----- .../test/assertj/MemoryRecordsAssertTest.java | 100 ++++ .../test/assertj/RecordAssertTest.java | 234 +++++++++ .../test/assertj/RecordBatchAssertTest.java | 244 +++++++++ .../encryption/inband/InBandKeyManager.java | 4 +- .../BatchAwareMemoryRecordsBuilder.java | 6 +- .../EnvelopeEncryptionFilterTest.java | 6 +- .../inband/InBandKeyManagerTest.java | 480 ++++++++---------- .../BatchAwareMemoryRecordsBuilderTest.java | 9 +- .../EnvelopeEncryptionFilterIT.java | 9 +- 17 files changed, 1223 insertions(+), 447 deletions(-) create mode 100644 kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/Assertions.java create mode 100644 kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java delete mode 100644 kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/KafkaAssertionsTest.java create mode 100644 kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/MemoryRecordsAssertTest.java create mode 100644 kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/RecordAssertTest.java create mode 100644 kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/RecordBatchAssertTest.java diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java index cca62d0d2e..d952a2f3ea 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/HeaderAssert.java @@ -13,6 +13,7 @@ public class HeaderAssert extends AbstractAssert { protected HeaderAssert(Header header) { super(header, HeaderAssert.class); + describedAs(header == null ? "null header" : "header"); } public static HeaderAssert assertThat(Header actual) { @@ -29,7 +30,8 @@ public HeaderAssert hasKeyEqualTo(String expected) { public HeaderAssert hasValueEqualTo(String expected) { isNotNull(); - Assertions.assertThat(new String(actual.value())) + String valueString = actual.value() == null ? null : new String(actual.value()); + Assertions.assertThat(valueString) .describedAs("header value") .isEqualTo(expected); return this; diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/MemoryRecordsAssert.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/MemoryRecordsAssert.java index d0821d3d9e..9e4238a082 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/MemoryRecordsAssert.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/MemoryRecordsAssert.java @@ -22,6 +22,7 @@ public class MemoryRecordsAssert extends AbstractAssert { protected MemoryRecordsAssert(MemoryRecords memoryRecords) { super(memoryRecords, MemoryRecordsAssert.class); + describedAs(memoryRecords == null ? "null memory records" : "memory records"); } public static MemoryRecordsAssert assertThat(MemoryRecords actual) { @@ -47,7 +48,7 @@ public Iterable batches() { isNotNull(); return () -> { var it = actual.batches().iterator(); - return new Iterator() { + return new Iterator<>() { @Override public boolean hasNext() { return it.hasNext(); @@ -62,17 +63,27 @@ public RecordBatchAssert next() { } public RecordBatchAssert firstBatch() { + isNotNull(); + isNotEmpty(); return batchesIterable() .first(new InstanceOfAssertFactory<>(RecordBatch.class, RecordBatchAssert::assertThat)) .describedAs("first batch"); } public RecordBatchAssert lastBatch() { + isNotNull(); + isNotEmpty(); return batchesIterable() .last(new InstanceOfAssertFactory<>(RecordBatch.class, RecordBatchAssert::assertThat)) .describedAs("last batch"); } + private void isNotEmpty() { + Assertions.assertThat(actual.batches()) + .describedAs("number of batches") + .hasSizeGreaterThan(0); + } + public MemoryRecordsAssert hasNumBatches(int expected) { isNotNull(); Assertions.assertThat(actual.batches()) diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java index aef88a5614..65b3e4b7e3 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java @@ -9,8 +9,10 @@ import org.apache.kafka.common.header.Header; import org.apache.kafka.common.record.Record; import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.AbstractByteArrayAssert; import org.assertj.core.api.AbstractLongAssert; import org.assertj.core.api.AbstractObjectAssert; +import org.assertj.core.api.AbstractStringAssert; import org.assertj.core.api.Assertions; import org.assertj.core.api.ObjectArrayAssert; @@ -19,6 +21,7 @@ public class RecordAssert extends AbstractAssert { protected RecordAssert(Record record) { super(record, RecordAssert.class); + describedAs(record == null ? "null record" : "record"); } public static RecordAssert assertThat(Record actual) { @@ -26,6 +29,7 @@ public static RecordAssert assertThat(Record actual) { } public RecordAssert hasOffsetEqualTo(int expect) { + isNotNull(); AbstractLongAssert offset = offsetAssert(); offset.isEqualTo(expect); return this; @@ -38,6 +42,7 @@ private AbstractLongAssert offsetAssert() { } public RecordAssert hasTimestampEqualTo(int expect) { + isNotNull(); AbstractLongAssert timestamp = timestampAssert(); timestamp.isEqualTo(expect); return this; @@ -50,11 +55,13 @@ private AbstractLongAssert timestampAssert() { } private AbstractObjectAssert keyStrAssert() { + isNotNull(); return Assertions.assertThat(actual).extracting(RecordTestUtils::recordKeyAsString) .describedAs("record key"); } public RecordAssert hasKeyEqualTo(String expect) { + isNotNull(); Assertions.assertThat(actual).extracting(RecordTestUtils::recordKeyAsString) .describedAs("record key") .isEqualTo(expect); @@ -62,22 +69,49 @@ public RecordAssert hasKeyEqualTo(String expect) { } public RecordAssert hasNullKey() { - keyStrAssert() - .isNull(); + isNotNull(); + keyStrAssert().isNull(); return this; } - private AbstractObjectAssert valueStrAssert() { - return Assertions.assertThat(actual).extracting(RecordTestUtils::recordValueAsString) + private AbstractStringAssert valueStrAssert() { + isNotNull(); + return Assertions.assertThat(RecordTestUtils.recordValueAsString(actual)) + .describedAs("record value"); + } + + private AbstractByteArrayAssert valueBytesAssert() { + isNotNull(); + return Assertions.assertThat(RecordTestUtils.recordValueAsBytes(actual)) .describedAs("record value"); } public RecordAssert hasValueEqualTo(String expect) { + isNotNull(); valueStrAssert().isEqualTo(expect); return this; } + public RecordAssert hasValueEqualTo(byte[] expect) { + isNotNull(); + valueBytesAssert().isEqualTo(expect); + return this; + } + + public RecordAssert hasValueNotEqualTo(String notExpected) { + isNotNull(); + valueStrAssert().isNotEqualTo(notExpected); + return this; + } + + public RecordAssert hasValueEqualTo(Record expected) { + isNotNull(); + hasValueEqualTo(RecordTestUtils.recordValueAsBytes(expected)); + return this; + } + public RecordAssert hasNullValue() { + isNotNull(); Assertions.assertThat(actual).extracting(RecordTestUtils::recordValueAsString) .describedAs("record value") .isNull(); @@ -91,33 +125,39 @@ public ObjectArrayAssert
headersAssert() { } public RecordAssert hasEmptyHeaders() { + isNotNull(); headersAssert().isEmpty(); return this; } public HeaderAssert singleHeader() { + isNotNull(); headersAssert().singleElement(); return HeaderAssert.assertThat(actual.headers()[0]) .describedAs("record header"); } public RecordAssert hasHeadersSize(int expect) { + isNotNull(); headersAssert().hasSize(expect); return this; } public RecordAssert containsHeaderWithKey(String expectedKey) { + isNotNull(); headersAssert().anyMatch(h -> h.key().equals(expectedKey)); return this; } public HeaderAssert firstHeader() { + isNotNull(); headersAssert().isNotEmpty(); return HeaderAssert.assertThat(actual.headers()[0]) .describedAs("first record header"); } public HeaderAssert lastHeader() { + isNotNull(); headersAssert().isNotEmpty(); return HeaderAssert.assertThat(actual.headers()[actual.headers().length - 1]) .describedAs("last record header"); diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordBatchAssert.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordBatchAssert.java index 1f9e83af54..308a8c9ee9 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordBatchAssert.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordBatchAssert.java @@ -7,18 +7,23 @@ package io.kroxylicious.test.assertj; import java.util.Iterator; +import java.util.OptionalLong; import org.apache.kafka.common.record.CompressionType; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; import org.assertj.core.api.AbstractAssert; import org.assertj.core.api.Assertions; import org.assertj.core.api.InstanceOfAssertFactory; import org.assertj.core.api.IterableAssert; +import edu.umd.cs.findbugs.annotations.NonNull; + public class RecordBatchAssert extends AbstractAssert { protected RecordBatchAssert(RecordBatch batch) { super(batch, RecordBatchAssert.class); + describedAs(batch == null ? "null record batch" : "record batch"); } public static RecordBatchAssert assertThat(RecordBatch actual) { @@ -33,7 +38,7 @@ public RecordBatchAssert hasSizeInBytes(int expected) { return this; } - public RecordBatchAssert hasBaseOffset(int expected) { + public RecordBatchAssert hasBaseOffset(long expected) { isNotNull(); Assertions.assertThat(actual.baseOffset()) .describedAs("baseOffset") @@ -53,17 +58,127 @@ public RecordBatchAssert hasCompressionType(CompressionType expected) { isNotNull(); Assertions.assertThat(actual.compressionType()) .describedAs("compressionType") + .isNotNull() .isEqualTo(expected); return this; } public RecordBatchAssert hasNumRecords(int expected) { + isNotNull(); Assertions.assertThat(actual) .describedAs("records") .hasSize(expected); return this; } + public RecordBatchAssert hasMagic(byte magic) { + isNotNull(); + Assertions.assertThat(actual.magic()) + .describedAs("magic") + .isEqualTo(magic); + return this; + } + + public RecordBatchAssert isControlBatch(boolean expected) { + isNotNull(); + Assertions.assertThat(actual.isControlBatch()) + .describedAs("controlBatch") + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert isTransactional(boolean expected) { + isNotNull(); + Assertions.assertThat(actual.isTransactional()) + .describedAs("transactional") + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert hasPartitionLeaderEpoch(int expected) { + isNotNull(); + Assertions.assertThat(actual.partitionLeaderEpoch()) + .describedAs("partitionLeaderEpoch") + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert hasDeleteHorizonMs(OptionalLong expected) { + isNotNull(); + Assertions.assertThat(actual.deleteHorizonMs()) + .describedAs("deleteHorizonMs") + .isNotNull() + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert hasLastOffset(long expected) { + isNotNull(); + Assertions.assertThat(actual.lastOffset()) + .describedAs("lastOffset") + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert hasMetadataMatching(RecordBatch batch) { + isNotNull(); + hasBaseOffset(batch.baseOffset()); + hasBaseSequence(batch.baseSequence()); + hasCompressionType(batch.compressionType()); + isControlBatch(batch.isControlBatch()); + isTransactional(batch.isTransactional()); + hasMagic(batch.magic()); + hasTimestampType(batch.timestampType()); + hasPartitionLeaderEpoch(batch.partitionLeaderEpoch()); + hasDeleteHorizonMs(batch.deleteHorizonMs()); + hasLastOffset(batch.lastOffset()); + hasMaxTimestamp(batch.maxTimestamp()); + hasProducerId(batch.producerId()); + hasProducerEpoch(batch.producerEpoch()); + hasLastSequence(batch.lastSequence()); + return this; + } + + public RecordBatchAssert hasLastSequence(int expected) { + isNotNull(); + Assertions.assertThat(actual.lastSequence()) + .describedAs("lastSequence") + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert hasProducerEpoch(short expected) { + isNotNull(); + Assertions.assertThat(actual.producerEpoch()) + .describedAs("producerEpoch") + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert hasProducerId(long expected) { + isNotNull(); + Assertions.assertThat(actual.producerId()) + .describedAs("producerId") + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert hasMaxTimestamp(long expected) { + isNotNull(); + Assertions.assertThat(actual.maxTimestamp()) + .describedAs("maxTimestamp") + .isEqualTo(expected); + return this; + } + + public RecordBatchAssert hasTimestampType(TimestampType expected) { + isNotNull(); + Assertions.assertThat(actual.timestampType()) + .describedAs("timestampType") + .isEqualTo(expected); + return this; + } + private IterableAssert recordIterable() { isNotNull(); IterableAssert records = IterableAssert.assertThatIterable(actual) @@ -72,19 +187,27 @@ private IterableAssert recordIterable() { } public RecordAssert firstRecord() { + isNotNull(); + isNotEmpty(); return recordIterable() .first(new InstanceOfAssertFactory<>(Record.class, RecordAssert::assertThat)) .describedAs("first record"); } public RecordAssert lastRecord() { + isNotNull(); + isNotEmpty(); return recordIterable() .last(new InstanceOfAssertFactory<>(Record.class, RecordAssert::assertThat)) .describedAs("last record"); } + @NonNull + private IterableAssert isNotEmpty() { + return Assertions.assertThat(actual).describedAs(descriptionText()).hasSizeGreaterThan(0); + } + public Iterable records() { - recordIterable().isNotEmpty(); isNotNull(); return () -> { return new Iterator() { diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java index 0c1987703f..c02d7b227d 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java @@ -13,12 +13,16 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.header.Header; import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.ControlRecordType; +import org.apache.kafka.common.record.EndTransactionMarker; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.record.RecordBatch; import org.apache.kafka.common.record.TimestampType; import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferOutputStream; import edu.umd.cs.findbugs.annotations.NonNull; @@ -147,6 +151,21 @@ public static Record record(String key, return record(DEFAULT_MAGIC_VALUE, DEFAULT_OFFSET, DEFAULT_TIMESTAMP, key, value, headers); } + /** + * Return a Record with the given key, value, offset and headers + * @param offset + * @param key + * @param value + * @param headers + * @return The record + */ + public static Record record(long offset, + String key, + String value, + Header... headers) { + return record(DEFAULT_MAGIC_VALUE, offset, DEFAULT_TIMESTAMP, key, value, headers); + } + /** * Return a Record with the given key, value and headers * @param key @@ -178,7 +197,7 @@ public static Record record(byte magic, Header... headers) { // This is a bit of a rigmarole, but it ensures that calls to getSizeInBytes() // on the returned Record is actually correct - MemoryRecords mr = memoryRecords(magic, offset, timestamp, key, value, headers); + MemoryRecords mr = singleElementMemoryRecords(magic, offset, timestamp, key, value, headers); return MemoryRecords.readableRecords(mr.buffer()).records().iterator().next(); } @@ -200,7 +219,7 @@ public static Record record(byte magic, Header... headers) { // This is a bit of a rigmarole, but it ensures that calls to getSizeInBytes() // on the returned Record is actually correct - MemoryRecords mr = memoryRecords(magic, offset, timestamp, key, value, headers); + MemoryRecords mr = singleElementMemoryRecords(magic, offset, timestamp, key, value, headers); return MemoryRecords.readableRecords(mr.buffer()).records().iterator().next(); } @@ -222,28 +241,74 @@ public static Record record(byte magic, Header... headers) { // This is a bit of a rigmarole, but it ensures that calls to getSizeInBytes() // on the returned Record is actually correct - MemoryRecords mr = memoryRecords(magic, offset, timestamp, key, value, headers); + MemoryRecords mr = singleElementMemoryRecords(magic, offset, timestamp, key, value, headers); return MemoryRecords.readableRecords(mr.buffer()).records().iterator().next(); } + /** + * Return a singleton RecordBatch containing a single Record with the given key, value and headers. + * The batch will use the current magic. The baseOffset and offset of the record will be 0 + * @see RecordTestUtils#singleElementRecordBatch(long, String, String, Header[]) + */ + public static MutableRecordBatch singleElementRecordBatch(String key, + String value, + Header... headers) { + return singleElementRecordBatch(DEFAULT_OFFSET, key, value, headers); + } + /** * Return a singleton RecordBatch containing a single Record with the given key, value and headers. * The batch will use the current magic. + * @param offset baseOffset of the single batch and offset of the single record within it * @param key * @param value * @param headers - * @return The record + * @return The record batch */ - public static RecordBatch recordBatch(String key, - String value, - Header... headers) { - return memoryRecords(DEFAULT_MAGIC_VALUE, - DEFAULT_OFFSET, + public static MutableRecordBatch singleElementRecordBatch(long offset, String key, String value, Header[] headers) { + return singleElementMemoryRecords(DEFAULT_MAGIC_VALUE, + offset, DEFAULT_TIMESTAMP, key, value, headers) - .firstBatch(); + .batches().iterator().next(); + } + + /** + * Return a singleton RecordBatch containing a single Record with the given key, value and headers. + * The batch will use the current magic. + * @param baseOffset baseOffset of the single batch and offset of the single record within it + * @return The record batch + */ + public static MutableRecordBatch singleElementRecordBatch(byte magic, + long baseOffset, + CompressionType compressionType, + TimestampType timestampType, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional, + boolean isControlBatch, + int partitionLeaderEpoch, + byte[] key, + byte[] value, + Header... headers) { + MemoryRecords records = memoryRecordsWithoutCopy(magic, baseOffset, compressionType, timestampType, logAppendTime, producerId, producerEpoch, baseSequence, + isTransactional, isControlBatch, partitionLeaderEpoch, 0L, key, value, headers); + return records.batches().iterator().next(); + } + + /** + * Return a singleton RecordBatch with all records removed. This simulates the case where compaction removes all + * records but retains the batch metadata. The batch will use the current magic. + * @param offset baseOffset of the single batch and offset of the single record within it + * @return The batch + * @see RecordTestUtils#memoryRecordsWithAllRecordsRemoved(long) + */ + public static MutableRecordBatch recordBatchWithAllRecordsRemoved(long offset) { + return memoryRecordsWithAllRecordsRemoved(offset).batchIterator().next(); } /** @@ -254,8 +319,8 @@ public static RecordBatch recordBatch(String key, * @param headers * @return The record */ - public static MemoryRecords memoryRecords(String key, String value, Header... headers) { - return memoryRecords(DEFAULT_MAGIC_VALUE, + public static MemoryRecords singleElementMemoryRecords(String key, String value, Header... headers) { + return singleElementMemoryRecords(DEFAULT_MAGIC_VALUE, DEFAULT_OFFSET, DEFAULT_TIMESTAMP, key, @@ -263,6 +328,19 @@ public static MemoryRecords memoryRecords(String key, String value, Header... he headers); } + /** + * Return a MemoryRecords containing the specified batches + */ + public static MemoryRecords memoryRecords(MutableRecordBatch... batches) { + ByteBufferOutputStream outputStream = new ByteBufferOutputStream(1000); + for (MutableRecordBatch batch : batches) { + batch.writeTo(outputStream); + } + ByteBuffer buffer = outputStream.buffer(); + buffer.flip(); + return MemoryRecords.readableRecords(buffer); + } + /** * Return a MemoryRecords containing a single RecordBatch containing a single Record with the given key, value and headers. * The batch will use the current magic. @@ -274,7 +352,7 @@ public static MemoryRecords memoryRecords(String key, String value, Header... he * @param headers * @return The record */ - public static MemoryRecords memoryRecords(byte magic, long offset, long timestamp, ByteBuffer key, ByteBuffer value, Header... headers) { + public static MemoryRecords singleElementMemoryRecords(byte magic, long offset, long timestamp, ByteBuffer key, ByteBuffer value, Header... headers) { return memoryRecordsWithoutCopy(magic, offset, timestamp, bytesOf(key), bytesOf(value), headers); } @@ -289,8 +367,10 @@ public static MemoryRecords memoryRecords(byte magic, long offset, long timestam * @param headers * @return The record */ - public static MemoryRecords memoryRecords(byte magic, long offset, long timestamp, String key, String value, Header... headers) { - return memoryRecordsWithoutCopy(magic, offset, timestamp, key.getBytes(StandardCharsets.UTF_8), value.getBytes(StandardCharsets.UTF_8), headers); + public static MemoryRecords singleElementMemoryRecords(byte magic, long offset, long timestamp, String key, String value, Header... headers) { + byte[] keyBytes = key == null ? null : key.getBytes(StandardCharsets.UTF_8); + byte[] valueBytes = value == null ? null : value.getBytes(StandardCharsets.UTF_8); + return memoryRecordsWithoutCopy(magic, offset, timestamp, keyBytes, valueBytes, headers); } /** @@ -304,7 +384,7 @@ public static MemoryRecords memoryRecords(byte magic, long offset, long timestam * @param headers * @return The record */ - public static MemoryRecords memoryRecords(byte magic, long offset, long timestamp, byte[] key, byte[] value, Header... headers) { + public static MemoryRecords singleElementMemoryRecords(byte magic, long offset, long timestamp, byte[] key, byte[] value, Header... headers) { // No need to copy the arrays because their contents are written to a ByteBuffer and not retained return memoryRecordsWithoutCopy(magic, offset, timestamp, key, value, headers); } @@ -316,6 +396,28 @@ private static MemoryRecords memoryRecordsWithoutCopy(byte magic, long offset, l } } + private static MemoryRecords memoryRecordsWithoutCopy(byte magic, + long baseOffset, + CompressionType compressionType, + TimestampType timestampType, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional, + boolean isControlBatch, + int partitionLeaderEpoch, + long timestamp, + byte[] key, + byte[] value, + Header... headers) { + try (MemoryRecordsBuilder memoryRecordsBuilder = memoryRecordsBuilder(magic, baseOffset, compressionType, timestampType, logAppendTime, producerId, producerEpoch, + baseSequence, isTransactional, isControlBatch, partitionLeaderEpoch)) { + memoryRecordsBuilder.appendWithOffset(baseOffset, timestamp, key, value, headers); + return memoryRecordsBuilder.build(); + } + } + /** * Return a MemoryRecords containing a single RecordBatch containing multiple Records. * The batch will use the current magic. @@ -382,19 +484,48 @@ private static MemoryRecordsBuilder defaultMemoryRecordsBuilder(byte magic) { @NonNull private static MemoryRecordsBuilder memoryRecordsBuilder(byte magic, long baseOffset) { + return memoryRecordsBuilder(magic, baseOffset, CompressionType.NONE, TimestampType.CREATE_TIME, 0L, 0L, (short) 0, 0, false, false, 0); + } + + @NonNull + private static MemoryRecordsBuilder memoryRecordsBuilder(byte magic, + long baseOffset, + CompressionType compressionType, + TimestampType timestampType, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional, + boolean isControlBatch, + int partitionLeaderEpoch) { return new MemoryRecordsBuilder( ByteBuffer.allocate(1024), magic, - CompressionType.NONE, - TimestampType.CREATE_TIME, + compressionType, + timestampType, baseOffset, - 0L, - 0L, - (short) 0, - 0, - false, - false, - 0, + logAppendTime, + producerId, + producerEpoch, + baseSequence, + isTransactional, + isControlBatch, + partitionLeaderEpoch, 0); } + + /** + * Generate a record batch set to be transaction and a control batch containing a single + * end transaction marker record of type abort + * @param baseOffset base offset of the batch + * @return batch + */ + public static MutableRecordBatch abortTransactionControlBatch(int baseOffset) { + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(ByteBuffer.allocate(1000), RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, baseOffset, 1L, 1L, (short) 1, 1, true, true, 1, 1); + builder.appendEndTxnMarker(1l, new EndTransactionMarker(ControlRecordType.ABORT, 1)); + MemoryRecords controlBatchRecords = builder.build(); + return controlBatchRecords.batchIterator().next(); + } } diff --git a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/Assertions.java b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/Assertions.java new file mode 100644 index 0000000000..3403a7c0a2 --- /dev/null +++ b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/Assertions.java @@ -0,0 +1,17 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.test.assertj; + +import org.assertj.core.api.ThrowableAssert; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class Assertions { + public static void throwsAssertionErrorContaining(ThrowableAssert.ThrowingCallable shouldRaiseThrowable, String description) { + assertThatThrownBy(shouldRaiseThrowable).isInstanceOf(AssertionError.class).hasMessageContaining(description); + } +} diff --git a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java new file mode 100644 index 0000000000..ee57afdc05 --- /dev/null +++ b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/HeaderAssertTest.java @@ -0,0 +1,62 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.test.assertj; + +import java.nio.charset.StandardCharsets; + +import org.apache.kafka.common.header.internals.RecordHeader; +import org.assertj.core.api.ThrowingConsumer; +import org.junit.jupiter.api.Test; + +import static io.kroxylicious.test.assertj.Assertions.throwsAssertionErrorContaining; + +class HeaderAssertTest { + + @Test + void testHeaderHasKeyEqualTo() { + RecordHeader header = new RecordHeader("foo", null); + HeaderAssert headerAssert = KafkaAssertions.assertThat(header); + headerAssert.hasKeyEqualTo("foo"); + throwsAssertionErrorContaining(() -> headerAssert.hasKeyEqualTo("bar"), "[header key]"); + assertThrowsIfHeaderNull(nullAssert -> nullAssert.hasKeyEqualTo("any")); + } + + @Test + void testHeaderHasNullValue() { + RecordHeader nullValue = new RecordHeader("foo", null); + HeaderAssert nullValueAssert = KafkaAssertions.assertThat(nullValue); + + RecordHeader nonNullValue = new RecordHeader("foo", new byte[]{ 1, 2, 3 }); + HeaderAssert nonNullValueAssert = KafkaAssertions.assertThat(nonNullValue); + + nullValueAssert.hasNullValue(); + throwsAssertionErrorContaining(nonNullValueAssert::hasNullValue, "[header value]"); + assertThrowsIfHeaderNull(HeaderAssert::hasNullValue); + } + + @Test + void testHeaderHasValueEqualTo() { + RecordHeader nullValue = new RecordHeader("foo", null); + HeaderAssert nullValueAssert = KafkaAssertions.assertThat(nullValue); + + RecordHeader nonNullValue = new RecordHeader("foo", "abc".getBytes(StandardCharsets.UTF_8)); + HeaderAssert nonNullValueAssert = KafkaAssertions.assertThat(nonNullValue); + + nullValueAssert.hasValueEqualTo(null); + nonNullValueAssert.hasValueEqualTo("abc"); + throwsAssertionErrorContaining(() -> nonNullValueAssert.hasValueEqualTo("other"), "[header value]"); + throwsAssertionErrorContaining(() -> nonNullValueAssert.hasValueEqualTo(null), "[header value]"); + throwsAssertionErrorContaining(() -> nullValueAssert.hasValueEqualTo("other"), "[header value]"); + assertThrowsIfHeaderNull(nullAssert -> nullAssert.hasValueEqualTo("any")); + } + + void assertThrowsIfHeaderNull(ThrowingConsumer action) { + HeaderAssert headerAssert = KafkaAssertions.assertThat((RecordHeader) null); + throwsAssertionErrorContaining(() -> action.accept(headerAssert), "[null header]"); + } + +} \ No newline at end of file diff --git a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/KafkaAssertionsTest.java b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/KafkaAssertionsTest.java deleted file mode 100644 index 3d651dd4e8..0000000000 --- a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/KafkaAssertionsTest.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright Kroxylicious Authors. - * - * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 - */ - -package io.kroxylicious.test.assertj; - -import java.nio.charset.StandardCharsets; - -import org.apache.kafka.common.header.internals.RecordHeader; -import org.apache.kafka.common.record.CompressionType; -import org.apache.kafka.common.record.MemoryRecords; -import org.apache.kafka.common.record.Record; -import org.apache.kafka.common.record.RecordBatch; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; - -import static io.kroxylicious.test.record.RecordTestUtils.memoryRecords; -import static io.kroxylicious.test.record.RecordTestUtils.record; -import static io.kroxylicious.test.record.RecordTestUtils.recordBatch; - -class KafkaAssertionsTest { - - @Test - void testHeader() { - RecordHeader header = new RecordHeader("foo", null); - HeaderAssert headerAssert = KafkaAssertions.assertThat(header); - - // key - headerAssert.hasKeyEqualTo("foo"); - Assertions.assertThatThrownBy(() -> headerAssert.hasKeyEqualTo("bar")).hasMessageContaining("[header key]"); - - // value - headerAssert.hasNullValue(); - HeaderAssert headerAssert1 = KafkaAssertions.assertThat(new RecordHeader("foo", new byte[0])); - Assertions.assertThatThrownBy(() -> headerAssert1.hasNullValue()).hasMessageContaining("[header value]"); - HeaderAssert headerAssert2 = KafkaAssertions.assertThat(new RecordHeader("foo", "abc".getBytes(StandardCharsets.UTF_8))); - headerAssert2.hasValueEqualTo("abc"); - Assertions.assertThatThrownBy(() -> headerAssert2.hasValueEqualTo("xyz")).hasMessageContaining("[header value]"); - - } - - @Test - void testRecord() { - Record record = record("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); - RecordAssert recordAssert = KafkaAssertions.assertThat(record); - - // offset - recordAssert.hasOffsetEqualTo(0); - Assertions.assertThatThrownBy(() -> recordAssert.hasOffsetEqualTo(1)).hasMessageContaining("[record offset]"); - - // timestamp - recordAssert.hasTimestampEqualTo(0); - Assertions.assertThatThrownBy(() -> recordAssert.hasTimestampEqualTo(1)).hasMessageContaining("[record timestamp]"); - - // key - recordAssert.hasKeyEqualTo("KEY"); - Assertions.assertThatThrownBy(() -> recordAssert.hasKeyEqualTo("NOT_KEY")).hasMessageContaining("[record key]"); - Assertions.assertThatThrownBy(() -> recordAssert.hasNullKey()).hasMessageContaining("[record key]"); - - // value - recordAssert.hasValueEqualTo("VALUE"); - Assertions.assertThatThrownBy(() -> recordAssert.hasValueEqualTo("NOT_VALUE")).hasMessageContaining("[record value]"); - Assertions.assertThatThrownBy(() -> recordAssert.hasNullValue()).hasMessageContaining("[record value]"); - - // headers - recordAssert.hasHeadersSize(1); - recordAssert.containsHeaderWithKey("HEADER"); - recordAssert.firstHeader().hasKeyEqualTo("HEADER"); - recordAssert.lastHeader().hasKeyEqualTo("HEADER"); - Assertions.assertThatThrownBy(() -> recordAssert.hasHeadersSize(2)).hasMessageContaining("[record headers]"); - Assertions.assertThatThrownBy(() -> recordAssert.hasEmptyHeaders()).hasMessageContaining("[record headers]"); - Assertions.assertThatThrownBy(() -> recordAssert.containsHeaderWithKey("NOT_HEADER")).hasMessageContaining("[record headers]"); - Assertions.assertThatThrownBy(() -> recordAssert.firstHeader().hasKeyEqualTo("NOT_HEADER")).hasMessageContaining("[header key]"); - Assertions.assertThatThrownBy(() -> recordAssert.lastHeader().hasKeyEqualTo("NOT_HEADER")).hasMessageContaining("[header key]"); - } - - @Test - void testRecordBatch() { - RecordBatch batch = recordBatch("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); - RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); - - // sizeInBytes - batchAssert.hasSizeInBytes(96); - Assertions.assertThatThrownBy(() -> batchAssert.hasSizeInBytes(1)).hasMessageContaining("[sizeInBytes]"); - - // Base offset - batchAssert.hasBaseOffset(0); - Assertions.assertThatThrownBy(() -> batchAssert.hasBaseOffset(1)).hasMessageContaining("[baseOffset]"); - - // Base sequence - batchAssert.hasBaseSequence(0); - Assertions.assertThatThrownBy(() -> batchAssert.hasBaseSequence(1)).hasMessageContaining("[baseSequence]"); - - // compression type - batchAssert.hasCompressionType(CompressionType.NONE); - Assertions.assertThatThrownBy(() -> batchAssert.hasCompressionType(CompressionType.GZIP)).hasMessageContaining("[compressionType]"); - - // records - batchAssert.hasNumRecords(1); - batchAssert.firstRecord().hasKeyEqualTo("KEY"); - batchAssert.lastRecord().hasKeyEqualTo("KEY"); - Assertions.assertThatThrownBy(() -> batchAssert.hasNumRecords(2)).hasMessageContaining("[records]"); - } - - @Test - void testMemoryRecords() { - MemoryRecords records = memoryRecords("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); - MemoryRecordsAssert recordsAssert = KafkaAssertions.assertThat(records); - - // Num batches - recordsAssert.hasNumBatches(1); - Assertions.assertThatThrownBy(() -> recordsAssert.hasNumBatches(2)).hasMessageContaining("[number of batches]"); - - // Batch sizes - recordsAssert.hasBatchSizes(1); - recordsAssert.firstBatch().firstRecord().hasKeyEqualTo("KEY"); - recordsAssert.lastBatch().firstRecord().hasKeyEqualTo("KEY"); - Assertions.assertThatThrownBy(() -> recordsAssert.hasBatchSizes(2)).hasMessageContaining("[batch sizes]"); - - // sizeInByte - recordsAssert.hasSizeInBytes(96); - Assertions.assertThatThrownBy(() -> recordsAssert.hasSizeInBytes(1)).hasMessageContaining("[sizeInBytes]"); - } -} diff --git a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/MemoryRecordsAssertTest.java b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/MemoryRecordsAssertTest.java new file mode 100644 index 0000000000..a5389b6116 --- /dev/null +++ b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/MemoryRecordsAssertTest.java @@ -0,0 +1,100 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.test.assertj; + +import java.nio.charset.StandardCharsets; + +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.assertj.core.api.ThrowingConsumer; +import org.junit.jupiter.api.Test; + +import io.kroxylicious.test.record.RecordTestUtils; + +import static io.kroxylicious.test.assertj.Assertions.throwsAssertionErrorContaining; + +class MemoryRecordsAssertTest { + + @Test + void testHasNumBatches() { + MemoryRecords records = RecordTestUtils.singleElementMemoryRecords("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + MemoryRecordsAssert recordsAssert = KafkaAssertions.assertThat(records); + recordsAssert.hasNumBatches(1); + throwsAssertionErrorContaining(() -> recordsAssert.hasNumBatches(2), "[number of batches]"); + assertThrowsIfMemoryRecordsNull(nullAssert -> nullAssert.hasNumBatches(1)); + } + + @Test + void testHasBatchSizes() { + MemoryRecords records = RecordTestUtils.singleElementMemoryRecords("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + MemoryRecordsAssert recordsAssert = KafkaAssertions.assertThat(records); + recordsAssert.hasBatchSizes(1); + throwsAssertionErrorContaining(() -> recordsAssert.hasBatchSizes(2), "[batch sizes]"); + assertThrowsIfMemoryRecordsNull(nullAssert -> nullAssert.hasBatchSizes(1)); + } + + @Test + void testFirstBatch() { + MemoryRecords records = RecordTestUtils.singleElementMemoryRecords("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + MutableRecordBatch batch1 = RecordTestUtils.singleElementRecordBatch("FIRST", "value"); + MutableRecordBatch batch2 = RecordTestUtils.singleElementRecordBatch("LAST", "value"); + MemoryRecords multiBatch = RecordTestUtils.memoryRecords(batch1, batch2); + MemoryRecords empty = MemoryRecords.EMPTY; + KafkaAssertions.assertThat(records).firstBatch().firstRecord().hasKeyEqualTo("KEY"); + KafkaAssertions.assertThat(multiBatch).firstBatch().firstRecord().hasKeyEqualTo("FIRST"); + throwsAssertionErrorContaining(() -> KafkaAssertions.assertThat(empty).firstBatch(), "number of batches"); + assertThrowsIfMemoryRecordsNull(MemoryRecordsAssert::firstBatch); + } + + @Test + void testLastBatch() { + MemoryRecords records = RecordTestUtils.singleElementMemoryRecords("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + MutableRecordBatch batch1 = RecordTestUtils.singleElementRecordBatch("FIRST", "value"); + MutableRecordBatch batch2 = RecordTestUtils.singleElementRecordBatch("LAST", "value"); + MemoryRecords multiBatch = RecordTestUtils.memoryRecords(batch1, batch2); + MemoryRecords empty = MemoryRecords.EMPTY; + KafkaAssertions.assertThat(records).lastBatch().firstRecord().hasKeyEqualTo("KEY"); + KafkaAssertions.assertThat(multiBatch).lastBatch().firstRecord().hasKeyEqualTo("LAST"); + throwsAssertionErrorContaining(() -> KafkaAssertions.assertThat(empty).lastBatch(), "number of batches"); + assertThrowsIfMemoryRecordsNull(MemoryRecordsAssert::lastBatch); + } + + @Test + void testBatches() { + MemoryRecords records = RecordTestUtils.singleElementMemoryRecords("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + MutableRecordBatch batch1 = RecordTestUtils.singleElementRecordBatch("KEY", "value"); + MutableRecordBatch batch2 = RecordTestUtils.singleElementRecordBatch("KEY", "value"); + MemoryRecords multiBatch = RecordTestUtils.memoryRecords(batch1, batch2); + MemoryRecords empty = MemoryRecords.EMPTY; + for (RecordBatchAssert batch : KafkaAssertions.assertThat(records).batches()) { + batch.firstRecord().hasKeyEqualTo("KEY"); + } + for (RecordBatchAssert batch : KafkaAssertions.assertThat(multiBatch).batches()) { + batch.firstRecord().hasKeyEqualTo("KEY"); + } + for (RecordBatchAssert batch : KafkaAssertions.assertThat(empty).batches()) { + batch.firstRecord().hasKeyEqualTo("KEY"); + } + assertThrowsIfMemoryRecordsNull(MemoryRecordsAssert::batches); + } + + @Test + void testHasSizeInBytes() { + MemoryRecords records = RecordTestUtils.singleElementMemoryRecords("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + MemoryRecordsAssert recordsAssert = KafkaAssertions.assertThat(records); + recordsAssert.hasSizeInBytes(96); + throwsAssertionErrorContaining(() -> recordsAssert.hasSizeInBytes(1), "[sizeInBytes]"); + assertThrowsIfMemoryRecordsNull(nullAssert -> nullAssert.hasSizeInBytes(1)); + } + + void assertThrowsIfMemoryRecordsNull(ThrowingConsumer action) { + MemoryRecordsAssert headerAssert = KafkaAssertions.assertThat((MemoryRecords) null); + throwsAssertionErrorContaining(() -> action.accept(headerAssert), "[null memory records]"); + } + +} diff --git a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/RecordAssertTest.java b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/RecordAssertTest.java new file mode 100644 index 0000000000..a43481fb5f --- /dev/null +++ b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/RecordAssertTest.java @@ -0,0 +1,234 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.test.assertj; + +import java.nio.charset.StandardCharsets; + +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.Record; +import org.assertj.core.api.ThrowingConsumer; +import org.junit.jupiter.api.Test; + +import io.kroxylicious.test.record.RecordTestUtils; + +import static io.kroxylicious.test.assertj.Assertions.throwsAssertionErrorContaining; +import static io.kroxylicious.test.record.RecordTestUtils.record; + +class RecordAssertTest { + + @Test + void testRecordHasOffsetEqualTo() { + Record record = record("KEY", "VALUE"); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + recordAssert.hasOffsetEqualTo(0); + throwsAssertionErrorContaining(() -> recordAssert.hasOffsetEqualTo(1), "[record offset]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.hasOffsetEqualTo(1)); + } + + @Test + void testRecordHasTimestampEqualTo() { + Record record = record("KEY", "VALUE"); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + recordAssert.hasTimestampEqualTo(0); + throwsAssertionErrorContaining(() -> recordAssert.hasTimestampEqualTo(1), "[record timestamp]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.hasTimestampEqualTo(1)); + } + + @Test + void testRecordHasKeyEqualTo() { + Record record = record("KEY", "VALUE"); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + recordAssert.hasKeyEqualTo("KEY"); + throwsAssertionErrorContaining(() -> recordAssert.hasKeyEqualTo("NOT_KEY"), "[record key]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.hasKeyEqualTo("NOT_KEY")); + } + + @Test + void testRecordHasNullKey() { + Record record = record("KEY", "VALUE"); + Record nullKeyRecord = record(null, "VALUE"); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + KafkaAssertions.assertThat(nullKeyRecord).hasNullKey(); + throwsAssertionErrorContaining(recordAssert::hasNullKey, "[record key]"); + assertThrowsIfRecordNull(RecordAssert::hasNullValue); + } + + @Test + void testRecordHasValueEqualToString() { + Record record = record("KEY", "VALUE"); + Record nullValue = record("KEY", (String) null); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + RecordAssert nullValueAssert = KafkaAssertions.assertThat(nullValue); + recordAssert.hasValueEqualTo("VALUE"); + nullValueAssert.hasValueEqualTo((String) null); + throwsAssertionErrorContaining(() -> recordAssert.hasValueEqualTo("NOT_VALUE"), "[record value]"); + throwsAssertionErrorContaining(() -> nullValueAssert.hasValueEqualTo("ANY"), "[record value]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.hasValueEqualTo("ANY")); + } + + @Test + void testRecordHasValueNotEqualToString() { + Record record = record("KEY", "VALUE"); + Record nullValue = record("KEY", (String) null); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + RecordAssert nullValueAssert = KafkaAssertions.assertThat(nullValue); + recordAssert.hasValueNotEqualTo("OTHER"); + nullValueAssert.hasValueNotEqualTo("OTHER"); + throwsAssertionErrorContaining(() -> recordAssert.hasValueNotEqualTo("VALUE"), "[record value]"); + throwsAssertionErrorContaining(() -> nullValueAssert.hasValueNotEqualTo(null), "[record value]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.hasValueNotEqualTo("ANY")); + } + + @Test + void testRecordHasNullValue() { + Record record = record("KEY", "VALUE"); + Record nullValue = record("KEY", (String) null); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + RecordAssert nullValueAssert = KafkaAssertions.assertThat(nullValue); + nullValueAssert.hasNullValue(); + throwsAssertionErrorContaining(recordAssert::hasNullValue, "[record value]"); + assertThrowsIfRecordNull(RecordAssert::hasNullValue); + } + + @Test + void testRecordHasValueEqualToByteArray() { + Record record = record("KEY", "VALUE"); + Record nullValue = record("KEY", (String) null); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + RecordAssert nullValueAssert = KafkaAssertions.assertThat(nullValue); + recordAssert.hasValueEqualTo("VALUE".getBytes(StandardCharsets.UTF_8)); + nullValueAssert.hasValueEqualTo((String) null); + throwsAssertionErrorContaining(() -> recordAssert.hasValueEqualTo("NOT_VALUE".getBytes(StandardCharsets.UTF_8)), "[record value]"); + throwsAssertionErrorContaining(() -> nullValueAssert.hasValueEqualTo("ANY".getBytes(StandardCharsets.UTF_8)), "[record value]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.hasValueEqualTo("ANY".getBytes(StandardCharsets.UTF_8))); + } + + @Test + void testRecordHasValueEqualToRecord() { + Record record = record("KEY", "VALUE"); + Record nullValue = record("KEY", (String) null); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + RecordAssert nullValueAssert = KafkaAssertions.assertThat(nullValue); + recordAssert.hasValueEqualTo(RecordTestUtils.record("KEY", "VALUE")); + throwsAssertionErrorContaining(() -> recordAssert.hasValueEqualTo(RecordTestUtils.record("KEY", "NOT_VALUE")), "[record value]"); + throwsAssertionErrorContaining(() -> nullValueAssert.hasValueEqualTo(RecordTestUtils.record("KEY", "ANY")), "[record value]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.hasValueEqualTo(RecordTestUtils.record("KEY", "ANY"))); + } + + @Test + void testRecordHasHeadersSize() { + Record record = record("KEY", "VALUE", new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + Record emptyHeaders = record("KEY", (String) null); + RecordAssert recordAssert = KafkaAssertions.assertThat(record); + RecordAssert emptyHeadersAssert = KafkaAssertions.assertThat(emptyHeaders); + recordAssert.hasHeadersSize(1); + emptyHeadersAssert.hasHeadersSize(0); + throwsAssertionErrorContaining(() -> recordAssert.hasHeadersSize(2), "[record headers]"); + throwsAssertionErrorContaining(() -> emptyHeadersAssert.hasHeadersSize(1), "[record headers]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.hasHeadersSize(1)); + } + + @Test + void testRecordContainsHeaderWithKey() { + String headerKeyA = "HEADER_KEY_A"; + Record singleHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + String headerKeyB = "HEADER_KEY_B"; + Record multHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8)), + new RecordHeader(headerKeyB, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + Record emptyHeaders = record("KEY", (String) null); + RecordAssert singleHeaderAssert = KafkaAssertions.assertThat(singleHeader); + RecordAssert multiHeaderAssert = KafkaAssertions.assertThat(multHeader); + RecordAssert emptyHeadersAssert = KafkaAssertions.assertThat(emptyHeaders); + + singleHeaderAssert.containsHeaderWithKey(headerKeyA); + throwsAssertionErrorContaining(() -> singleHeaderAssert.containsHeaderWithKey("NOT_HEADER"), "[record headers]"); + + multiHeaderAssert.containsHeaderWithKey(headerKeyA); + multiHeaderAssert.containsHeaderWithKey(headerKeyB); + throwsAssertionErrorContaining(() -> multiHeaderAssert.containsHeaderWithKey("NOT_HEADER"), "[record headers]"); + + throwsAssertionErrorContaining(() -> emptyHeadersAssert.containsHeaderWithKey("ANY"), "[record headers]"); + assertThrowsIfRecordNull(nullAssert -> nullAssert.containsHeaderWithKey("ANY")); + } + + @Test + void testRecordFirstHeader() { + String headerKeyA = "HEADER_KEY_A"; + Record singleHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + String headerKeyB = "HEADER_KEY_B"; + Record multiHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8)), + new RecordHeader(headerKeyB, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + Record emptyHeaders = record("KEY", (String) null); + RecordAssert singleHeaderAssert = KafkaAssertions.assertThat(singleHeader); + RecordAssert multiHeaderAssert = KafkaAssertions.assertThat(multiHeader); + RecordAssert emptyHeadersAssert = KafkaAssertions.assertThat(emptyHeaders); + + singleHeaderAssert.firstHeader().hasKeyEqualTo(headerKeyA); + multiHeaderAssert.firstHeader().hasKeyEqualTo(headerKeyA); + throwsAssertionErrorContaining(emptyHeadersAssert::firstHeader, "[record headers]"); + assertThrowsIfRecordNull(RecordAssert::firstHeader); + } + + @Test + void testRecordLastHeader() { + String headerKeyA = "HEADER_KEY_A"; + Record singleHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + String headerKeyB = "HEADER_KEY_B"; + Record multiHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8)), + new RecordHeader(headerKeyB, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + Record emptyHeaders = record("KEY", (String) null); + RecordAssert singleHeaderAssert = KafkaAssertions.assertThat(singleHeader); + RecordAssert multiHeaderAssert = KafkaAssertions.assertThat(multiHeader); + RecordAssert emptyHeadersAssert = KafkaAssertions.assertThat(emptyHeaders); + + singleHeaderAssert.lastHeader().hasKeyEqualTo(headerKeyA); + multiHeaderAssert.lastHeader().hasKeyEqualTo(headerKeyB); + throwsAssertionErrorContaining(emptyHeadersAssert::lastHeader, "[record headers]"); + assertThrowsIfRecordNull(RecordAssert::lastHeader); + } + + @Test + void testRecordSingleHeader() { + String headerKeyA = "HEADER_KEY_A"; + Record singleHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + String headerKeyB = "HEADER_KEY_B"; + Record multiHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8)), + new RecordHeader(headerKeyB, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + Record emptyHeaders = record("KEY", (String) null); + RecordAssert singleHeaderAssert = KafkaAssertions.assertThat(singleHeader); + RecordAssert multiHeaderAssert = KafkaAssertions.assertThat(multiHeader); + RecordAssert emptyHeadersAssert = KafkaAssertions.assertThat(emptyHeaders); + + singleHeaderAssert.singleHeader().hasKeyEqualTo(headerKeyA); + throwsAssertionErrorContaining(multiHeaderAssert::singleHeader, "[record headers]"); + throwsAssertionErrorContaining(emptyHeadersAssert::singleHeader, "[record headers]"); + assertThrowsIfRecordNull(RecordAssert::singleHeader); + } + + @Test + void testRecordHasEmptyHeaders() { + String headerKeyA = "HEADER_KEY_A"; + Record singleHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + String headerKeyB = "HEADER_KEY_B"; + Record multiHeader = record("KEY", "VALUE", new RecordHeader(headerKeyA, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8)), + new RecordHeader(headerKeyB, "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + Record emptyHeaders = record("KEY", (String) null); + RecordAssert singleHeaderAssert = KafkaAssertions.assertThat(singleHeader); + RecordAssert multiHeaderAssert = KafkaAssertions.assertThat(multiHeader); + RecordAssert emptyHeadersAssert = KafkaAssertions.assertThat(emptyHeaders); + + emptyHeadersAssert.hasEmptyHeaders(); + throwsAssertionErrorContaining(multiHeaderAssert::hasEmptyHeaders, "[record headers]"); + throwsAssertionErrorContaining(singleHeaderAssert::hasEmptyHeaders, "[record headers]"); + assertThrowsIfRecordNull(RecordAssert::hasEmptyHeaders); + } + + void assertThrowsIfRecordNull(ThrowingConsumer action) { + RecordAssert recordAssert = KafkaAssertions.assertThat((Record) null); + throwsAssertionErrorContaining(() -> action.accept(recordAssert), "[null record]"); + } +} diff --git a/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/RecordBatchAssertTest.java b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/RecordBatchAssertTest.java new file mode 100644 index 0000000000..37ac30a9b7 --- /dev/null +++ b/kroxylicious-filter-test-support/src/test/java/io/kroxylicious/test/assertj/RecordBatchAssertTest.java @@ -0,0 +1,244 @@ +/* + * Copyright Kroxylicious Authors. + * + * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0 + */ + +package io.kroxylicious.test.assertj; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.OptionalLong; + +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.assertj.core.api.ThrowingConsumer; +import org.junit.jupiter.api.Test; + +import io.kroxylicious.test.record.RecordTestUtils; + +import static io.kroxylicious.test.assertj.Assertions.throwsAssertionErrorContaining; +import static io.kroxylicious.test.record.RecordTestUtils.singleElementRecordBatch; + +class RecordBatchAssertTest { + + @Test + void testRecordBatchHasSizeInBytes() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasSizeInBytes(76); + throwsAssertionErrorContaining(() -> batchAssert.hasSizeInBytes(1), "[sizeInBytes]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasSizeInBytes(1)); + } + + @Test + void testRecordBatchHasBaseOffset() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasBaseOffset(0L); + throwsAssertionErrorContaining(() -> batchAssert.hasBaseOffset(1L), "[baseOffset]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasBaseOffset(1L)); + } + + @Test + void testRecordBatchHasBaseSequence() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasBaseSequence(0); + throwsAssertionErrorContaining(() -> batchAssert.hasBaseSequence(1), "[baseSequence]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasBaseSequence(1)); + } + + @Test + void testRecordBatchHasCompressionType() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasCompressionType(CompressionType.NONE); + throwsAssertionErrorContaining(() -> batchAssert.hasCompressionType(CompressionType.GZIP), "[compressionType]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasCompressionType(CompressionType.GZIP)); + } + + @Test + void testRecordBatchHasMagic() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasMagic(RecordBatch.CURRENT_MAGIC_VALUE); + throwsAssertionErrorContaining(() -> batchAssert.hasMagic((byte) 1), "[magic]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasMagic((byte) 1)); + } + + @Test + void testRecordBatchIsControlBatch() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatch controlBatch = RecordTestUtils.abortTransactionControlBatch(1); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + RecordBatchAssert controlBatchAssert = KafkaAssertions.assertThat(controlBatch); + batchAssert.isControlBatch(false); + throwsAssertionErrorContaining(() -> batchAssert.isControlBatch(true), "[controlBatch]"); + controlBatchAssert.isControlBatch(true); + throwsAssertionErrorContaining(() -> controlBatchAssert.isControlBatch(false), "[controlBatch]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.isControlBatch(false)); + } + + @Test + void testRecordBatchIsTransactional() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatch transactionalBatch = RecordTestUtils.abortTransactionControlBatch(1); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + RecordBatchAssert transactionalBatchAssert = KafkaAssertions.assertThat(transactionalBatch); + batchAssert.isTransactional(false); + throwsAssertionErrorContaining(() -> batchAssert.isTransactional(true), "[transactional]"); + transactionalBatchAssert.isTransactional(true); + throwsAssertionErrorContaining(() -> transactionalBatchAssert.isTransactional(false), "[transactional]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.isTransactional(false)); + } + + @Test + void testRecordBatchHasPartitionLeaderEpoch() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasPartitionLeaderEpoch(0); + throwsAssertionErrorContaining(() -> batchAssert.hasPartitionLeaderEpoch(1), "[partitionLeaderEpoch]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasPartitionLeaderEpoch(1)); + } + + @Test + void testRecordBatchHasDeleteHorizonMs() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasDeleteHorizonMs(OptionalLong.empty()); + throwsAssertionErrorContaining(() -> batchAssert.hasDeleteHorizonMs(OptionalLong.of(1L)), "[deleteHorizonMs]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasDeleteHorizonMs(OptionalLong.of(1L))); + } + + @Test + void testRecordBatchHasLastOffset() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasLastOffset(0L); + throwsAssertionErrorContaining(() -> batchAssert.hasLastOffset(1L), "[lastOffset]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasLastOffset(1L)); + } + + @Test + void testRecordBatchHasLastSequence() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasLastSequence(0); + throwsAssertionErrorContaining(() -> batchAssert.hasLastSequence(1), "[lastSequence]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasLastSequence(1)); + } + + @Test + void testRecordBatchHasProducerEpoch() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasProducerEpoch((short) 0); + throwsAssertionErrorContaining(() -> batchAssert.hasProducerEpoch((short) 1), "[producerEpoch]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasProducerEpoch((short) 1)); + } + + @Test + void testRecordBatchHasProducerId() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasProducerId(0L); + throwsAssertionErrorContaining(() -> batchAssert.hasProducerId(1L), "[producerId]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasProducerId(1L)); + } + + @Test + void testRecordBatchHasMaxTimestamp() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasMaxTimestamp(0L); + throwsAssertionErrorContaining(() -> batchAssert.hasMaxTimestamp(1L), "[maxTimestamp]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasMaxTimestamp(1L)); + } + + @Test + void testRecordBatchHasTimestampType() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasTimestampType(TimestampType.CREATE_TIME); + throwsAssertionErrorContaining(() -> batchAssert.hasTimestampType(TimestampType.LOG_APPEND_TIME), "[timestampType]"); + throwsAssertionErrorContaining(() -> batchAssert.hasTimestampType(null), "[timestampType]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasTimestampType(TimestampType.LOG_APPEND_TIME)); + } + + @Test + void testRecordBatchHasMetadataMatching() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatch batchSameMetadata = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE", + new RecordHeader("HEADER", "HEADER_VALUE".getBytes(StandardCharsets.UTF_8))); + RecordBatch batchDifferentMetadata = singleElementRecordBatch(RecordBatch.CURRENT_MAGIC_VALUE, 1L, CompressionType.GZIP, TimestampType.CREATE_TIME, 1L, 1L, + (short) 1, 1, false, + false, 1, + "KEY".getBytes( + StandardCharsets.UTF_8), + "VALUE".getBytes(StandardCharsets.UTF_8), new RecordHeader[]{}); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasMetadataMatching(batch); + batchAssert.hasMetadataMatching(batchSameMetadata); + throwsAssertionErrorContaining(() -> batchAssert.hasMetadataMatching(batchDifferentMetadata), "[baseOffset]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasMetadataMatching(batch)); + } + + @Test + void testRecordBatchHasNumRecords() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatchAssert batchAssert = KafkaAssertions.assertThat(batch); + batchAssert.hasNumRecords(1); + throwsAssertionErrorContaining(() -> batchAssert.hasNumRecords(2), "[records]"); + assertThrowsIfRecordBatchNull(nullAssert -> nullAssert.hasNumRecords(1)); + } + + @Test + void testRecordBatchFirstRecord() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatch emptyByCompaction = RecordTestUtils.recordBatchWithAllRecordsRemoved(1L); + RecordBatch multipleRecordsBatch = RecordTestUtils.memoryRecords(List.of(RecordTestUtils.record(0L, "KEY", "a"), RecordTestUtils.record(1L, "KEY2", "b"))) + .firstBatch(); + KafkaAssertions.assertThat(batch).firstRecord().hasKeyEqualTo("KEY"); + KafkaAssertions.assertThat(multipleRecordsBatch).firstRecord().hasKeyEqualTo("KEY"); + throwsAssertionErrorContaining(() -> KafkaAssertions.assertThat(emptyByCompaction).firstRecord(), "[record batch]"); + assertThrowsIfRecordBatchNull(RecordBatchAssert::firstRecord); + } + + @Test + void testRecordBatchLastRecord() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatch emptyByCompaction = RecordTestUtils.recordBatchWithAllRecordsRemoved(1L); + RecordBatch multipleRecordsBatch = RecordTestUtils.memoryRecords(List.of(RecordTestUtils.record(0L, "KEY", "a"), RecordTestUtils.record(1L, "KEY2", "b"))) + .firstBatch(); + KafkaAssertions.assertThat(batch).lastRecord().hasKeyEqualTo("KEY"); + KafkaAssertions.assertThat(multipleRecordsBatch).lastRecord().hasKeyEqualTo("KEY2"); + throwsAssertionErrorContaining(() -> KafkaAssertions.assertThat(emptyByCompaction).lastRecord(), "[record batch]"); + assertThrowsIfRecordBatchNull(RecordBatchAssert::lastRecord); + } + + @Test + void testRecords() { + RecordBatch batch = RecordTestUtils.singleElementRecordBatch("KEY", "VALUE"); + RecordBatch emptyByCompaction = RecordTestUtils.recordBatchWithAllRecordsRemoved(1L); + RecordBatch multipleRecordsBatch = RecordTestUtils.memoryRecords(List.of(RecordTestUtils.record(0L, "KEY", "a"), RecordTestUtils.record(1L, "KEY", "b"))) + .firstBatch(); + for (RecordAssert record : KafkaAssertions.assertThat(batch).records()) { + record.hasKeyEqualTo("KEY"); + } + for (RecordAssert record : KafkaAssertions.assertThat(multipleRecordsBatch).records()) { + record.hasKeyEqualTo("KEY"); + } + for (RecordAssert record : KafkaAssertions.assertThat(emptyByCompaction).records()) { + record.hasKeyEqualTo("KEY"); + } + assertThrowsIfRecordBatchNull(RecordBatchAssert::records); + } + + void assertThrowsIfRecordBatchNull(ThrowingConsumer action) { + RecordBatchAssert batchAssert = KafkaAssertions.assertThat((RecordBatch) null); + throwsAssertionErrorContaining(() -> action.accept(batchAssert), "[null record batch]"); + } +} diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index ac973e6a76..bb9ced6124 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -370,10 +370,10 @@ public CompletionStage decrypt(@NonNull String topicName, int par } Set uniqueEdeks = extractEdeks(topicName, partition, records); CompletionStage> decryptors = resolveAll(uniqueEdeks); - CompletionStage objectCompletionStage = decryptors.thenApply( + CompletionStage decryptStage = decryptors.thenApply( encryptorMap -> decrypt(topicName, partition, records, new BatchAwareMemoryRecordsBuilder(allocateBufferForDecode(records, bufferAllocator)), encryptorMap, batchRecordCounts)); - return objectCompletionStage.thenApply(BatchAwareMemoryRecordsBuilder::build); + return decryptStage.thenApply(BatchAwareMemoryRecordsBuilder::build); } private CompletionStage> resolveAll(Set uniqueEdeks) { diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java index dff6b6da29..e5a3db16ac 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java @@ -276,16 +276,18 @@ public BatchAwareMemoryRecordsBuilder appendWithOffset(long offset, Record recor * @return the memory records */ public @NonNull MemoryRecords build() { + ByteBuffer buffer; if (closed) { - return MemoryRecords.readableRecords(this.buffer.buffer()); + buffer = this.buffer.buffer(); } else { closed = true; maybeAppendCurrentBatch(); ByteBuffer buf = this.buffer.buffer(); buf.flip(); - return MemoryRecords.readableRecords(buf); + buffer = buf; } + return MemoryRecords.readableRecords(buffer); } /** diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/EnvelopeEncryptionFilterTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/EnvelopeEncryptionFilterTest.java index c30c1bd131..50d60ed547 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/EnvelopeEncryptionFilterTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/EnvelopeEncryptionFilterTest.java @@ -120,9 +120,9 @@ void setUp() { return CompletableFuture.completedFuture(copy); }); - when(keyManager.encrypt(any(), anyInt(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(RecordTestUtils.memoryRecords("key", "value"))); + when(keyManager.encrypt(any(), anyInt(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(RecordTestUtils.singleElementMemoryRecords("key", "value"))); - when(keyManager.decrypt(any(), anyInt(), any(), any())).thenReturn(CompletableFuture.completedFuture(RecordTestUtils.memoryRecords("decrypt", "decrypt"))); + when(keyManager.decrypt(any(), anyInt(), any(), any())).thenReturn(CompletableFuture.completedFuture(RecordTestUtils.singleElementMemoryRecords("decrypt", "decrypt"))); encryptionFilter = new EnvelopeEncryptionFilter<>(keyManager, kekSelector); } @@ -199,7 +199,7 @@ void shouldDecryptEncryptedRecords() { .setTopic(ENCRYPTED_TOPIC) .setPartitions(List.of(new PartitionData().setRecords(makeRecord(ENCRYPTED_MESSAGE_BYTES))))); - MemoryRecords decryptedRecords = RecordTestUtils.memoryRecords("key", "value"); + MemoryRecords decryptedRecords = RecordTestUtils.singleElementMemoryRecords("key", "value"); when(keyManager.decrypt(any(), anyInt(), any(), any())).thenReturn(CompletableFuture.completedFuture(decryptedRecords)); // When diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java index 8d00274389..f779241d6b 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/inband/InBandKeyManagerTest.java @@ -7,6 +7,7 @@ package io.kroxylicious.filter.encryption.inband; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; import java.util.EnumSet; @@ -19,20 +20,16 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.stream.Stream; -import java.util.stream.StreamSupport; import javax.crypto.SecretKey; import org.apache.kafka.common.header.Header; import org.apache.kafka.common.header.internals.RecordHeader; import org.apache.kafka.common.record.CompressionType; -import org.apache.kafka.common.record.ControlRecordType; import org.apache.kafka.common.record.MemoryRecords; -import org.apache.kafka.common.record.MemoryRecordsBuilder; import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Record; import org.apache.kafka.common.record.RecordBatch; -import org.apache.kafka.common.record.SimpleRecord; import org.apache.kafka.common.record.TimestampType; import org.apache.kafka.common.utils.ByteBufferOutputStream; import org.apache.kafka.common.utils.ByteUtils; @@ -47,11 +44,11 @@ import io.kroxylicious.filter.encryption.EncryptionScheme; import io.kroxylicious.filter.encryption.RecordField; -import io.kroxylicious.filter.encryption.records.BatchAwareMemoryRecordsBuilder; import io.kroxylicious.kms.provider.kroxylicious.inmemory.InMemoryEdek; import io.kroxylicious.kms.provider.kroxylicious.inmemory.InMemoryKms; import io.kroxylicious.kms.provider.kroxylicious.inmemory.UnitTestingKmsService; import io.kroxylicious.kms.service.KmsException; +import io.kroxylicious.test.assertj.MemoryRecordsAssert; import io.kroxylicious.test.record.RecordTestUtils; import edu.umd.cs.findbugs.annotations.NonNull; @@ -67,6 +64,12 @@ class InBandKeyManagerTest { + private static final String ARBITRARY_KEY = "key"; + private static final String ARBITRARY_KEY_2 = "key2"; + private static final String ARBITRARY_VALUE = "value"; + private static final String ARBITRARY_VALUE_2 = "value2"; + private static final Header[] ABSENT_HEADERS = {}; + @Test void shouldBeAbleToDependOnRecordHeaderEquality() { // The InBandKeyManager relies internally on RecordHeader implementing equals @@ -82,9 +85,8 @@ void shouldBeAbleToDependOnRecordHeaderEquality() { @Test void shouldEncryptRecordValue() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -115,226 +117,152 @@ void shouldEncryptRecordValue() { } @Test - void shouldPreserveMultipleBatches() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + void shouldPreserveMultipleBatchesOnEncrypt() { + // given + InMemoryKms kms = getInMemoryKms(); + EncryptionScheme scheme = createScheme(kms); + var km = createKeyManager(kms, 500_000); + + MutableRecordBatch firstBatch = RecordTestUtils.singleElementRecordBatch(RecordBatch.CURRENT_MAGIC_VALUE, 1L, CompressionType.GZIP, TimestampType.CREATE_TIME, 2L, + 3L, + (short) 4, 5, false, false, 1, ARBITRARY_KEY.getBytes( + StandardCharsets.UTF_8), + ARBITRARY_VALUE.getBytes(StandardCharsets.UTF_8)); + + MutableRecordBatch secondBatch = RecordTestUtils.singleElementRecordBatch(RecordBatch.CURRENT_MAGIC_VALUE, 2L, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 9L, 10L, + (short) 11, 12, false, false, 2, ARBITRARY_KEY_2.getBytes( + StandardCharsets.UTF_8), + ARBITRARY_VALUE_2.getBytes(StandardCharsets.UTF_8)); + MemoryRecords records = RecordTestUtils.memoryRecords(firstBatch, secondBatch); - var kekId = kms.generateKey(); + // when + MemoryRecords encrypted = assertImmediateSuccessAndGet(encrypt(km, scheme, records)); - byte[] value = { 1, 2, 3 }; - Record record = RecordTestUtils.record(1, ByteBuffer.wrap(value)); - - var value2 = new byte[]{ 4, 5, 6 }; - Record record2 = RecordTestUtils.record(2, ByteBuffer.wrap(value2)); - BatchAwareMemoryRecordsBuilder builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(1000)); - builder.addBatch(CompressionType.NONE, TimestampType.CREATE_TIME, 1); - builder.appendWithOffset(1l, record); - builder.addBatch(CompressionType.GZIP, TimestampType.LOG_APPEND_TIME, 2); - builder.appendWithOffset(2l, record2); - MemoryRecords records = builder.build(); - - EncryptionScheme scheme = new EncryptionScheme<>(kekId, EnumSet.of(RecordField.RECORD_VALUE)); - CompletableFuture encryptedFuture = km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new).toCompletableFuture(); - assertThat(encryptedFuture).succeedsWithin(Duration.ZERO); - MemoryRecords encrypted = encryptedFuture.join(); - record.value().rewind(); - record2.value().rewind(); + // then + MemoryRecordsAssert encryptedAssert = MemoryRecordsAssert.assertThat(encrypted); + encryptedAssert.hasNumBatches(2); + encryptedAssert.firstBatch().hasMetadataMatching(firstBatch).hasNumRecords(1).firstRecord().hasValueNotEqualTo(ARBITRARY_VALUE); + encryptedAssert.lastBatch().hasMetadataMatching(secondBatch).hasNumRecords(1).firstRecord().hasValueNotEqualTo(ARBITRARY_VALUE_2); + } - assertThat(encrypted.batches()).hasSize(2); - List batches = StreamSupport.stream(encrypted.batches().spliterator(), false).toList(); - MutableRecordBatch first = batches.get(0); - assertThat(first.compressionType()).isEqualTo(CompressionType.NONE); - assertThat(first.timestampType()).isEqualTo(TimestampType.CREATE_TIME); - assertThat(first.baseOffset()).isEqualTo(1L); - assertThat(first).hasSize(1); - - MutableRecordBatch second = batches.get(1); - // should we keep the client's compression type? - assertThat(second.compressionType()).isEqualTo(CompressionType.GZIP); - assertThat(second.timestampType()).isEqualTo(TimestampType.LOG_APPEND_TIME); - assertThat(second.baseOffset()).isEqualTo(2L); - assertThat(second).hasSize(1); - - CompletableFuture decryptedFuture = km.decrypt("topic", 1, encrypted, ByteBufferOutputStream::new).toCompletableFuture(); - assertThat(decryptedFuture).succeedsWithin(Duration.ZERO); - MemoryRecords decrypted = decryptedFuture.join(); - - assertThat(decrypted.batches()).hasSize(2); - List decryptedBatches = StreamSupport.stream(decrypted.batches().spliterator(), false).toList(); - MutableRecordBatch firstDecrypted = decryptedBatches.get(0); - assertThat(firstDecrypted.compressionType()).isEqualTo(CompressionType.NONE); - assertThat(firstDecrypted.timestampType()).isEqualTo(TimestampType.CREATE_TIME); - assertThat(firstDecrypted.baseOffset()).isEqualTo(1L); - assertThat(firstDecrypted).hasSize(1); - assertThat(firstDecrypted.iterator()) - .toIterable() - .singleElement() - .extracting(RecordTestUtils::recordValueAsBytes) - .isEqualTo(value); + @Test + void shouldPreserveMultipleBatchesOnDecrypt() { + // given + InMemoryKms kms = getInMemoryKms(); + EncryptionScheme scheme = createScheme(kms); + var km = createKeyManager(kms, 500_000); + + MutableRecordBatch firstBatch = RecordTestUtils.singleElementRecordBatch(RecordBatch.CURRENT_MAGIC_VALUE, 1L, CompressionType.GZIP, TimestampType.CREATE_TIME, 2L, + 3L, + (short) 4, 5, false, false, 1, ARBITRARY_KEY.getBytes( + StandardCharsets.UTF_8), + ARBITRARY_VALUE.getBytes(StandardCharsets.UTF_8)); + + MutableRecordBatch secondBatch = RecordTestUtils.singleElementRecordBatch(RecordBatch.CURRENT_MAGIC_VALUE, 2L, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 9L, 10L, + (short) 11, 12, false, false, 2, ARBITRARY_KEY_2.getBytes( + StandardCharsets.UTF_8), + ARBITRARY_VALUE_2.getBytes(StandardCharsets.UTF_8)); + MemoryRecords records = RecordTestUtils.memoryRecords(firstBatch, secondBatch); + MemoryRecords encrypted = assertImmediateSuccessAndGet(encrypt(km, scheme, records)); - MutableRecordBatch secondDecrypted = decryptedBatches.get(1); - assertThat(secondDecrypted.compressionType()).isEqualTo(CompressionType.GZIP); - assertThat(secondDecrypted.timestampType()).isEqualTo(TimestampType.LOG_APPEND_TIME); - assertThat(secondDecrypted.baseOffset()).isEqualTo(2L); - assertThat(secondDecrypted).hasSize(1); - assertThat(secondDecrypted.iterator()) - .toIterable() - .singleElement() - .extracting(RecordTestUtils::recordValueAsBytes) - .isEqualTo(value2); + // when + MemoryRecords decrypted = assertImmediateSuccessAndGet(decrypt(km, encrypted)); + // then + MemoryRecordsAssert decryptedAssert = MemoryRecordsAssert.assertThat(decrypted); + decryptedAssert.hasNumBatches(2); + decryptedAssert.firstBatch().hasMetadataMatching(firstBatch).hasNumRecords(1).firstRecord().hasValueEqualTo(ARBITRARY_VALUE); + decryptedAssert.lastBatch().hasMetadataMatching(secondBatch).hasNumRecords(1).firstRecord().hasValueEqualTo(ARBITRARY_VALUE_2); } @Test - void shouldPreserveControlBatch() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + void shouldPreserveControlBatchOnEncrypt() { + // given + InMemoryKms kms = getInMemoryKms(); + EncryptionScheme scheme = createScheme(kms); + var km = createKeyManager(kms, 500_000); - var kekId = kms.generateKey(); + MutableRecordBatch firstBatch = RecordTestUtils.singleElementRecordBatch(1L, ARBITRARY_KEY, ARBITRARY_VALUE, ABSENT_HEADERS); + MutableRecordBatch controlBatch = RecordTestUtils.abortTransactionControlBatch(2); + Record controlRecord = controlBatch.iterator().next(); + MemoryRecords records = RecordTestUtils.memoryRecords(firstBatch, controlBatch); - byte[] value = { 1, 2, 3 }; - Record record = RecordTestUtils.record(1, ByteBuffer.wrap(value)); - BatchAwareMemoryRecordsBuilder builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(1000)); - builder.addBatch(CompressionType.NONE, TimestampType.CREATE_TIME, 1); - builder.appendWithOffset(1L, record); - byte[] controlBatchValue = { 4, 5, 6 }; - RecordBatch controlBatch = controlBatch(2, controlBatchValue); - builder.addBatchLike(controlBatch); - builder.append(controlBatch.iterator().next()); - MemoryRecords records = builder.build(); - - EncryptionScheme scheme = new EncryptionScheme<>(kekId, EnumSet.of(RecordField.RECORD_VALUE)); - CompletableFuture encryptedFuture = km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new).toCompletableFuture(); - assertThat(encryptedFuture).succeedsWithin(Duration.ZERO); - MemoryRecords encrypted = encryptedFuture.join(); - record.value().rewind(); + // when + MemoryRecords encrypted = assertImmediateSuccessAndGet(encrypt(km, scheme, records)); - assertThat(encrypted.batches()).hasSize(2); - List batches = StreamSupport.stream(encrypted.batches().spliterator(), false).toList(); - MutableRecordBatch first = batches.get(0); - assertThat(first.compressionType()).isEqualTo(CompressionType.NONE); - assertThat(first.timestampType()).isEqualTo(TimestampType.CREATE_TIME); - assertThat(first.baseOffset()).isEqualTo(1L); - assertThat(first).hasSize(1); - - MutableRecordBatch second = batches.get(1); - // should we keep the client's compression type? - assertThat(second.compressionType()).isEqualTo(controlBatch.compressionType()); - assertThat(second.timestampType()).isEqualTo(controlBatch.timestampType()); - assertThat(second.baseOffset()).isEqualTo(controlBatch.baseOffset()); - assertThat(second.isControlBatch()).isTrue(); - assertThat(second).hasSize(1); - // control batches are not encrypted - assertThat(second.iterator()) - .toIterable() - .singleElement() - .extracting(RecordTestUtils::recordValueAsBytes) - .isEqualTo(controlBatchValue); - - CompletableFuture decryptedFuture = km.decrypt("topic", 1, encrypted, ByteBufferOutputStream::new).toCompletableFuture(); - assertThat(decryptedFuture).succeedsWithin(Duration.ZERO); - MemoryRecords decrypted = decryptedFuture.join(); - - assertThat(decrypted.batches()).hasSize(2); - List decryptedBatches = StreamSupport.stream(decrypted.batches().spliterator(), false).toList(); - MutableRecordBatch firstDecrypted = decryptedBatches.get(0); - assertThat(firstDecrypted.compressionType()).isEqualTo(CompressionType.NONE); - assertThat(firstDecrypted.timestampType()).isEqualTo(TimestampType.CREATE_TIME); - assertThat(firstDecrypted.baseOffset()).isEqualTo(1L); - assertThat(firstDecrypted).hasSize(1); - assertThat(firstDecrypted.iterator()) - .toIterable() - .singleElement() - .extracting(RecordTestUtils::recordValueAsBytes) - .isEqualTo(value); + // then + MemoryRecordsAssert encryptedAssert = MemoryRecordsAssert.assertThat(encrypted); + encryptedAssert.hasNumBatches(2); + encryptedAssert.firstBatch().hasMetadataMatching(firstBatch).hasNumRecords(1).firstRecord().hasValueNotEqualTo(ARBITRARY_VALUE); + encryptedAssert.lastBatch().hasMetadataMatching(controlBatch).hasNumRecords(1).firstRecord().hasValueEqualTo(controlRecord); + } - MutableRecordBatch secondDecrypted = decryptedBatches.get(1); - assertThat(secondDecrypted.compressionType()).isEqualTo(controlBatch.compressionType()); - assertThat(secondDecrypted.timestampType()).isEqualTo(controlBatch.timestampType()); - assertThat(secondDecrypted.baseOffset()).isEqualTo(controlBatch.baseOffset()); - assertThat(secondDecrypted.isControlBatch()).isTrue(); - assertThat(secondDecrypted).hasSize(1); - // control batch value is preserved - assertThat(second.iterator()) - .toIterable() - .singleElement() - .extracting(RecordTestUtils::recordValueAsBytes) - .isEqualTo(controlBatchValue); + @Test + void shouldPreserveControlBatchOnDecrypt() { + // given + InMemoryKms kms = getInMemoryKms(); + EncryptionScheme scheme = createScheme(kms); + var km = createKeyManager(kms, 500_000); + + MutableRecordBatch firstBatch = RecordTestUtils.singleElementRecordBatch(1L, ARBITRARY_KEY, ARBITRARY_VALUE, ABSENT_HEADERS); + MutableRecordBatch controlBatch = RecordTestUtils.abortTransactionControlBatch(2); + Record controlRecord = controlBatch.iterator().next(); + MemoryRecords records = RecordTestUtils.memoryRecords(firstBatch, controlBatch); + MemoryRecords encrypted = assertImmediateSuccessAndGet(encrypt(km, scheme, records)); + // when + MemoryRecords decrypted = assertImmediateSuccessAndGet(decrypt(km, encrypted)); + + // then + MemoryRecordsAssert decryptedAssert = MemoryRecordsAssert.assertThat(decrypted); + decryptedAssert.hasNumBatches(2); + decryptedAssert.firstBatch().hasMetadataMatching(firstBatch).hasNumRecords(1).firstRecord().hasValueEqualTo(ARBITRARY_VALUE); + decryptedAssert.lastBatch().hasMetadataMatching(controlBatch).hasNumRecords(1).firstRecord().hasValueEqualTo(controlRecord); } @Test - void shouldPreserveMultipleBatches_IncludingEmptyBatch() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); - - var kekId = kms.generateKey(); + void shouldPreserveEmptyBatchOnEncrypt() { + // given + InMemoryKms kms = getInMemoryKms(); + EncryptionScheme scheme = createScheme(kms); + var km = createKeyManager(kms, 500_000); - byte[] value = { 1, 2, 3 }; - Record record = RecordTestUtils.record(1, ByteBuffer.wrap(value)); - BatchAwareMemoryRecordsBuilder builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(1000)); - builder.addBatch(CompressionType.NONE, TimestampType.CREATE_TIME, 1); - builder.appendWithOffset(1L, record); - - MemoryRecords empty = RecordTestUtils.memoryRecordsWithAllRecordsRemoved(2L); - MutableRecordBatch emptyBatch = empty.batches().iterator().next(); - builder.writeBatch(emptyBatch); - MemoryRecords records = builder.build(); - - EncryptionScheme scheme = new EncryptionScheme<>(kekId, EnumSet.of(RecordField.RECORD_VALUE)); - CompletableFuture encryptedFuture = km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new).toCompletableFuture(); - assertThat(encryptedFuture).succeedsWithin(Duration.ZERO); - MemoryRecords encrypted = encryptedFuture.join(); - record.value().rewind(); + MutableRecordBatch firstBatch = RecordTestUtils.singleElementRecordBatch(1L, ARBITRARY_KEY, ARBITRARY_VALUE, ABSENT_HEADERS); + MutableRecordBatch emptyBatch = RecordTestUtils.recordBatchWithAllRecordsRemoved(2L); + MemoryRecords records = RecordTestUtils.memoryRecords(firstBatch, emptyBatch); - assertThat(encrypted.batches()).hasSize(2); - List batches = StreamSupport.stream(encrypted.batches().spliterator(), false).toList(); - MutableRecordBatch first = batches.get(0); - assertThat(first.compressionType()).isEqualTo(CompressionType.NONE); - assertThat(first.timestampType()).isEqualTo(TimestampType.CREATE_TIME); - assertThat(first.baseOffset()).isEqualTo(1L); - assertThat(first).hasSize(1); - - MutableRecordBatch second = batches.get(1); - // should we keep the client's compression type? - assertThat(second.compressionType()).isEqualTo(emptyBatch.compressionType()); - assertThat(second.timestampType()).isEqualTo(emptyBatch.timestampType()); - assertThat(second.baseOffset()).isEqualTo(emptyBatch.baseOffset()); - assertThat(second).hasSize(0); - - CompletableFuture decryptedFuture = km.decrypt("topic", 1, encrypted, ByteBufferOutputStream::new).toCompletableFuture(); - assertThat(decryptedFuture).succeedsWithin(Duration.ZERO); - MemoryRecords decrypted = decryptedFuture.join(); - - assertThat(decrypted.batches()).hasSize(2); - List decryptedBatches = StreamSupport.stream(decrypted.batches().spliterator(), false).toList(); - MutableRecordBatch firstDecrypted = decryptedBatches.get(0); - assertThat(firstDecrypted.compressionType()).isEqualTo(CompressionType.NONE); - assertThat(firstDecrypted.timestampType()).isEqualTo(TimestampType.CREATE_TIME); - assertThat(firstDecrypted.baseOffset()).isEqualTo(1L); - assertThat(firstDecrypted).hasSize(1); - assertThat(firstDecrypted.iterator()) - .toIterable() - .singleElement() - .extracting(RecordTestUtils::recordValueAsBytes) - .isEqualTo(value); + // when + MemoryRecords encrypted = assertImmediateSuccessAndGet(encrypt(km, scheme, records)); - MutableRecordBatch secondDecrypted = decryptedBatches.get(1); - assertThat(secondDecrypted.compressionType()).isEqualTo(emptyBatch.compressionType()); - assertThat(secondDecrypted.timestampType()).isEqualTo(emptyBatch.timestampType()); - assertThat(secondDecrypted.baseOffset()).isEqualTo(emptyBatch.baseOffset()); + // then + MemoryRecordsAssert encryptedAssert = MemoryRecordsAssert.assertThat(encrypted); + encryptedAssert.hasNumBatches(2); + encryptedAssert.firstBatch().hasMetadataMatching(firstBatch).hasNumRecords(1).firstRecord().hasValueNotEqualTo(ARBITRARY_VALUE); + encryptedAssert.lastBatch().hasMetadataMatching(emptyBatch).hasNumRecords(0); } - private static RecordBatch controlBatch(int baseOffset, byte[] arbitraryValue) { - MemoryRecordsBuilder builder = new MemoryRecordsBuilder(ByteBuffer.allocate(1000), RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, - TimestampType.CREATE_TIME, baseOffset, 1L, 1L, (short) 1, 1, false, true, 1, 1); - byte[] key = { 0, 0, (byte) ControlRecordType.ABORT.type(), (byte) (ControlRecordType.ABORT.type() >> 8) }; - builder.appendControlRecordWithOffset(baseOffset, new SimpleRecord(1L, key, arbitraryValue)); - MemoryRecords controlBatchRecords = builder.build(); - return controlBatchRecords.firstBatch(); + @Test + void shouldPreserveEmptyBatchOnDecrypt() { + // given + InMemoryKms kms = getInMemoryKms(); + EncryptionScheme scheme = createScheme(kms); + var km = createKeyManager(kms, 500_000); + + MutableRecordBatch firstBatch = RecordTestUtils.singleElementRecordBatch(1L, ARBITRARY_KEY, ARBITRARY_VALUE, ABSENT_HEADERS); + MutableRecordBatch emptyBatch = RecordTestUtils.recordBatchWithAllRecordsRemoved(2L); + MemoryRecords records = RecordTestUtils.memoryRecords(firstBatch, emptyBatch); + MemoryRecords encrypted = assertImmediateSuccessAndGet(encrypt(km, scheme, records)); + + // when + MemoryRecords decrypted = assertImmediateSuccessAndGet(decrypt(km, encrypted)); + + // then + MemoryRecordsAssert decryptedAssert = MemoryRecordsAssert.assertThat(decrypted); + decryptedAssert.hasNumBatches(2); + decryptedAssert.firstBatch().hasMetadataMatching(firstBatch).hasNumRecords(1).firstRecord().hasValueEqualTo(ARBITRARY_VALUE); + decryptedAssert.lastBatch().hasMetadataMatching(emptyBatch).hasNumRecords(0); } @NonNull @@ -360,9 +288,8 @@ private static CompletionStage doEncrypt(InBandKeyManager(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -385,9 +312,8 @@ void shouldTolerateEncryptingAndDecryptingEmptyRecordValue() { @Test void decryptSupportsUnencryptedRecordValue() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); byte[] recBytes = { 1, 2, 3 }; Record record = RecordTestUtils.record(recBytes); @@ -408,9 +334,8 @@ static List decryptSupportsEmptyRecordBatches() { @ParameterizedTest @MethodSource void decryptSupportsEmptyRecordBatches(MemoryRecords records) { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); assertThat(km.decrypt("foo", 1, records, ByteBufferOutputStream::new)) .succeedsWithin(Duration.ZERO).isSameAs(records); } @@ -418,9 +343,8 @@ void decryptSupportsEmptyRecordBatches(MemoryRecords records) { // we do not want to break compaction tombstoning by creating a parcel for the null value case @Test void nullRecordValuesShouldNotBeModifiedAtEncryptTime() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -439,9 +363,8 @@ void nullRecordValuesShouldNotBeModifiedAtEncryptTime() { // value is null. @Test void nullRecordValuesAreIncompatibleWithHeaderEncryption() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -458,9 +381,8 @@ void nullRecordValuesAreIncompatibleWithHeaderEncryption() { @Test void shouldTolerateEncryptingEmptyBatch() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -474,22 +396,19 @@ void shouldTolerateEncryptingEmptyBatch() { @Test void shouldTolerateEncryptingSingleBatchMemoryRecordsWithNoRecords() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); - var kekId = kms.generateKey(); - EncryptionScheme scheme = new EncryptionScheme<>(kekId, EnumSet.of(RecordField.RECORD_VALUE)); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); + EncryptionScheme scheme = createScheme(kms); MemoryRecords records = RecordTestUtils.memoryRecordsWithAllRecordsRemoved(); - assertThat(km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new)).succeedsWithin(Duration.ZERO).isSameAs(records); + assertThat(encrypt(km, scheme, records)).succeedsWithin(Duration.ZERO).isSameAs(records); } @Test void encryptionRetry() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + InMemoryKms kms = getInMemoryKms(); var kekId = kms.generateKey(); // configure 1 encryption per dek but then try to encrypt 2 records, will destroy and retry - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 1); + var km = createKeyManager(kms, 1); var value = ByteBuffer.wrap(new byte[]{ 1, 2, 3 }); var value2 = ByteBuffer.wrap(new byte[]{ 4, 5, 6 }); @@ -507,12 +426,11 @@ void encryptionRetry() { @Test void dekCreationRetryFailurePropagatedToEncryptCompletionStage() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + InMemoryKms kms = getInMemoryKms(); var kekId = kms.generateKey(); InMemoryKms spyKms = Mockito.spy(kms); when(spyKms.generateDekPair(kekId)).thenReturn(CompletableFuture.failedFuture(new EncryptorCreationException("failed to create that DEK"))); - var km = new InBandKeyManager<>(spyKms, BufferPool.allocating(), 500000); + var km = createKeyManager(spyKms, 500000); var value = ByteBuffer.wrap(new byte[]{ 1, 2, 3 }); var value2 = ByteBuffer.wrap(new byte[]{ 4, 5, 6 }); @@ -529,13 +447,12 @@ void dekCreationRetryFailurePropagatedToEncryptCompletionStage() { @Test void edekDecryptionRetryFailurePropagatedToDecryptCompletionStage() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + InMemoryKms kms = getInMemoryKms(); var kekId = kms.generateKey(); InMemoryKms spyKms = Mockito.spy(kms); doReturn(CompletableFuture.failedFuture(new KmsException("failed to create that DEK"))).when(spyKms).decryptEdek(any()); - var km = new InBandKeyManager<>(spyKms, BufferPool.allocating(), 50000); + var km = createKeyManager(spyKms, 50000); var value = ByteBuffer.wrap(new byte[]{ 1, 2, 3 }); var value2 = ByteBuffer.wrap(new byte[]{ 4, 5, 6 }); @@ -556,13 +473,12 @@ void edekDecryptionRetryFailurePropagatedToDecryptCompletionStage() { @Test void afterWeFailToLoadADekTheNextEncryptionAttemptCanSucceed() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + InMemoryKms kms = getInMemoryKms(); var kekId = kms.generateKey(); InMemoryKms spyKms = Mockito.spy(kms); when(spyKms.generateDekPair(kekId)).thenReturn(CompletableFuture.failedFuture(new KmsException("failed to create that DEK"))); - var km = new InBandKeyManager<>(spyKms, BufferPool.allocating(), 50000); + var km = createKeyManager(spyKms, 50000); var value = ByteBuffer.wrap(new byte[]{ 1, 2, 3 }); var value2 = ByteBuffer.wrap(new byte[]{ 4, 5, 6 }); @@ -591,9 +507,8 @@ void afterWeFailToLoadADekTheNextEncryptionAttemptCanSucceed() { @Test void shouldEncryptRecordValueForMultipleRecords() throws ExecutionException, InterruptedException, TimeoutException { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -625,9 +540,8 @@ void shouldEncryptRecordValueForMultipleRecords() throws ExecutionException, Int @Test void shouldGenerateNewDekIfOldDekHasNoRemainingEncryptions() throws ExecutionException, InterruptedException, TimeoutException { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 2); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 2); var kekId = kms.generateKey(); @@ -664,9 +578,8 @@ void shouldGenerateNewDekIfOldDekHasNoRemainingEncryptions() throws ExecutionExc @Test void shouldGenerateNewDekIfOldOneHasSomeRemainingEncryptionsButNotEnoughForWholeBatch() throws ExecutionException, InterruptedException, TimeoutException { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 3); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 3); var kekId = kms.generateKey(); @@ -704,9 +617,8 @@ void shouldGenerateNewDekIfOldOneHasSomeRemainingEncryptionsButNotEnoughForWhole @Test void shouldUseSameDekForMultipleBatches() throws ExecutionException, InterruptedException, TimeoutException { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 4); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 4); var kekId = kms.generateKey(); @@ -743,9 +655,8 @@ void shouldUseSameDekForMultipleBatches() throws ExecutionException, Interrupted @Test void shouldEncryptRecordHeaders() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -773,9 +684,8 @@ void shouldEncryptRecordHeaders() { @Test void shouldEncryptRecordHeadersForMultipleRecords() throws ExecutionException, InterruptedException, TimeoutException { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -812,9 +722,8 @@ void shouldEncryptRecordHeadersForMultipleRecords() throws ExecutionException, I @Test void shouldPropagateHeadersInClearWhenNotEncryptingHeaders() { - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 500_000); + InMemoryKms kms = getInMemoryKms(); + var km = createKeyManager(kms, 500_000); var kekId = kms.generateKey(); @@ -852,14 +761,13 @@ void decryptPreservesOrdering(long offsetA, long offsetB) { var topic = "topic"; var partition = 1; - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + InMemoryKms kms = getInMemoryKms(); var kekId1 = kms.generateKey(); var kekId2 = kms.generateKey(); var spyKms = Mockito.spy(kms); - var km = new InBandKeyManager<>(spyKms, BufferPool.allocating(), 50000); + var km = createKeyManager(spyKms, 50000); byte[] rec1Bytes = { 1, 2, 3 }; byte[] rec2Bytes = { 4, 5, 6 }; @@ -913,11 +821,10 @@ void decryptPreservesOrdering_RecordSetIncludeUnencrypted() { var topic = "topic"; var partition = 1; - var kmsService = UnitTestingKmsService.newInstance(); - InMemoryKms kms = kmsService.buildKms(new UnitTestingKmsService.Config()); + InMemoryKms kms = getInMemoryKms(); var kekId = kms.generateKey(); - var km = new InBandKeyManager<>(kms, BufferPool.allocating(), 50000); + var km = createKeyManager(kms, 50000); byte[] rec1Bytes = { 1, 2, 3 }; byte[] rec2Bytes = { 4, 5, 6 }; @@ -960,6 +867,12 @@ public TestingDek getSerializedGeneratedEdek(InMemoryKms kms, int i) { return new TestingDek(bytes); } + private T assertImmediateSuccessAndGet(CompletionStage stage) { + CompletableFuture future = stage.toCompletableFuture(); + assertThat(future).succeedsWithin(Duration.ZERO); + return future.join(); + } + @NonNull private static List extractEdeks(List encrypted) { List deks = encrypted.stream() @@ -975,4 +888,31 @@ private static List extractEdeks(List encrypted) { return deks; } + @NonNull + private static InBandKeyManager createKeyManager(InMemoryKms kms, int maxEncryptionsPerDek) { + return new InBandKeyManager<>(kms, BufferPool.allocating(), maxEncryptionsPerDek); + } + + @NonNull + private static EncryptionScheme createScheme(InMemoryKms kms) { + var kekId = kms.generateKey(); + return new EncryptionScheme<>(kekId, EnumSet.of(RecordField.RECORD_VALUE)); + } + + @NonNull + private static CompletionStage decrypt(InBandKeyManager km, MemoryRecords encrypted) { + return km.decrypt("topic", 1, encrypted, ByteBufferOutputStream::new); + } + + @NonNull + private static InMemoryKms getInMemoryKms() { + var kmsService = UnitTestingKmsService.newInstance(); + return kmsService.buildKms(new UnitTestingKmsService.Config()); + } + + @NonNull + private static CompletionStage encrypt(InBandKeyManager km, EncryptionScheme scheme, MemoryRecords records) { + return km.encrypt("topic", 1, scheme, records, ByteBufferOutputStream::new); + } + } diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java index 88f8505284..5884bc0431 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java @@ -32,7 +32,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.in; class BatchAwareMemoryRecordsBuilderTest { @@ -52,7 +51,7 @@ void shouldRequireABatchBeforeAppend() { void shouldBePossibleToWriteBatchDirectly() { // Given var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); - MemoryRecords input = RecordTestUtils.memoryRecords("a", "b"); + MemoryRecords input = RecordTestUtils.singleElementMemoryRecords("a", "b"); MutableRecordBatch recordBatch = input.batchIterator().next(); // When @@ -73,7 +72,7 @@ void shouldBePossibleToWriteBatchAfterBuildingABatch() { byte[] value1 = { 4, 5, 6 }; builder.appendWithOffset(0L, 1L, new byte[]{ 1, 2, 3 }, value1, new Header[]{}); byte[] value2 = { 10, 11, 12 }; - MemoryRecords input = RecordTestUtils.memoryRecords(RecordBatch.CURRENT_MAGIC_VALUE, 1L, 1L, new byte[]{ 7, 8, 9 }, value2); + MemoryRecords input = RecordTestUtils.singleElementMemoryRecords(RecordBatch.CURRENT_MAGIC_VALUE, 1L, 1L, new byte[]{ 7, 8, 9 }, value2); MutableRecordBatch recordBatch = input.batchIterator().next(); // When @@ -102,7 +101,7 @@ void shouldBePossibleToBuildABatchAfterWritingBatch() { // Given byte[] value1 = { 10, 11, 12 }; var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); - MemoryRecords input = RecordTestUtils.memoryRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, 1L, new byte[]{ 7, 8, 9 }, value1); + MemoryRecords input = RecordTestUtils.singleElementMemoryRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, 1L, new byte[]{ 7, 8, 9 }, value1); MutableRecordBatch recordBatch = input.batchIterator().next(); builder.writeBatch(recordBatch); @@ -174,7 +173,7 @@ void shouldPreventAddBatchAfterBuild() { void shouldPreventAddBatchLikeAfterBuild() { // Given var builder = new BatchAwareMemoryRecordsBuilder(new ByteBufferOutputStream(100)); - RecordBatch batch = RecordTestUtils.memoryRecords("key", "value").firstBatch(); + RecordBatch batch = RecordTestUtils.singleElementMemoryRecords("key", "value").firstBatch(); // When builder.build(); diff --git a/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java b/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java index 15e5bf60a0..f3d4d7422f 100644 --- a/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java +++ b/kroxylicious-integration-tests/src/test/java/io/kroxylicious/proxy/encryption/EnvelopeEncryptionFilterIT.java @@ -25,6 +25,7 @@ import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.TopicPartition; import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.ThrowingConsumer; import org.junit.jupiter.api.TestTemplate; import org.junit.jupiter.api.extension.ExtendWith; @@ -175,14 +176,10 @@ void roundTripTransactionalIsolation(KafkaCluster cluster, Topic topic, TestKmsF } } - interface ExceptionalConsumer { - void accept(T t) throws Exception; - } - - Producer withTransaction(Producer producer, ExceptionalConsumer> consumer) { + Producer withTransaction(Producer producer, ThrowingConsumer> action) { producer.beginTransaction(); try { - consumer.accept(producer); + action.accept(producer); } catch (Exception e) { throw new RuntimeException(e); From b8f3b05a49a1d0ba3dc272b5c21aad31929a8e04 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 24 Jan 2024 11:18:23 +1300 Subject: [PATCH 10/11] Appease sonar Signed-off-by: Robert Young --- .../test/assertj/RecordAssert.java | 14 +++--- .../test/record/RecordTestUtils.java | 29 ++++++++----- .../encryption/inband/InBandKeyManager.java | 43 +++++++++++-------- .../BatchAwareMemoryRecordsBuilder.java | 8 ++-- .../BatchAwareMemoryRecordsBuilderTest.java | 5 ++- 5 files changed, 58 insertions(+), 41 deletions(-) diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java index 65b3e4b7e3..7c08af6b00 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/assertj/RecordAssert.java @@ -19,6 +19,10 @@ import io.kroxylicious.test.record.RecordTestUtils; public class RecordAssert extends AbstractAssert { + + private static final String RECORD_VALUE_DESCRIPTION = "record value"; + private static final String RECORD_KEY_DESCRIPTION = "record key"; + protected RecordAssert(Record record) { super(record, RecordAssert.class); describedAs(record == null ? "null record" : "record"); @@ -57,13 +61,13 @@ private AbstractLongAssert timestampAssert() { private AbstractObjectAssert keyStrAssert() { isNotNull(); return Assertions.assertThat(actual).extracting(RecordTestUtils::recordKeyAsString) - .describedAs("record key"); + .describedAs(RECORD_KEY_DESCRIPTION); } public RecordAssert hasKeyEqualTo(String expect) { isNotNull(); Assertions.assertThat(actual).extracting(RecordTestUtils::recordKeyAsString) - .describedAs("record key") + .describedAs(RECORD_KEY_DESCRIPTION) .isEqualTo(expect); return this; } @@ -77,13 +81,13 @@ public RecordAssert hasNullKey() { private AbstractStringAssert valueStrAssert() { isNotNull(); return Assertions.assertThat(RecordTestUtils.recordValueAsString(actual)) - .describedAs("record value"); + .describedAs(RECORD_VALUE_DESCRIPTION); } private AbstractByteArrayAssert valueBytesAssert() { isNotNull(); return Assertions.assertThat(RecordTestUtils.recordValueAsBytes(actual)) - .describedAs("record value"); + .describedAs(RECORD_VALUE_DESCRIPTION); } public RecordAssert hasValueEqualTo(String expect) { @@ -113,7 +117,7 @@ public RecordAssert hasValueEqualTo(Record expected) { public RecordAssert hasNullValue() { isNotNull(); Assertions.assertThat(actual).extracting(RecordTestUtils::recordValueAsString) - .describedAs("record value") + .describedAs(RECORD_VALUE_DESCRIPTION) .isNull(); return this; } diff --git a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java index c02d7b227d..dca9fc2808 100644 --- a/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java +++ b/kroxylicious-filter-test-support/src/main/java/io/kroxylicious/test/record/RecordTestUtils.java @@ -6,6 +6,8 @@ package io.kroxylicious.test.record; +import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.List; @@ -332,13 +334,17 @@ public static MemoryRecords singleElementMemoryRecords(String key, String value, * Return a MemoryRecords containing the specified batches */ public static MemoryRecords memoryRecords(MutableRecordBatch... batches) { - ByteBufferOutputStream outputStream = new ByteBufferOutputStream(1000); - for (MutableRecordBatch batch : batches) { - batch.writeTo(outputStream); + try (ByteBufferOutputStream outputStream = new ByteBufferOutputStream(1000)) { + for (MutableRecordBatch batch : batches) { + batch.writeTo(outputStream); + } + ByteBuffer buffer = outputStream.buffer(); + buffer.flip(); + return MemoryRecords.readableRecords(buffer); + } + catch (IOException e) { + throw new UncheckedIOException(e); } - ByteBuffer buffer = outputStream.buffer(); - buffer.flip(); - return MemoryRecords.readableRecords(buffer); } /** @@ -522,10 +528,11 @@ private static MemoryRecordsBuilder memoryRecordsBuilder(byte magic, * @return batch */ public static MutableRecordBatch abortTransactionControlBatch(int baseOffset) { - MemoryRecordsBuilder builder = new MemoryRecordsBuilder(ByteBuffer.allocate(1000), RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, - TimestampType.CREATE_TIME, baseOffset, 1L, 1L, (short) 1, 1, true, true, 1, 1); - builder.appendEndTxnMarker(1l, new EndTransactionMarker(ControlRecordType.ABORT, 1)); - MemoryRecords controlBatchRecords = builder.build(); - return controlBatchRecords.batchIterator().next(); + try (MemoryRecordsBuilder builder = new MemoryRecordsBuilder(ByteBuffer.allocate(1000), RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, baseOffset, 1L, 1L, (short) 1, 1, true, true, 1, 1)) { + builder.appendEndTxnMarker(1l, new EndTransactionMarker(ControlRecordType.ABORT, 1)); + MemoryRecords controlBatchRecords = builder.build(); + return controlBatchRecords.batchIterator().next(); + } } } diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java index bb9ced6124..fb552141e8 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/inband/InBandKeyManager.java @@ -412,30 +412,35 @@ private BatchAwareMemoryRecordsBuilder decrypt(String topicName, builder.writeBatch(batch); } else { - builder.addBatchLike(batch); - for (Record kafkaRecord : batch) { - var decryptionVersion = decryptionVersion(topicName, partition, kafkaRecord); - if (decryptionVersion == null) { - builder.append(kafkaRecord); - } - else if (decryptionVersion == EncryptionVersion.V1) { - ByteBuffer wrapper = kafkaRecord.value(); - var edekLength = ByteUtils.readUnsignedVarint(wrapper); - ByteBuffer slice = wrapper.slice(wrapper.position(), edekLength); - var edek = edekSerde.deserialize(slice); - wrapper.position(wrapper.position() + edekLength); - AesGcmEncryptor aesGcmEncryptor = encryptorMap.get(edek); - if (aesGcmEncryptor == null) { - throw new RuntimeException("no encryptor loaded for edek, " + edek); - } - decryptRecord(EncryptionVersion.V1, aesGcmEncryptor, wrapper, kafkaRecord, builder); - } - } + decryptBatch(topicName, partition, builder, encryptorMap, batch); } } return builder; } + private void decryptBatch(String topicName, int partition, @NonNull BatchAwareMemoryRecordsBuilder builder, @NonNull Map encryptorMap, + MutableRecordBatch batch) { + builder.addBatchLike(batch); + for (Record kafkaRecord : batch) { + var decryptionVersion = decryptionVersion(topicName, partition, kafkaRecord); + if (decryptionVersion == null) { + builder.append(kafkaRecord); + } + else if (decryptionVersion == EncryptionVersion.V1) { + ByteBuffer wrapper = kafkaRecord.value(); + var edekLength = ByteUtils.readUnsignedVarint(wrapper); + ByteBuffer slice = wrapper.slice(wrapper.position(), edekLength); + var edek = edekSerde.deserialize(slice); + wrapper.position(wrapper.position() + edekLength); + AesGcmEncryptor aesGcmEncryptor = encryptorMap.get(edek); + if (aesGcmEncryptor == null) { + throw new EncryptionException("no encryptor loaded for edek, " + edek); + } + decryptRecord(EncryptionVersion.V1, aesGcmEncryptor, wrapper, kafkaRecord, builder); + } + } + } + private ByteBufferOutputStream allocateBufferForDecode(MemoryRecords memoryRecords, IntFunction allocator) { int sizeEstimate = memoryRecords.sizeInBytes(); return allocator.apply(sizeEstimate); diff --git a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java index e5a3db16ac..37afc93123 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/main/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilder.java @@ -276,18 +276,18 @@ public BatchAwareMemoryRecordsBuilder appendWithOffset(long offset, Record recor * @return the memory records */ public @NonNull MemoryRecords build() { - ByteBuffer buffer; + ByteBuffer recordsBuff; if (closed) { - buffer = this.buffer.buffer(); + recordsBuff = this.buffer.buffer(); } else { closed = true; maybeAppendCurrentBatch(); ByteBuffer buf = this.buffer.buffer(); buf.flip(); - buffer = buf; + recordsBuff = buf; } - return MemoryRecords.readableRecords(buffer); + return MemoryRecords.readableRecords(recordsBuff); } /** diff --git a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java index 5884bc0431..f990b0f562 100644 --- a/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java +++ b/kroxylicious-filters/kroxylicious-encryption/src/test/java/io/kroxylicious/filter/encryption/records/BatchAwareMemoryRecordsBuilderTest.java @@ -233,8 +233,8 @@ void shouldPreventAppendControlRecordAfterBuild() { builder.build(); // Then + SimpleRecord controlRecord = controlRecord(); assertThatThrownBy(() -> { - SimpleRecord controlRecord = controlRecord(); builder.appendControlRecordWithOffset(1, controlRecord); }) .isExactlyInstanceOf(IllegalStateException.class) @@ -262,8 +262,9 @@ void shouldPreventAppendEndTxnMarkerRecordAfterBuild() { builder.build(); // Then + EndTransactionMarker marker = new EndTransactionMarker(ControlRecordType.ABORT, 1); assertThatThrownBy(() -> { - builder.appendEndTxnMarker(1, new EndTransactionMarker(ControlRecordType.ABORT, 1)); + builder.appendEndTxnMarker(1, marker); }) .isExactlyInstanceOf(IllegalStateException.class) .hasMessageContaining("Builder is closed"); From 2d255346bc9414db417b4b0b731b8ed595a50291 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 24 Jan 2024 12:27:25 +1300 Subject: [PATCH 11/11] Changelog Signed-off-by: Robert Young --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 712c3a90ad..cee7ed1918 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Please enumerate **all user-facing** changes using format `