From c6857dd58efa68d3df42a1f05a74bbc6c7038b3b Mon Sep 17 00:00:00 2001 From: Alex Konradi Date: Tue, 9 Jul 2024 11:58:43 -0400 Subject: [PATCH] Don't assume EOF on 0-length read for Node Remove the special handling code that detects EOF now that the upstream crate has fixed the bug that was being worked around. This also fixes a bug where EOF was being incorrectly detected when the provided buffer was empty. Add a test case to prevent regression in the future. --- .../libsignal/media/InputStreamTest.java | 22 ++++++++++ .../org/signal/libsignal/internal/Native.java | 1 + node/Native.d.ts | 1 + node/ts/test/IoTest.ts | 24 +++++++++++ rust/bridge/shared/src/testing/mod.rs | 41 ++++++++++++++++++- rust/bridge/shared/types/Cargo.toml | 2 +- rust/bridge/shared/types/src/node/io.rs | 19 ++------- swift/Sources/SignalFfi/signal_ffi.h | 2 + .../Tests/LibSignalClientTests/IoTests.swift | 21 ++++++++++ 9 files changed, 116 insertions(+), 17 deletions(-) create mode 100644 java/client/src/test/java/org/signal/libsignal/media/InputStreamTest.java create mode 100644 node/ts/test/IoTest.ts create mode 100644 swift/Tests/LibSignalClientTests/IoTests.swift diff --git a/java/client/src/test/java/org/signal/libsignal/media/InputStreamTest.java b/java/client/src/test/java/org/signal/libsignal/media/InputStreamTest.java new file mode 100644 index 000000000..27553ba3b --- /dev/null +++ b/java/client/src/test/java/org/signal/libsignal/media/InputStreamTest.java @@ -0,0 +1,22 @@ +// +// Copyright 2024 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +package org.signal.libsignal.media; + +import static org.junit.Assert.assertArrayEquals; + +import java.io.ByteArrayInputStream; +import org.junit.Test; +import org.signal.libsignal.internal.Native; + +public class InputStreamTest { + + @Test + public void testReadIntoEmptyBuffer() { + byte[] data = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(); + assertArrayEquals( + Native.TESTING_InputStreamReadIntoZeroLengthSlice(new ByteArrayInputStream(data)), data); + } +} diff --git a/java/shared/java/org/signal/libsignal/internal/Native.java b/java/shared/java/org/signal/libsignal/internal/Native.java index 2fc9fffa0..67f62aa41 100644 --- a/java/shared/java/org/signal/libsignal/internal/Native.java +++ b/java/shared/java/org/signal/libsignal/internal/Native.java @@ -639,6 +639,7 @@ private Native() {} public static native CompletableFuture TESTING_FutureProducesPointerType(long asyncRuntime, int input); public static native CompletableFuture TESTING_FutureSuccess(long asyncRuntime, int input); public static native CompletableFuture TESTING_FutureThrowsCustomErrorType(long asyncRuntime); + public static native byte[] TESTING_InputStreamReadIntoZeroLengthSlice(InputStream capsAlphabetInput); public static native void TESTING_NonSuspendingBackgroundThreadRuntime_Destroy(long handle); public static native CompletableFuture TESTING_OnlyCompletesByCancellation(long asyncRuntime); public static native String TESTING_OtherTestingHandleType_getValue(long handle); diff --git a/node/Native.d.ts b/node/Native.d.ts index 358248136..463a57d67 100644 --- a/node/Native.d.ts +++ b/node/Native.d.ts @@ -505,6 +505,7 @@ export function TESTING_FutureFailure(asyncRuntime: Wrapper, input: string): Promise; export function TESTING_FutureProducesPointerType(asyncRuntime: Wrapper, input: number): Promise; export function TESTING_FutureSuccess(asyncRuntime: Wrapper, input: number): Promise; +export function TESTING_InputStreamReadIntoZeroLengthSlice(capsAlphabetInput: InputStream): Promise; export function TESTING_NonSuspendingBackgroundThreadRuntime_New(): NonSuspendingBackgroundThreadRuntime; export function TESTING_OnlyCompletesByCancellation(asyncRuntime: Wrapper): Promise; export function TESTING_OtherTestingHandleType_getValue(handle: Wrapper): string; diff --git a/node/ts/test/IoTest.ts b/node/ts/test/IoTest.ts new file mode 100644 index 000000000..3951e53c8 --- /dev/null +++ b/node/ts/test/IoTest.ts @@ -0,0 +1,24 @@ +// +// Copyright 2024 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +import { assert, use } from 'chai'; +import * as chaiAsPromised from 'chai-as-promised'; + +import * as Native from '../../Native'; +import { Uint8ArrayInputStream } from './ioutil'; + +use(chaiAsPromised); + +const CAPS_ALPHABET_INPUT = Buffer.from('ABCDEFGHIJKLMNOPQRSTUVWXYZ'); + +describe('InputStream', () => { + it('handles reads into empty buffers', async () => { + const input = new Uint8ArrayInputStream(CAPS_ALPHABET_INPUT); + const output = await Native.TESTING_InputStreamReadIntoZeroLengthSlice( + input + ); + assert.deepEqual(output.compare(CAPS_ALPHABET_INPUT), 0); + }); +}); diff --git a/rust/bridge/shared/src/testing/mod.rs b/rust/bridge/shared/src/testing/mod.rs index cecd82862..82a58bac1 100644 --- a/rust/bridge/shared/src/testing/mod.rs +++ b/rust/bridge/shared/src/testing/mod.rs @@ -3,7 +3,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // -use futures_util::FutureExt; +use futures_util::{AsyncReadExt as _, FutureExt}; +use io::{AsyncInput, InputStream}; use libsignal_bridge_macros::*; use libsignal_bridge_types::support::*; use libsignal_bridge_types::*; @@ -228,3 +229,41 @@ fn TESTING_ProcessBytestringArray(input: Vec<&[u8]>) -> Box<[Vec]> { .collect::>>() .into_boxed_slice() } + +#[bridge_fn] +async fn TESTING_InputStreamReadIntoZeroLengthSlice( + caps_alphabet_input: &mut dyn InputStream, +) -> Vec { + let mut async_input = AsyncInput::new(caps_alphabet_input, 26); + let first = { + let mut buf = [0; 10]; + async_input + .read_exact(&mut buf) + .await + .expect("can read first"); + buf + }; + { + let mut zero_length_array = [0; 0]; + assert_eq!( + async_input + .read(&mut zero_length_array) + .await + .expect("can do zero-length read"), + 0 + ); + } + let remainder = { + let mut buf = Vec::with_capacity(16); + async_input + .read_to_end(&mut buf) + .await + .expect("can read to end"); + buf + }; + + assert_eq!(&first, b"ABCDEFGHIJ"); + assert_eq!(remainder, b"KLMNOPQRSTUVWXYZ"); + + first.into_iter().chain(remainder).collect() +} diff --git a/rust/bridge/shared/types/Cargo.toml b/rust/bridge/shared/types/Cargo.toml index c1c85f154..b7149a37a 100644 --- a/rust/bridge/shared/types/Cargo.toml +++ b/rust/bridge/shared/types/Cargo.toml @@ -33,7 +33,7 @@ bincode = "1.0" cfg-if = "1.0" derive-where = "1.2.5" displaydoc = "0.2" -futures-util = "0.3.7" +futures-util = "0.3.30" hex = "0.4.3" hkdf = "0.12" hmac = "0.12.0" diff --git a/rust/bridge/shared/types/src/node/io.rs b/rust/bridge/shared/types/src/node/io.rs index 8db7ca0c4..4087173d8 100644 --- a/rust/bridge/shared/types/src/node/io.rs +++ b/rust/bridge/shared/types/src/node/io.rs @@ -17,7 +17,6 @@ use std::sync::Arc; pub struct NodeInputStream { js_channel: Channel, stream_object: Arc>, - eof_reached: Cell, } pub struct NodeSyncInputStream<'a> { @@ -30,7 +29,6 @@ impl NodeInputStream { Self { js_channel: cx.channel(), stream_object: Arc::new(stream.root(cx)), - eof_reached: Default::default(), } } @@ -52,9 +50,6 @@ impl NodeInputStream { Err(error) => Err(ThrownException::from_value(cx, error)), }) .await?; - if read_data.is_empty() { - self.eof_reached.set(true); - } Ok(read_data) } @@ -94,16 +89,10 @@ impl Finalize for NodeInputStream { impl InputStream for NodeInputStream { fn read<'out, 'a: 'out>(&'a self, buf: &mut [u8]) -> IoResult> { let amount = buf.len() as u32; - if self.eof_reached.get() { - // If we read again after eof was reached, we can end up hitting an unreachable!() in - // futures::io::FillBuf due to issue rust-lang/futures#2727. - Ok(InputStreamRead::Ready { amount_read: 0 }) - } else { - Ok(InputStreamRead::Pending(Box::pin( - self.do_read(amount) - .map_err(|err| IoError::new(IoErrorKind::Other, err)), - ))) - } + Ok(InputStreamRead::Pending(Box::pin( + self.do_read(amount) + .map_err(|err| IoError::new(IoErrorKind::Other, err)), + ))) } async fn skip(&self, amount: u64) -> IoResult<()> { diff --git a/swift/Sources/SignalFfi/signal_ffi.h b/swift/Sources/SignalFfi/signal_ffi.h index 0e5e893d8..4dfc34736 100644 --- a/swift/Sources/SignalFfi/signal_ffi.h +++ b/swift/Sources/SignalFfi/signal_ffi.h @@ -1748,6 +1748,8 @@ SignalFfiError *signal_testing_return_string_array(SignalStringArray *out); SignalFfiError *signal_testing_process_bytestring_array(SignalBytestringArray *out, SignalBorrowedSliceOfBuffers input); +SignalFfiError *signal_testing_input_stream_read_into_zero_length_slice(SignalOwnedBuffer *out, const SignalInputStream *caps_alphabet_input); + SignalFfiError *signal_testing_cdsi_lookup_response_convert(SignalCPromiseFfiCdsiLookupResponse *promise, const SignalTokioAsyncContext *async_runtime); SignalFfiError *signal_testing_only_completes_by_cancellation(SignalCPromisebool *promise, const SignalTokioAsyncContext *async_runtime); diff --git a/swift/Tests/LibSignalClientTests/IoTests.swift b/swift/Tests/LibSignalClientTests/IoTests.swift new file mode 100644 index 000000000..c870827ca --- /dev/null +++ b/swift/Tests/LibSignalClientTests/IoTests.swift @@ -0,0 +1,21 @@ +// +// Copyright 2024 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +@testable import LibSignalClient +@testable import SignalFfi +import XCTest + +class IoTests: TestCaseBase { + func testReadIntoEmptyBuffer() throws { + let input = [UInt8]("ABCDEFGHIJKLMNOPQRSTUVWXYZ".utf8) + let inputStream = SignalInputStreamAdapter(input) + let output = try withInputStream(inputStream) { input in + try invokeFnReturningArray { output in + SignalFfi.signal_testing_input_stream_read_into_zero_length_slice(output, input) + } + } + XCTAssertEqual(input, output) + } +}