diff --git a/src/Nerdbank.Streams/MultiplexingStream.Channel.cs b/src/Nerdbank.Streams/MultiplexingStream.Channel.cs index 51dd77c7..c8175ae6 100644 --- a/src/Nerdbank.Streams/MultiplexingStream.Channel.cs +++ b/src/Nerdbank.Streams/MultiplexingStream.Channel.cs @@ -1067,10 +1067,6 @@ public override bool TryRead(out ReadResult readResult) public override ValueTask CompleteAsync(Exception? exception = null) => this.inner.CompleteAsync(exception); - public override Task CopyToAsync(PipeWriter destination, CancellationToken cancellationToken = default) => this.inner.CopyToAsync(destination, cancellationToken); - - public override Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default) => this.inner.CopyToAsync(destination, cancellationToken); - [Obsolete] public override void OnWriterCompleted(Action callback, object? state) => this.inner.OnWriterCompleted(callback, state); diff --git a/test/Nerdbank.Streams.Tests/MultiplexingStreamV2Tests.cs b/test/Nerdbank.Streams.Tests/MultiplexingStreamV2Tests.cs index e516f4fb..a0fd8bb9 100644 --- a/test/Nerdbank.Streams.Tests/MultiplexingStreamV2Tests.cs +++ b/test/Nerdbank.Streams.Tests/MultiplexingStreamV2Tests.cs @@ -136,6 +136,28 @@ public async Task Backpressure_ExistingPipe() await writeTask; } + [Fact] + public async Task Backpressure_CopyToAsync() + { + long backpressureThreshold = this.mx1.DefaultChannelReceivingWindowSize; + (MultiplexingStream.Channel a, MultiplexingStream.Channel b) = await this.EstablishChannelsAsync("a"); + + byte[]? hugeChunk = new byte[backpressureThreshold * 2]; // enough to fill the remote and local windows + a.Output.Write(hugeChunk); + Task flushTask = Task.Run(async delegate + { + await a.Output.FlushAsync(this.TimeoutToken); + await a.Output.CompleteAsync(); + }); + + // Now read from the channel and verify it unblocks the writer, using CopyToAsync specifically. + long drainedBytesCount = await this.DrainReaderTillCompletedAsync(b.Input, useCopyToAsync: true); + Assert.Equal(hugeChunk.Length, drainedBytesCount); + + await flushTask.WithCancellation(this.TimeoutToken); + await CompleteChannelsAsync(a, b); + } + /// /// Regression test for #253. /// diff --git a/test/Nerdbank.Streams.Tests/TestBase.cs b/test/Nerdbank.Streams.Tests/TestBase.cs index 8a05a7a5..61293dbb 100644 --- a/test/Nerdbank.Streams.Tests/TestBase.cs +++ b/test/Nerdbank.Streams.Tests/TestBase.cs @@ -146,17 +146,30 @@ public async Task DrainAsync(PipeReader reader, long requiredLength) } } - public async Task DrainReaderTillCompletedAsync(PipeReader reader) + public async Task DrainReaderTillCompletedAsync(PipeReader reader, bool useCopyToAsync = false) { - while (true) + long bytesDrained = 0; + if (useCopyToAsync) { - ReadResult readResult = await reader.ReadAsync(this.TimeoutToken); - reader.AdvanceTo(readResult.Buffer.End); - if (readResult.IsCompleted) + MemoryStream ms = new(); + await reader.CopyToAsync(ms, this.TimeoutToken); + bytesDrained = ms.Length; + } + else + { + while (true) { - break; + ReadResult readResult = await reader.ReadAsync(this.TimeoutToken); + bytesDrained += readResult.Buffer.Length; + reader.AdvanceTo(readResult.Buffer.End); + if (readResult.IsCompleted) + { + break; + } } } + + return bytesDrained; } internal byte[] GetBuffer(int length)