Skip to content

Commit

Permalink
Merge pull request kroxylicious#908 from SamBarker/chasing-coverage
Browse files Browse the repository at this point in the history
Chasing coverage
  • Loading branch information
SamBarker authored Jan 23, 2024
2 parents 9609340 + cf4739d commit 1e52156
Show file tree
Hide file tree
Showing 16 changed files with 531 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import io.kroxylicious.proxy.plugin.PluginImplName;
import io.kroxylicious.proxy.plugin.Plugins;

import edu.umd.cs.findbugs.annotations.NonNull;

@Plugin(configType = FetchResponseTransformationFilterFactory.Config.class)
public class FetchResponseTransformationFilterFactory
implements FilterFactory<Config, Config> {
Expand All @@ -27,6 +29,7 @@ public Config initialize(FilterFactoryContext context, Config config) {
return Plugins.requireConfig(this, config);
}

@NonNull
@Override
public FetchResponseTransformationFilter createFilter(FilterFactoryContext context,
Config configuration) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import io.kroxylicious.proxy.plugin.PluginImplName;
import io.kroxylicious.proxy.plugin.Plugins;

import edu.umd.cs.findbugs.annotations.NonNull;

import static org.junit.jupiter.api.Assertions.fail;

@Plugin(configType = ExampleFilterFactory.Config.class)
Expand All @@ -33,6 +35,7 @@ public Config initialize(FilterFactoryContext context, Config config) {
return Plugins.requireConfig(this, config);
}

@NonNull
@Override
public Filter createFilter(FilterFactoryContext context, Config configuration) {
fail("unexpected call");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import io.kroxylicious.proxy.plugin.Plugin;

import edu.umd.cs.findbugs.annotations.NonNull;

@Plugin(configType = Void.class)
public class ApiVersionsMarkingFilterFactory implements FilterFactory<Void, Void> {

Expand All @@ -16,6 +18,7 @@ public Void initialize(FilterFactoryContext context, Void config) {
return null;
}

@NonNull
@Override
public ApiVersionsMarkingFilter createFilter(FilterFactoryContext context, Void configuration) {
return new ApiVersionsMarkingFilter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import io.kroxylicious.proxy.plugin.Plugin;
import io.kroxylicious.proxy.plugin.Plugins;

import edu.umd.cs.findbugs.annotations.NonNull;

@Plugin(configType = FixedClientIdFilterConfig.class)
public class FixedClientIdFilterFactory implements FilterFactory<FixedClientIdFilterConfig, FixedClientIdFilterConfig> {

Expand All @@ -18,6 +20,7 @@ public FixedClientIdFilterConfig initialize(FilterFactoryContext context, FixedC
return Plugins.requireConfig(this, config);
}

@NonNull
@Override
public FixedClientIdFilter createFilter(FilterFactoryContext context, FixedClientIdFilterConfig configuration) {
return new FixedClientIdFilter(configuration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import io.kroxylicious.proxy.plugin.Plugin;
import io.kroxylicious.proxy.plugin.Plugins;

import edu.umd.cs.findbugs.annotations.NonNull;

@Plugin(configType = OutOfBandSendFilterConfig.class)
public class OutOfBandSendFilterFactory implements FilterFactory<OutOfBandSendFilterConfig, OutOfBandSendFilterConfig> {

Expand All @@ -18,6 +20,7 @@ public OutOfBandSendFilterConfig initialize(FilterFactoryContext context, OutOfB
return Plugins.requireConfig(this, config);
}

@NonNull
@Override
public OutOfBandSendFilter createFilter(FilterFactoryContext context, OutOfBandSendFilterConfig configuration) {
return new OutOfBandSendFilter(configuration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import io.kroxylicious.proxy.plugin.Plugin;

import edu.umd.cs.findbugs.annotations.NonNull;

@Plugin(configType = RejectingCreateTopicFilter.RejectingCreateTopicFilterConfig.class)
public class RejectingCreateTopicFilterFactory
implements FilterFactory<RejectingCreateTopicFilter.RejectingCreateTopicFilterConfig, RejectingCreateTopicFilter.RejectingCreateTopicFilterConfig> {
Expand All @@ -19,6 +21,7 @@ public RejectingCreateTopicFilter.RejectingCreateTopicFilterConfig initialize(Fi
return config;
}

@NonNull
@Override
public RejectingCreateTopicFilter createFilter(FilterFactoryContext context, RejectingCreateTopicFilter.RejectingCreateTopicFilterConfig configuration) {
return new RejectingCreateTopicFilter(context, configuration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import io.kroxylicious.proxy.plugin.Plugin;
import io.kroxylicious.proxy.plugin.Plugins;

import edu.umd.cs.findbugs.annotations.NonNull;

@Plugin(configType = RequestResponseMarkingFilter.RequestResponseMarkingFilterConfig.class)
public class RequestResponseMarkingFilterFactory
implements FilterFactory<RequestResponseMarkingFilter.RequestResponseMarkingFilterConfig, RequestResponseMarkingFilter.RequestResponseMarkingFilterConfig> {
Expand All @@ -19,6 +21,7 @@ public RequestResponseMarkingFilter.RequestResponseMarkingFilterConfig initializ
return Plugins.requireConfig(this, config);
}

@NonNull
@Override
public RequestResponseMarkingFilter createFilter(FilterFactoryContext context,
RequestResponseMarkingFilter.RequestResponseMarkingFilterConfig configuration) {
Expand All @@ -27,7 +30,7 @@ public RequestResponseMarkingFilter createFilter(FilterFactoryContext context,

public enum Direction {
REQUEST,
RESPONSE;
RESPONSE
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.kroxylicious.proxy.bootstrap.FilterChainFactory;
import io.kroxylicious.proxy.config.PluginFactoryRegistry;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.NetFilter;
import io.kroxylicious.proxy.internal.codec.KafkaRequestDecoder;
import io.kroxylicious.proxy.internal.codec.KafkaResponseEncoder;
import io.kroxylicious.proxy.internal.filter.ApiVersionsIntersectFilter;
Expand All @@ -39,6 +40,8 @@
import io.kroxylicious.proxy.internal.net.EndpointReconciler;
import io.kroxylicious.proxy.internal.net.VirtualClusterBinding;
import io.kroxylicious.proxy.internal.net.VirtualClusterBindingResolver;
import io.kroxylicious.proxy.model.VirtualCluster;
import io.kroxylicious.proxy.tag.VisibleForTesting;

public class KafkaProxyInitializer extends ChannelInitializer<SocketChannel> {

Expand Down Expand Up @@ -75,74 +78,85 @@ public void initChannel(SocketChannel ch) {
var bindingAddress = ch.parent().localAddress().getAddress().isAnyLocalAddress() ? Optional.<String> empty()
: Optional.of(ch.localAddress().getAddress().getHostAddress());
if (tls) {
LOGGER.debug("Adding SSL/SNI handler");
pipeline.addLast(new SniHandler((sniHostname, promise) -> {
try {
var stage = virtualClusterBindingResolver.resolve(Endpoint.createEndpoint(bindingAddress, targetPort, tls), sniHostname);
// completes the netty promise when then resolution completes (success/otherwise).
var unused = stage.handle((binding, t) -> {
try {
initTlsChannel(ch, pipeline, bindingAddress, targetPort);
}
else {
initPlainChannel(ch, pipeline, bindingAddress, targetPort);
}
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private void initPlainChannel(SocketChannel ch, ChannelPipeline pipeline, Optional<String> bindingAddress, int targetPort) {
pipeline.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
virtualClusterBindingResolver.resolve(Endpoint.createEndpoint(bindingAddress, targetPort, tls), null)
.handle((binding, t) -> {
if (t != null) {
promise.setFailure(t);
ctx.fireExceptionCaught(t);
return null;
}
var virtualCluster = binding.virtualCluster();
var sslContext = virtualCluster.getDownstreamSslContext();
if (sslContext.isEmpty()) {
promise.setFailure(new IllegalStateException("Virtual cluster %s does not provide SSL context".formatted(virtualCluster)));
}
else {
try {
KafkaProxyInitializer.this.addHandlers(ch, binding);
promise.setSuccess(sslContext.get());
ctx.fireChannelActive();
}
}
catch (Throwable t1) {
promise.setFailure(t1);
}
return null;
});
return promise;
}
catch (Throwable cause) {
return promise.setFailure(cause);
}
}) {

@Override
protected void onLookupComplete(ChannelHandlerContext ctx, Future<SslContext> future) throws Exception {
super.onLookupComplete(ctx, future);
ctx.fireChannelActive();
}
});
}
else {
pipeline.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) {
var stage = virtualClusterBindingResolver.resolve(Endpoint.createEndpoint(bindingAddress, targetPort, tls), null);
var unused = stage.handle((binding, t) -> {
catch (Throwable t1) {
ctx.fireExceptionCaught(t1);
}
finally {
pipeline.remove(this);
}
return null;
});
}
});
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private void initTlsChannel(SocketChannel ch, ChannelPipeline pipeline, Optional<String> bindingAddress, int targetPort) {
LOGGER.debug("Adding SSL/SNI handler");
pipeline.addLast(new SniHandler((sniHostname, promise) -> {
try {
var stage = virtualClusterBindingResolver.resolve(Endpoint.createEndpoint(bindingAddress, targetPort, tls), sniHostname);
// completes the netty promise when then resolution completes (success/otherwise).
stage.handle((binding, t) -> {
try {
if (t != null) {
ctx.fireExceptionCaught(t);
promise.setFailure(t);
return null;
}
try {
KafkaProxyInitializer.this.addHandlers(ch, binding);
ctx.fireChannelActive();
}
catch (Throwable t1) {
ctx.fireExceptionCaught(t1);
var virtualCluster = binding.virtualCluster();
var sslContext = virtualCluster.getDownstreamSslContext();
if (sslContext.isEmpty()) {
promise.setFailure(new IllegalStateException("Virtual cluster %s does not provide SSL context".formatted(virtualCluster)));
}
finally {
pipeline.remove(this);
else {
KafkaProxyInitializer.this.addHandlers(ch, binding);
promise.setSuccess(sslContext.get());
}
return null;
});
}
});
}
}
catch (Throwable t1) {
promise.setFailure(t1);
}
return null;
});
return promise;
}
catch (Throwable cause) {
return promise.setFailure(cause);
}
}) {

@Override
protected void onLookupComplete(ChannelHandlerContext ctx, Future<SslContext> future) throws Exception {
super.onLookupComplete(ctx, future);
ctx.fireChannelActive();
}
});
}

private void addHandlers(SocketChannel ch, VirtualClusterBinding binding) {
@VisibleForTesting
void addHandlers(SocketChannel ch, VirtualClusterBinding binding) {
var virtualCluster = binding.virtualCluster();
ChannelPipeline pipeline = ch.pipeline();
if (virtualCluster.isLogNetwork()) {
Expand Down Expand Up @@ -173,8 +187,41 @@ private void addHandlers(SocketChannel ch, VirtualClusterBinding binding) {
}

ApiVersionsServiceImpl apiVersionService = new ApiVersionsServiceImpl();
var frontendHandler = new KafkaProxyFrontendHandler(context -> {
List<FilterAndInvoker> apiVersionFilters = dp.isAuthenticationOffloadEnabled() ? List.of()
final NetFilter netFilter = new InitalizerNetFilter(dp, apiVersionService, ch, binding, pfr, filterChainFactory, endpointReconciler);
var frontendHandler = new KafkaProxyFrontendHandler(netFilter, dp, virtualCluster, apiVersionService);

pipeline.addLast("netHandler", frontendHandler);

LOGGER.debug("{}: Initial pipeline: {}", ch, pipeline);
}

@VisibleForTesting
static class InitalizerNetFilter implements NetFilter {

private final SaslDecodePredicate decodePredicate;
private final ApiVersionsServiceImpl apiVersionService;
private final SocketChannel ch;
private final VirtualCluster virtualCluster;
private final VirtualClusterBinding binding;
private final PluginFactoryRegistry pfr;
private final FilterChainFactory filterChainFactory;
private final EndpointReconciler endpointReconciler;

InitalizerNetFilter(SaslDecodePredicate decodePredicate, ApiVersionsServiceImpl apiVersionService, SocketChannel ch,
VirtualClusterBinding binding, PluginFactoryRegistry pfr, FilterChainFactory filterChainFactory, EndpointReconciler endpointReconciler) {
this.decodePredicate = decodePredicate;
this.apiVersionService = apiVersionService;
this.ch = ch;
this.virtualCluster = binding.virtualCluster();
this.binding = binding;
this.pfr = pfr;
this.filterChainFactory = filterChainFactory;
this.endpointReconciler = endpointReconciler;
}

@Override
public void selectServer(NetFilter.NetFilterContext context) {
List<FilterAndInvoker> apiVersionFilters = decodePredicate.isAuthenticationOffloadEnabled() ? List.of()
: FilterAndInvoker.build(new ApiVersionsIntersectFilter(apiVersionService));

NettyFilterContext filterContext = new NettyFilterContext(ch.eventLoop(), pfr);
Expand All @@ -194,11 +241,6 @@ private void addHandlers(SocketChannel ch, VirtualClusterBinding binding) {
}

context.initiateConnect(target, filters);
}, dp, virtualCluster, apiVersionService);

pipeline.addLast("netHandler", frontendHandler);

LOGGER.debug("{}: Initial pipeline: {}", ch, pipeline);
}
}

}
Loading

0 comments on commit 1e52156

Please sign in to comment.