diff --git a/accept.go b/accept.go index 6e214111..d3ba3258 100644 --- a/accept.go +++ b/accept.go @@ -165,12 +165,15 @@ var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { key := r.Header.Get("Sec-WebSocket-Key") + w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) +} + +func secWebSocketAccept(secWebSocketKey string) string { h := sha1.New() - h.Write([]byte(key)) + h.Write([]byte(secWebSocketKey)) h.Write(keyGUID) - responseKey := base64.StdEncoding.EncodeToString(h.Sum(nil)) - w.Header().Set("Sec-WebSocket-Accept", responseKey) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) } func authenticateOrigin(r *http.Request) error { diff --git a/dial.go b/dial.go index 64d2820d..1983f89a 100644 --- a/dial.go +++ b/dial.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "io" "io/ioutil" + "math/rand" "net/http" "net/url" "strings" @@ -30,11 +31,6 @@ type DialOptions struct { Subprotocols []string } -// We use this key for all client requests as the Sec-WebSocket-Key header doesn't do anything. -// See https://stackoverflow.com/a/37074398/4283659. -// We also use the same mask key for every message as it too does not make a difference. -var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16)) - // Dial performs a WebSocket handshake on the given url with the given options. // The response is the WebSocket handshake response from the server. // If an error occurs, the returned response may be non nil. However, you can only @@ -82,7 +78,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") req.Header.Set("Sec-WebSocket-Version", "13") - req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) + req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey()) if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } @@ -101,7 +97,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res } }() - err = verifyServerResponse(resp) + err = verifyServerResponse(req, resp) if err != nil { return nil, resp, err } @@ -118,12 +114,13 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res closer: rwc, client: true, } + c.extractBufioWriterBuf(rwc) c.init() return c, resp, nil } -func verifyServerResponse(resp *http.Response) error { +func verifyServerResponse(r *http.Request, resp *http.Response) error { if resp.StatusCode != http.StatusSwitchingProtocols { return xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } @@ -136,8 +133,12 @@ func verifyServerResponse(resp *http.Response) error { return xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } - // We do not care about Sec-WebSocket-Accept because it does not matter. - // See the secWebSocketKey global variable. + if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { + return xerrors.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + resp.Header.Get("Sec-WebSocket-Accept"), + r.Header.Get("Sec-WebSocket-Key"), + ) + } return nil } @@ -176,3 +177,9 @@ func getBufioWriter(w io.Writer) *bufio.Writer { func returnBufioWriter(bw *bufio.Writer) { bufioWriterPool.Put(bw) } + +func makeSecWebSocketKey() string { + b := make([]byte, 16) + rand.Read(b) + return base64.StdEncoding.EncodeToString(b) +} diff --git a/dial_test.go b/dial_test.go index 02aaa4fc..6f0deef9 100644 --- a/dial_test.go +++ b/dial_test.go @@ -38,6 +38,16 @@ func Test_verifyServerHandshake(t *testing.T) { }, success: false, }, + { + name: "badSecWebSocketAccept", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Accept", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, { name: "success", response: func(w http.ResponseWriter) { @@ -58,7 +68,15 @@ func Test_verifyServerHandshake(t *testing.T) { tc.response(w) resp := w.Result() - err := verifyServerResponse(resp) + r := httptest.NewRequest("GET", "/", nil) + key := makeSecWebSocketKey() + r.Header.Set("Sec-WebSocket-Key", key) + + if resp.Header.Get("Sec-WebSocket-Accept") == "" { + resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + } + + err := verifyServerResponse(r, resp) if (err == nil) != tc.success { t.Fatalf("unexpected error: %+v", err) } diff --git a/export_test.go b/export_test.go new file mode 100644 index 00000000..22ad76fc --- /dev/null +++ b/export_test.go @@ -0,0 +1,3 @@ +package websocket + +var Compute = handleSecWebSocketKey diff --git a/websocket.go b/websocket.go index 71d505f5..50744326 100644 --- a/websocket.go +++ b/websocket.go @@ -3,6 +3,7 @@ package websocket import ( "bufio" "context" + cryptorand "crypto/rand" "fmt" "io" "io/ioutil" @@ -26,8 +27,11 @@ type Conn struct { subprotocol string br *bufio.Reader bw *bufio.Writer - closer io.Closer - client bool + // writeBuf is used for masking, its the buffer in bufio.Writer. + // Only used by the client. + writeBuf []byte + closer io.Closer + client bool // read limit for a message in bytes. msgReadLimit int64 @@ -581,22 +585,22 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // See the Writer method if you want to stream a message. The docs on Writer // regarding concurrency also apply to this method. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - err := c.write(ctx, typ, p) + _, err := c.write(ctx, typ, p) if err != nil { return xerrors.Errorf("failed to write msg: %w", err) } return nil } -func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { err := c.acquireLock(ctx, c.writeMsgLock) if err != nil { - return err + return 0, err } defer c.releaseLock(c.writeMsgLock) - err = c.writeFrame(ctx, true, opcode(typ), p) - return err + n, err := c.writeFrame(ctx, true, opcode(typ), p) + return n, err } // messageWriter enables writing to a WebSocket connection. @@ -620,12 +624,12 @@ func (w *messageWriter) write(p []byte) (int, error) { if w.closed { return 0, xerrors.Errorf("cannot use closed writer") } - err := w.c.writeFrame(w.ctx, false, w.opcode, p) + n, err := w.c.writeFrame(w.ctx, false, w.opcode, p) if err != nil { - return 0, xerrors.Errorf("failed to write data frame: %w", err) + return n, xerrors.Errorf("failed to write data frame: %w", err) } w.opcode = opContinuation - return len(p), nil + return n, nil } // Close flushes the frame to the connection. @@ -644,7 +648,7 @@ func (w *messageWriter) close() error { } w.closed = true - err := w.c.writeFrame(w.ctx, true, w.opcode, nil) + _, err := w.c.writeFrame(w.ctx, true, w.opcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) } @@ -654,7 +658,7 @@ func (w *messageWriter) close() error { } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { - err := c.writeFrame(ctx, true, opcode, p) + _, err := c.writeFrame(ctx, true, opcode, p) if err != nil { return xerrors.Errorf("failed to write control frame: %w", err) } @@ -662,26 +666,32 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } // writeFrame handles all writes to the connection. -// We never mask inside here because our mask key is always 0,0,0,0. -// See comment on secWebSocketKey for why. -func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error { +func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { h := header{ fin: fin, opcode: opcode, masked: c.client, payloadLength: int64(len(p)), } + + if c.client { + _, err := io.ReadFull(cryptorand.Reader, h.maskKey[:]) + if err != nil { + return 0, xerrors.Errorf("failed to generate masking key: %w", err) + } + } + b2 := marshalHeader(h) err := c.acquireLock(ctx, c.writeFrameLock) if err != nil { - return err + return 0, err } defer c.releaseLock(c.writeFrameLock) select { case <-c.closed: - return c.closeErr + return 0, c.closeErr case c.setWriteTimeout <- ctx: } @@ -705,17 +715,49 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte _, err = c.bw.Write(b2) if err != nil { - return writeErr(err) - } - _, err = c.bw.Write(p) - if err != nil { - return writeErr(err) + return 0, writeErr(err) + } + + var n int + if c.client { + var keypos int + for len(p) > 0 { + if c.bw.Available() == 0 { + err = c.bw.Flush() + if err != nil { + return n, writeErr(err) + } + } + + // Start of next write in the buffer. + i := c.bw.Buffered() + + p2 := p + if len(p) > c.bw.Available() { + p2 = p[:c.bw.Available()] + } + + n2, err := c.bw.Write(p2) + if err != nil { + return n, writeErr(err) + } + + keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2]) + + p = p[n2:] + n += n2 + } + } else { + n, err = c.bw.Write(p) + if err != nil { + return n, writeErr(err) + } } if fin { err = c.bw.Flush() if err != nil { - return writeErr(err) + return n, writeErr(err) } } @@ -723,11 +765,11 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte // the context expires. select { case <-c.closed: - return c.closeErr + return n, c.closeErr case c.setWriteTimeout <- context.Background(): } - return nil + return n, nil } func (c *Conn) writePong(p []byte) error { @@ -842,3 +884,23 @@ func (c *Conn) ping(ctx context.Context) error { return nil } } + +type writerFunc func(p []byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) +} + +// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer +// and stores it in c.writeBuf. +func (c *Conn) extractBufioWriterBuf(w io.Writer) { + c.bw.Reset(writerFunc(func(p2 []byte) (int, error) { + c.writeBuf = p2[:cap(p2)] + return len(p2), nil + })) + + c.bw.WriteByte(0) + c.bw.Flush() + + c.bw.Reset(w) +} diff --git a/websocket_test.go b/websocket_test.go index 00e510c8..9d867b50 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -68,7 +68,6 @@ func TestHandshake(t *testing.T) { checkHeader("Connection", "Upgrade") checkHeader("Upgrade", "websocket") - checkHeader("Sec-WebSocket-Accept", "ICX+Yqv66kxgM0FcWaLWlFLwTAI=") checkHeader("Sec-WebSocket-Protocol", "myproto") c.Close(websocket.StatusNormalClosure, "")