Skip to content

Commit

Permalink
Don't assume EOF on 0-length read for Node
Browse files Browse the repository at this point in the history
Remove the special handling code that detects EOF now that the upstream crate 
has fixed the bug that was being worked around. This also fixes a bug where EOF 
was being incorrectly detected when the provided buffer was empty. Add a test 
case to prevent regression in the future.
  • Loading branch information
akonradi-signal authored Jul 9, 2024
1 parent 2feac34 commit c6857dd
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//

package org.signal.libsignal.media;

import static org.junit.Assert.assertArrayEquals;

import java.io.ByteArrayInputStream;
import org.junit.Test;
import org.signal.libsignal.internal.Native;

public class InputStreamTest {

@Test
public void testReadIntoEmptyBuffer() {
byte[] data = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes();
assertArrayEquals(
Native.TESTING_InputStreamReadIntoZeroLengthSlice(new ByteArrayInputStream(data)), data);
}
}
1 change: 1 addition & 0 deletions java/shared/java/org/signal/libsignal/internal/Native.java
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ private Native() {}
public static native CompletableFuture<Long> TESTING_FutureProducesPointerType(long asyncRuntime, int input);
public static native CompletableFuture<Integer> TESTING_FutureSuccess(long asyncRuntime, int input);
public static native CompletableFuture<Void> TESTING_FutureThrowsCustomErrorType(long asyncRuntime);
public static native byte[] TESTING_InputStreamReadIntoZeroLengthSlice(InputStream capsAlphabetInput);
public static native void TESTING_NonSuspendingBackgroundThreadRuntime_Destroy(long handle);
public static native CompletableFuture TESTING_OnlyCompletesByCancellation(long asyncRuntime);
public static native String TESTING_OtherTestingHandleType_getValue(long handle);
Expand Down
1 change: 1 addition & 0 deletions node/Native.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ export function TESTING_FutureFailure(asyncRuntime: Wrapper<NonSuspendingBackgro
export function TESTING_FutureProducesOtherPointerType(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: string): Promise<OtherTestingHandleType>;
export function TESTING_FutureProducesPointerType(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<TestingHandleType>;
export function TESTING_FutureSuccess(asyncRuntime: Wrapper<NonSuspendingBackgroundThreadRuntime>, input: number): Promise<number>;
export function TESTING_InputStreamReadIntoZeroLengthSlice(capsAlphabetInput: InputStream): Promise<Buffer>;
export function TESTING_NonSuspendingBackgroundThreadRuntime_New(): NonSuspendingBackgroundThreadRuntime;
export function TESTING_OnlyCompletesByCancellation(asyncRuntime: Wrapper<TokioAsyncContext>): Promise<void>;
export function TESTING_OtherTestingHandleType_getValue(handle: Wrapper<OtherTestingHandleType>): string;
Expand Down
24 changes: 24 additions & 0 deletions node/ts/test/IoTest.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//

import { assert, use } from 'chai';
import * as chaiAsPromised from 'chai-as-promised';

import * as Native from '../../Native';
import { Uint8ArrayInputStream } from './ioutil';

use(chaiAsPromised);

const CAPS_ALPHABET_INPUT = Buffer.from('ABCDEFGHIJKLMNOPQRSTUVWXYZ');

describe('InputStream', () => {
it('handles reads into empty buffers', async () => {
const input = new Uint8ArrayInputStream(CAPS_ALPHABET_INPUT);
const output = await Native.TESTING_InputStreamReadIntoZeroLengthSlice(
input
);
assert.deepEqual(output.compare(CAPS_ALPHABET_INPUT), 0);
});
});
41 changes: 40 additions & 1 deletion rust/bridge/shared/src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
// SPDX-License-Identifier: AGPL-3.0-only
//

use futures_util::FutureExt;
use futures_util::{AsyncReadExt as _, FutureExt};
use io::{AsyncInput, InputStream};
use libsignal_bridge_macros::*;
use libsignal_bridge_types::support::*;
use libsignal_bridge_types::*;
Expand Down Expand Up @@ -228,3 +229,41 @@ fn TESTING_ProcessBytestringArray(input: Vec<&[u8]>) -> Box<[Vec<u8>]> {
.collect::<Vec<Vec<u8>>>()
.into_boxed_slice()
}

#[bridge_fn]
async fn TESTING_InputStreamReadIntoZeroLengthSlice(
caps_alphabet_input: &mut dyn InputStream,
) -> Vec<u8> {
let mut async_input = AsyncInput::new(caps_alphabet_input, 26);
let first = {
let mut buf = [0; 10];
async_input
.read_exact(&mut buf)
.await
.expect("can read first");
buf
};
{
let mut zero_length_array = [0; 0];
assert_eq!(
async_input
.read(&mut zero_length_array)
.await
.expect("can do zero-length read"),
0
);
}
let remainder = {
let mut buf = Vec::with_capacity(16);
async_input
.read_to_end(&mut buf)
.await
.expect("can read to end");
buf
};

assert_eq!(&first, b"ABCDEFGHIJ");
assert_eq!(remainder, b"KLMNOPQRSTUVWXYZ");

first.into_iter().chain(remainder).collect()
}
2 changes: 1 addition & 1 deletion rust/bridge/shared/types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bincode = "1.0"
cfg-if = "1.0"
derive-where = "1.2.5"
displaydoc = "0.2"
futures-util = "0.3.7"
futures-util = "0.3.30"
hex = "0.4.3"
hkdf = "0.12"
hmac = "0.12.0"
Expand Down
19 changes: 4 additions & 15 deletions rust/bridge/shared/types/src/node/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use std::sync::Arc;
pub struct NodeInputStream {
js_channel: Channel,
stream_object: Arc<Root<JsObject>>,
eof_reached: Cell<bool>,
}

pub struct NodeSyncInputStream<'a> {
Expand All @@ -30,7 +29,6 @@ impl NodeInputStream {
Self {
js_channel: cx.channel(),
stream_object: Arc::new(stream.root(cx)),
eof_reached: Default::default(),
}
}

Expand All @@ -52,9 +50,6 @@ impl NodeInputStream {
Err(error) => Err(ThrownException::from_value(cx, error)),
})
.await?;
if read_data.is_empty() {
self.eof_reached.set(true);
}
Ok(read_data)
}

