From fe3cd98428043fa3c4483d53a363a1df8fe6fc2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 19 Dec 2023 23:16:01 +0800 Subject: [PATCH] Fix missing handshake timeout --- client.go | 10 ++++++---- h2mux.go | 6 +++--- session.go | 9 +++++++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 5fc126a..519f7cd 100644 --- a/client.go +++ b/client.go @@ -113,7 +113,7 @@ func (c *Client) openStream(ctx context.Context) (net.Conn, error) { if err != nil { continue } - stream, err = session.Open() + stream, err = session.OpenContext(ctx) if err != nil { continue } @@ -168,6 +168,8 @@ func (c *Client) offer(ctx context.Context) (abstractSession, error) { } func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { + ctx, cancel := context.WithTimeout(ctx, TCPTimeout) + defer cancel() conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination) if err != nil { return nil, err @@ -192,7 +194,7 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { return nil, err } if c.brutal.Enabled { - err = c.brutalExchange(conn, session) + err = c.brutalExchange(ctx, conn, session) if err != nil { conn.Close() session.Close() @@ -203,8 +205,8 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { return session, nil } -func (c *Client) brutalExchange(sessionConn net.Conn, session abstractSession) error { - stream, err := session.Open() +func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error { + stream, err := session.OpenContext(ctx) if err != nil { return err } diff --git a/h2mux.go b/h2mux.go index a67ef70..21f16cb 100644 --- a/h2mux.go +++ b/h2mux.go @@ -64,7 +64,7 @@ func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http } } -func (s *h2MuxServerSession) Open() (net.Conn, error) { +func (s *h2MuxServerSession) OpenContext(ctx context.Context) (net.Conn, error) { return nil, os.ErrInvalid } @@ -197,7 +197,7 @@ func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) { s.Close() } -func (s *h2MuxClientSession) Open() (net.Conn, error) { +func (s *h2MuxClientSession) OpenContext(ctx context.Context) (net.Conn, error) { pipeInReader, pipeInWriter := io.Pipe() request := &http.Request{ Method: http.MethodConnect, @@ -206,7 +206,7 @@ func (s *h2MuxClientSession) Open() (net.Conn, error) { } conn := newLateHTTPConn(pipeInWriter) go func() { - response, err := s.transport.RoundTrip(request) + response, err := s.transport.RoundTrip(request.WithContext(ctx)) if err != nil { conn.setup(nil, err) } else if response.StatusCode != 200 { diff --git a/session.go b/session.go index 2dc37b3..524f573 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package mux import ( + "context" "io" "net" "reflect" @@ -12,7 +13,7 @@ import ( ) type abstractSession interface { - Open() (net.Conn, error) + OpenContext(ctx context.Context) (net.Conn, error) Accept() (net.Conn, error) NumStreams() int Close() error @@ -80,7 +81,7 @@ type smuxSession struct { *smux.Session } -func (s *smuxSession) Open() (net.Conn, error) { +func (s *smuxSession) OpenContext(context.Context) (net.Conn, error) { return s.OpenStream() } @@ -96,6 +97,10 @@ type yamuxSession struct { *yamux.Session } +func (y *yamuxSession) OpenContext(context.Context) (net.Conn, error) { + return y.OpenStream() +} + func (y *yamuxSession) CanTakeNewRequest() bool { return true }