Skip to content

Commit

Permalink
Add RecordEncryptor and RecordDecryptor (kroxylicious#899)
Browse files Browse the repository at this point in the history
* Add RecordEncryptor
* Add the RecordDecryptor

Signed-off-by: Tom Bentley <[email protected]>
  • Loading branch information
tombentley authored Feb 1, 2024
1 parent 360cf78 commit 96817b0
Show file tree
Hide file tree
Showing 12 changed files with 667 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

package io.kroxylicious.test.assertj;

import java.nio.charset.StandardCharsets;

import org.apache.kafka.common.header.Header;
import org.assertj.core.api.AbstractAssert;
import org.assertj.core.api.AbstractByteArrayAssert;
import org.assertj.core.api.Assertions;

public class HeaderAssert extends AbstractAssert<HeaderAssert, Header> {
Expand All @@ -29,20 +32,25 @@ public HeaderAssert hasKeyEqualTo(String expected) {
}

public HeaderAssert hasValueEqualTo(String expected) {
isNotNull();
String valueString = actual.value() == null ? null : new String(actual.value());
Assertions.assertThat(valueString)
.describedAs("header value")
.isEqualTo(expected);
valueAssert().isEqualTo(expected == null ? null : expected.getBytes(StandardCharsets.UTF_8));
return this;
}

public HeaderAssert hasValueEqualTo(byte[] expected) {
valueAssert().isEqualTo(expected);
return this;
}

public HeaderAssert hasNullValue() {
isNotNull();
Assertions.assertThat(actual.value())
.describedAs("header value")
.isNull();
valueAssert().isNull();
return this;
}

private AbstractByteArrayAssert<?> valueAssert() {
isNotNull();
AbstractByteArrayAssert<?> headerValue = Assertions.assertThat(actual.value())
.describedAs("header value");
return headerValue;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public static RecordAssert assertThat(Record actual) {
return new RecordAssert(actual);
}

public RecordAssert hasOffsetEqualTo(int expect) {
public RecordAssert hasOffsetEqualTo(long expect) {
isNotNull();
AbstractLongAssert<?> offset = offsetAssert();
offset.isEqualTo(expect);
Expand All @@ -45,7 +45,7 @@ private AbstractLongAssert<?> offsetAssert() {
.describedAs("record offset");
}

public RecordAssert hasTimestampEqualTo(int expect) {
public RecordAssert hasTimestampEqualTo(long expect) {
isNotNull();
AbstractLongAssert<?> timestamp = timestampAssert();
timestamp.isEqualTo(expect);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ private static byte[] bytesOf(ByteBuffer buffer) {
return result;
}

public static String asString(ByteBuffer buffer) {
return buffer == null ? null : new String(bytesOf(buffer), StandardCharsets.UTF_8);
}

/**
* Get a copy of the bytes contained in the given {@code record}'s value, without changing the
* {@code record}'s {@link ByteBuffer#position() position}, {@link ByteBuffer#limit() limit} or {@link ByteBuffer#mark() mark}.
Expand All @@ -70,15 +74,15 @@ public static byte[] recordValueAsBytes(Record record) {
}

public static String recordValueAsString(Record record) {
return record.value() == null ? null : new String(bytesOf(record.value()), StandardCharsets.UTF_8);
return asString(record.value());
}

public static byte[] recordKeyAsBytes(Record record) {
return bytesOf(record.key());
}

public static String recordKeyAsString(Record record) {
return record.key() == null ? null : new String(bytesOf(record.key()), StandardCharsets.UTF_8);
return asString(record.key());
}

/**
Expand Down Expand Up @@ -247,6 +251,18 @@ public static Record record(byte magic,
return MemoryRecords.readableRecords(mr.buffer()).records().iterator().next();
}

public static Record record(long offset, long timestamp, ByteBuffer key, ByteBuffer value, Header[] headers) {
return record(RecordBatch.CURRENT_MAGIC_VALUE, offset, timestamp, key, value, headers);
}

public static Record record(long offset, long timestamp, String key, String value, Header[] headers) {
return record(RecordBatch.CURRENT_MAGIC_VALUE, offset, timestamp, key, value, headers);
}

public static Record record(long offset, long timestamp, byte[] key, byte[] value, Header[] headers) {
return record(RecordBatch.CURRENT_MAGIC_VALUE, offset, timestamp, key, value, headers);
}

/**
* Return a singleton RecordBatch containing a single Record with the given key, value and headers.
* The batch will use the current magic. The baseOffset and offset of the record will be 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ void testHeaderHasValueEqualTo() {
RecordHeader nonNullValue = new RecordHeader("foo", "abc".getBytes(StandardCharsets.UTF_8));
HeaderAssert nonNullValueAssert = KafkaAssertions.assertThat(nonNullValue);

nullValueAssert.hasValueEqualTo(null);
nullValueAssert.hasValueEqualTo((String) null);
nonNullValueAssert.hasValueEqualTo("abc");
throwsAssertionErrorContaining(() -> nonNullValueAssert.hasValueEqualTo("other"), "[header value]");
throwsAssertionErrorContaining(() -> nonNullValueAssert.hasValueEqualTo(null), "[header value]");
throwsAssertionErrorContaining(() -> nonNullValueAssert.hasValueEqualTo((String) null), "[header value]");
throwsAssertionErrorContaining(() -> nullValueAssert.hasValueEqualTo("other"), "[header value]");
assertThrowsIfHeaderNull(nullAssert -> nullAssert.hasValueEqualTo("any"));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Kroxylicious Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/

package io.kroxylicious.filter.encryption.inband;

import org.apache.kafka.common.record.Record;

import io.kroxylicious.filter.encryption.EncryptionVersion;

import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;

/**
* Helper class to group together some state for decryption.
* Either both, or neither, of the given {@code decryptionVersion} and {@code encryptor} should be null.
* @param kafkaRecord The record
* @param decryptionVersion The version
* @param encryptor The encryptor
*/
record DecryptState(@NonNull Record kafkaRecord, @Nullable EncryptionVersion decryptionVersion,
@Nullable AesGcmEncryptor encryptor) {

DecryptState {
if (decryptionVersion == null ^ encryptor == null) {
throw new IllegalArgumentException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ private void decryptRecord(EncryptionVersion decryptionVersion,
synchronized (encryptor) {
plaintextParcel = decryptParcel(wrapper.slice(), encryptor);
}
Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, kafkaRecord, builder);
Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, kafkaRecord, (v, h) -> {
builder.appendWithOffset(kafkaRecord.offset(), kafkaRecord.timestamp(), kafkaRecord.key(), v, h);
});
}

private ByteBuffer decryptParcel(ByteBuffer ciphertextParcel, AesGcmEncryptor encryptor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Set;
import java.util.function.BiConsumer;

import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.internals.RecordHeader;
Expand All @@ -19,7 +20,6 @@
import io.kroxylicious.filter.encryption.EncryptionException;
import io.kroxylicious.filter.encryption.ParcelVersion;
import io.kroxylicious.filter.encryption.RecordField;
import io.kroxylicious.filter.encryption.records.BatchAwareMemoryRecordsBuilder;

import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;
Expand Down Expand Up @@ -58,7 +58,7 @@ static void writeParcel(ParcelVersion parcelVersion, Set<RecordField> recordFiel
static void readParcel(ParcelVersion parcelVersion,
ByteBuffer parcel,
Record encryptedRecord,
@NonNull BatchAwareMemoryRecordsBuilder builder) {
@NonNull BiConsumer<ByteBuffer, Header[]> consumer) {
switch (parcelVersion) {
case V1:
var parcelledValue = readRecordValue(parcel);
Expand All @@ -80,7 +80,7 @@ static void readParcel(ParcelVersion parcelVersion,
usedHeaders = parcelledHeaders;
}
ByteBuffer parcelledBuffer = parcelledValue == ABSENT_VALUE ? encryptedRecord.value() : parcelledValue;
builder.appendWithOffset(encryptedRecord.offset(), encryptedRecord.timestamp(), encryptedRecord.key(), parcelledBuffer, usedHeaders);
consumer.accept(parcelledBuffer, usedHeaders);
break;
default:
throw new EncryptionException("Unknown parcel version " + parcelVersion);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright Kroxylicious Authors.
*
* Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
*/

package io.kroxylicious.filter.encryption.inband;

import java.nio.ByteBuffer;
import java.util.Objects;
import java.util.function.IntFunction;

import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.record.Record;
import org.apache.kafka.common.utils.ByteUtils;

import io.kroxylicious.filter.encryption.AadSpec;
import io.kroxylicious.filter.encryption.CipherCode;
import io.kroxylicious.filter.encryption.records.RecordTransform;

import edu.umd.cs.findbugs.annotations.NonNull;
import edu.umd.cs.findbugs.annotations.Nullable;

/**
* A {@link RecordTransform} that decrypts records that were previously encrypted by {@link RecordEncryptor}.
*/
public class RecordDecryptor implements RecordTransform {

private final IntFunction<DecryptState> keys;
int index = 0;
private ByteBuffer transformedValue;
private Header[] transformedHeaders;

public RecordDecryptor(@NonNull IntFunction<DecryptState> keys) {
Objects.requireNonNull(keys);
this.keys = keys;
}

@Override
public void init(@NonNull Record record) {
DecryptState decryptState = keys.apply(index);
if (decryptState == null || decryptState.encryptor() == null) {
transformedValue = record.value();
transformedHeaders = record.headers();
return;
}
var encryptor = decryptState.encryptor();
var decryptionVersion = decryptState.decryptionVersion();
var wrapper = record.value();
// Skip the edek
var edekLength = ByteUtils.readUnsignedVarint(wrapper);
wrapper.position(wrapper.position() + edekLength);

var aadSpec = AadSpec.fromCode(wrapper.get());
ByteBuffer aad = switch (aadSpec) {
case NONE -> ByteUtils.EMPTY_BUF;
};

var cipherCode = CipherCode.fromCode(wrapper.get());

ByteBuffer plaintextParcel;
synchronized (encryptor) {
plaintextParcel = decryptParcel(wrapper.slice(), encryptor);
}
Parcel.readParcel(decryptionVersion.parcelVersion(), plaintextParcel, record, (v, h) -> {
transformedValue = v;
transformedHeaders = h;
});

}

private ByteBuffer decryptParcel(ByteBuffer ciphertextParcel, AesGcmEncryptor encryptor) {
ByteBuffer plaintext = ciphertextParcel.duplicate();
encryptor.decrypt(ciphertextParcel, plaintext);
plaintext.flip();
return plaintext;
}

@Override
public long transformOffset(@NonNull Record record) {
return record.offset();
}

@Override
public long transformTimestamp(@NonNull Record record) {
return record.timestamp();
}

@Nullable
@Override
public ByteBuffer transformKey(@NonNull Record record) {
return record.key();
}

@Nullable
@Override
public ByteBuffer transformValue(@NonNull Record record) {
return transformedValue == null ? null : transformedValue.duplicate();
}

@Nullable
@Override
public Header[] transformHeaders(@NonNull Record record) {
return transformedHeaders;
}

@Override
public void resetAfterTransform(@NonNull Record record) {
if (transformedValue != null) {
transformedValue.clear();
}
transformedHeaders = null;
index++;
}
}
Loading

0 comments on commit 96817b0

Please sign in to comment.