diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index ec201400..3735d13d 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -544,4 +544,14 @@ impl Send { true } } + + pub(super) fn maybe_reset_next_stream_id(&mut self, id: StreamId) { + if let Ok(next_id) = self.next_stream_id { + // Peer::is_local_init should have been called beforehand + debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated()); + if id >= next_id { + self.next_stream_id = id.next_id(); + } + } + } } diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index ac762c8f..4962db8d 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -865,6 +865,24 @@ impl Inner { let key = match self.store.find_entry(id) { Entry::Occupied(e) => e.key(), Entry::Vacant(e) => { + // Resetting a stream we don't know about? That could be OK... + // + // 1. As a server, we just received a request, but that request + // was bad, so we're resetting before even accepting it. + // This is totally fine. + // + // 2. The remote may have sent us a frame on new stream that + // it's *not* supposed to have done, and thus, we don't know + // the stream. In that case, sending a reset will "open" the + // stream in our store. Maybe that should be a connection + // error instead? At least for now, we need to update what + // our vision of the next stream is. + if self.counts.peer().is_local_init(id) { + // We normally would open this stream, so update our + // next-send-id record. + self.actions.send.maybe_reset_next_stream_id(id); + } + let stream = Stream::new(id, 0, 0); e.insert(stream) diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index 91ef4939..f2b2efc1 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -1022,3 +1022,60 @@ async fn srv_window_update_on_lower_stream_id() { }; join(srv, client).await; } + +// See https://github.com/hyperium/h2/issues/570 +#[tokio::test] +async fn reset_new_stream_before_send() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + // Send unexpected headers, that depends on itself, causing a framing error. + srv.send_bytes(&[ + 0, 0, 0x6, // len + 0x1, // type (headers) + 0x25, // flags (eos, eoh, pri) + 0, 0, 0, 0x3, // stream id + 0, 0, 0, 0x3, // dependency + 2, // weight + 0x88, // HPACK :status=200 + ]) + .await; + srv.recv_frame(frames::reset(3).protocol_error()).await; + srv.recv_frame( + frames::headers(5) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(5).response(200).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + let resp = conn + .drive(client.get("https://example.com/")) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // req number 2 + let resp = conn + .drive(client.get("https://example.com/")) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + conn.await.expect("client"); + }; + + join(srv, client).await; +}