From 54c85cb26308b89846d4671f23954dce088da2b0 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Wed, 25 Oct 2023 10:20:49 +0100 Subject: [PATCH] Fix thread-safety issues in TCPThroughputBenchmark (#2537) Motivation: Several thread-safety issues were missed in code review. This patch fixes them. Modifications: - Removed the use of an unstructured Task, replaced with eventLoop.execute to ServerHandler's EventLoop. - Stopped ClientHandler reaching into the benchmark object without any synchronization, used promises and event loop hops instead. Result: Thread safety is back --- .../TCPThroughputBenchmark.swift | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift b/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift index a1b88cabac..aaf3f8acb6 100644 --- a/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift +++ b/Sources/NIOPerformanceTester/TCPThroughputBenchmark.swift @@ -32,21 +32,21 @@ final class TCPThroughputBenchmark: Benchmark { private var clientChannel: Channel! private var message: ByteBuffer! - private var isDonePromise: EventLoopPromise! + private var serverEventLoop: EventLoop! final class ServerHandler: ChannelInboundHandler { public typealias InboundIn = ByteBuffer public typealias OutboundOut = ByteBuffer - private var channel: Channel! + private var context: ChannelHandlerContext! public func channelActive(context: ChannelHandlerContext) { - self.channel = context.channel + self.context = context } public func send(_ message: ByteBuffer, times count: Int) { for _ in 0..? - init(_ benchmark: TCPThroughputBenchmark) { - self.benchmark = benchmark + init() { self.messagesReceived = 0 } + func prepareRun(expectedMessages: Int, promise: EventLoopPromise) { + self.expectedMessages = expectedMessages + self.completionPromise = promise + } + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.messagesReceived += 1 - if (self.benchmark.messages == self.messagesReceived) { - self.benchmark.isDonePromise.succeed() + + if (self.expectedMessages == self.messagesReceived) { + let promise = self.completionPromise + self.messagesReceived = 0 + self.expectedMessages = nil + self.completionPromise = nil + + promise!.succeed() } } } @@ -95,12 +106,12 @@ final class TCPThroughputBenchmark: Benchmark { func setUp() throws { self.group = MultiThreadedEventLoopGroup(numberOfThreads: 4) - let connectionEstablished: EventLoopPromise = self.group.next().makePromise() + let connectionEstablished: EventLoopPromise = self.group.next().makePromise() self.serverChannel = try ServerBootstrap(group: self.group) .childChannelInitializer { channel in self.serverHandler = ServerHandler() - connectionEstablished.succeed() + connectionEstablished.succeed(channel.eventLoop) return channel.pipeline.addHandler(self.serverHandler) } .bind(host: "127.0.0.1", port: 0) @@ -109,7 +120,7 @@ final class TCPThroughputBenchmark: Benchmark { self.clientChannel = try ClientBootstrap(group: group) .channelInitializer { channel in channel.pipeline.addHandler(ByteToMessageHandler(StreamDecoder())).flatMap { _ in - channel.pipeline.addHandler(ClientHandler(self)) + channel.pipeline.addHandler(ClientHandler()) } } .connect(to: serverChannel.localAddress!) @@ -122,7 +133,7 @@ final class TCPThroughputBenchmark: Benchmark { } self.message = message - try connectionEstablished.futureResult.wait() + self.serverEventLoop = try connectionEstablished.futureResult.wait() } func tearDown() { @@ -132,12 +143,22 @@ final class TCPThroughputBenchmark: Benchmark { } func run() throws -> Int { - self.isDonePromise = self.group.next().makePromise() - Task { - self.serverHandler.send(self.message, times: self.messages) + let isDonePromise = self.clientChannel.eventLoop.makePromise(of: Void.self) + let clientChannel = self.clientChannel! + let expectedMessages = self.messages + + try clientChannel.eventLoop.submit { + try clientChannel.pipeline.syncOperations.handler(type: ClientHandler.self).prepareRun(expectedMessages: expectedMessages, promise: isDonePromise) + }.wait() + + let serverHandler = self.serverHandler! + let message = self.message! + let messages = self.messages + + self.serverEventLoop.execute { + serverHandler.send(message, times: messages) } - try self.isDonePromise.futureResult.wait() - self.isDonePromise = nil + try isDonePromise.futureResult.wait() return 0 } }