Expand Down Expand Up @@ -94,16 +89,10 @@ impl Finalize for NodeInputStream {
impl InputStream for NodeInputStream {
fn read<'out, 'a: 'out>(&'a self, buf: &mut [u8]) -> IoResult<InputStreamRead<'out>> {
let amount = buf.len() as u32;
if self.eof_reached.get() {
// If we read again after eof was reached, we can end up hitting an unreachable!() in
// futures::io::FillBuf due to issue rust-lang/futures#2727.
Ok(InputStreamRead::Ready { amount_read: 0 })
} else {
Ok(InputStreamRead::Pending(Box::pin(
self.do_read(amount)
.map_err(|err| IoError::new(IoErrorKind::Other, err)),
)))
}
Ok(InputStreamRead::Pending(Box::pin(
self.do_read(amount)
.map_err(|err| IoError::new(IoErrorKind::Other, err)),
)))
}

async fn skip(&self, amount: u64) -> IoResult<()> {
Expand Down
2 changes: 2 additions & 0 deletions swift/Sources/SignalFfi/signal_ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,8 @@ SignalFfiError *signal_testing_return_string_array(SignalStringArray *out);

SignalFfiError *signal_testing_process_bytestring_array(SignalBytestringArray *out, SignalBorrowedSliceOfBuffers input);

SignalFfiError *signal_testing_input_stream_read_into_zero_length_slice(SignalOwnedBuffer *out, const SignalInputStream *caps_alphabet_input);

SignalFfiError *signal_testing_cdsi_lookup_response_convert(SignalCPromiseFfiCdsiLookupResponse *promise, const SignalTokioAsyncContext *async_runtime);

SignalFfiError *signal_testing_only_completes_by_cancellation(SignalCPromisebool *promise, const SignalTokioAsyncContext *async_runtime);
Expand Down
21 changes: 21 additions & 0 deletions swift/Tests/LibSignalClientTests/IoTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//
// Copyright 2024 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//

@testable import LibSignalClient
@testable import SignalFfi
import XCTest

class IoTests: TestCaseBase {
func testReadIntoEmptyBuffer() throws {
let input = [UInt8]("ABCDEFGHIJKLMNOPQRSTUVWXYZ".utf8)
let inputStream = SignalInputStreamAdapter(input)
let output = try withInputStream(inputStream) { input in
try invokeFnReturningArray { output in
SignalFfi.signal_testing_input_stream_read_into_zero_length_slice(output, input)
}
}
XCTAssertEqual(input, output)
}
}

0 comments on commit c6857dd

Please sign in to comment.