Skip to content

Commit

Permalink
Fix missing handshake timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 19, 2023
1 parent 1cbd1ab commit fe3cd98
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
10 changes: 6 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (c *Client) openStream(ctx context.Context) (net.Conn, error) {
if err != nil {
continue
}
stream, err = session.Open()
stream, err = session.OpenContext(ctx)
if err != nil {
continue
}
Expand Down Expand Up @@ -168,6 +168,8 @@ func (c *Client) offer(ctx context.Context) (abstractSession, error) {
}

func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
ctx, cancel := context.WithTimeout(ctx, TCPTimeout)
defer cancel()
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination)
if err != nil {
return nil, err
Expand All @@ -192,7 +194,7 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
return nil, err
}
if c.brutal.Enabled {
err = c.brutalExchange(conn, session)
err = c.brutalExchange(ctx, conn, session)
if err != nil {
conn.Close()
session.Close()
Expand All @@ -203,8 +205,8 @@ func (c *Client) offerNew(ctx context.Context) (abstractSession, error) {
return session, nil
}

func (c *Client) brutalExchange(sessionConn net.Conn, session abstractSession) error {
stream, err := session.Open()
func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error {
stream, err := session.OpenContext(ctx)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions h2mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http
}
}

func (s *h2MuxServerSession) Open() (net.Conn, error) {
func (s *h2MuxServerSession) OpenContext(ctx context.Context) (net.Conn, error) {
return nil, os.ErrInvalid
}

Expand Down Expand Up @@ -197,7 +197,7 @@ func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) {
s.Close()
}

func (s *h2MuxClientSession) Open() (net.Conn, error) {
func (s *h2MuxClientSession) OpenContext(ctx context.Context) (net.Conn, error) {
pipeInReader, pipeInWriter := io.Pipe()
request := &http.Request{
Method: http.MethodConnect,
Expand All @@ -206,7 +206,7 @@ func (s *h2MuxClientSession) Open() (net.Conn, error) {
}
conn := newLateHTTPConn(pipeInWriter)
go func() {
response, err := s.transport.RoundTrip(request)
response, err := s.transport.RoundTrip(request.WithContext(ctx))
if err != nil {
conn.setup(nil, err)
} else if response.StatusCode != 200 {
Expand Down
9 changes: 7 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mux

import (
"context"
"io"
"net"
"reflect"
Expand All @@ -12,7 +13,7 @@ import (
)

type abstractSession interface {
Open() (net.Conn, error)
OpenContext(ctx context.Context) (net.Conn, error)
Accept() (net.Conn, error)
NumStreams() int
Close() error
Expand Down Expand Up @@ -80,7 +81,7 @@ type smuxSession struct {
*smux.Session
}

func (s *smuxSession) Open() (net.Conn, error) {
func (s *smuxSession) OpenContext(context.Context) (net.Conn, error) {
return s.OpenStream()
}

Expand All @@ -96,6 +97,10 @@ type yamuxSession struct {
*yamux.Session
}

func (y *yamuxSession) OpenContext(context.Context) (net.Conn, error) {
return y.OpenStream()
}

func (y *yamuxSession) CanTakeNewRequest() bool {
return true
}
Expand Down

0 comments on commit fe3cd98

Please sign in to comment.