Skip to content

Commit

Permalink
Merge pull request #89 from nhooyr/sec
Browse files Browse the repository at this point in the history
Add WebSocket masking and correctly use Sec-WebSocket-Key in client
  • Loading branch information
nhooyr authored Jun 4, 2019
2 parents 4130a30 + abcbea0 commit 169cdbc
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 40 deletions.
9 changes: 6 additions & 3 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,15 @@ var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")

func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) {
key := r.Header.Get("Sec-WebSocket-Key")
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
}

func secWebSocketAccept(secWebSocketKey string) string {
h := sha1.New()
h.Write([]byte(key))
h.Write([]byte(secWebSocketKey))
h.Write(keyGUID)

responseKey := base64.StdEncoding.EncodeToString(h.Sum(nil))
w.Header().Set("Sec-WebSocket-Accept", responseKey)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

func authenticateOrigin(r *http.Request) error {
Expand Down
27 changes: 17 additions & 10 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/base64"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"strings"
Expand All @@ -30,11 +31,6 @@ type DialOptions struct {
Subprotocols []string
}

// We use this key for all client requests as the Sec-WebSocket-Key header doesn't do anything.
// See https://stackoverflow.com/a/37074398/4283659.
// We also use the same mask key for every message as it too does not make a difference.
var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16))

// Dial performs a WebSocket handshake on the given url with the given options.
// The response is the WebSocket handshake response from the server.
// If an error occurs, the returned response may be non nil. However, you can only
Expand Down Expand Up @@ -82,7 +78,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey())
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
Expand All @@ -101,7 +97,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
}
}()

err = verifyServerResponse(resp)
err = verifyServerResponse(req, resp)
if err != nil {
return nil, resp, err
}
Expand All @@ -118,12 +114,13 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
closer: rwc,
client: true,
}
c.extractBufioWriterBuf(rwc)
c.init()

return c, resp, nil
}

func verifyServerResponse(resp *http.Response) error {
func verifyServerResponse(r *http.Request, resp *http.Response) error {
if resp.StatusCode != http.StatusSwitchingProtocols {
return xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
Expand All @@ -136,8 +133,12 @@ func verifyServerResponse(resp *http.Response) error {
return xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}

// We do not care about Sec-WebSocket-Accept because it does not matter.
// See the secWebSocketKey global variable.
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) {
return xerrors.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
r.Header.Get("Sec-WebSocket-Key"),
)
}

return nil
}
Expand Down Expand Up @@ -176,3 +177,9 @@ func getBufioWriter(w io.Writer) *bufio.Writer {
func returnBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}

func makeSecWebSocketKey() string {
b := make([]byte, 16)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
20 changes: 19 additions & 1 deletion dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ func Test_verifyServerHandshake(t *testing.T) {
},
success: false,
},
{
name: "badSecWebSocketAccept",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Accept", "xd")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "success",
response: func(w http.ResponseWriter) {
Expand All @@ -58,7 +68,15 @@ func Test_verifyServerHandshake(t *testing.T) {
tc.response(w)
resp := w.Result()

err := verifyServerResponse(resp)
r := httptest.NewRequest("GET", "/", nil)
key := makeSecWebSocketKey()
r.Header.Set("Sec-WebSocket-Key", key)

if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
}

err := verifyServerResponse(r, resp)
if (err == nil) != tc.success {
t.Fatalf("unexpected error: %+v", err)
}
Expand Down
3 changes: 3 additions & 0 deletions export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package websocket

var Compute = handleSecWebSocketKey
112 changes: 87 additions & 25 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package websocket
import (
"bufio"
"context"
cryptorand "crypto/rand"
"fmt"
"io"
"io/ioutil"
Expand All @@ -26,8 +27,11 @@ type Conn struct {
subprotocol string
br *bufio.Reader
bw *bufio.Writer
closer io.Closer
client bool
// writeBuf is used for masking, its the buffer in bufio.Writer.
// Only used by the client.
writeBuf []byte
closer io.Closer
client bool

// read limit for a message in bytes.
msgReadLimit int64
Expand Down Expand Up @@ -581,22 +585,22 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
// See the Writer method if you want to stream a message. The docs on Writer
// regarding concurrency also apply to this method.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
err := c.write(ctx, typ, p)
_, err := c.write(ctx, typ, p)
if err != nil {
return xerrors.Errorf("failed to write msg: %w", err)
}
return nil
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
err := c.acquireLock(ctx, c.writeMsgLock)
if err != nil {
return err
return 0, err
}
defer c.releaseLock(c.writeMsgLock)

err = c.writeFrame(ctx, true, opcode(typ), p)
return err
n, err := c.writeFrame(ctx, true, opcode(typ), p)
return n, err
}

// messageWriter enables writing to a WebSocket connection.
Expand All @@ -620,12 +624,12 @@ func (w *messageWriter) write(p []byte) (int, error) {
if w.closed {
return 0, xerrors.Errorf("cannot use closed writer")
}
err := w.c.writeFrame(w.ctx, false, w.opcode, p)
n, err := w.c.writeFrame(w.ctx, false, w.opcode, p)
if err != nil {
return 0, xerrors.Errorf("failed to write data frame: %w", err)
return n, xerrors.Errorf("failed to write data frame: %w", err)
}
w.opcode = opContinuation
return len(p), nil
return n, nil
}

// Close flushes the frame to the connection.
Expand All @@ -644,7 +648,7 @@ func (w *messageWriter) close() error {
}
w.closed = true

err := w.c.writeFrame(w.ctx, true, w.opcode, nil)
_, err := w.c.writeFrame(w.ctx, true, w.opcode, nil)
if err != nil {
return xerrors.Errorf("failed to write fin frame: %w", err)
}
Expand All @@ -654,34 +658,40 @@ func (w *messageWriter) close() error {
}

func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
err := c.writeFrame(ctx, true, opcode, p)
_, err := c.writeFrame(ctx, true, opcode, p)
if err != nil {
return xerrors.Errorf("failed to write control frame: %w", err)
}
return nil
}

// writeFrame handles all writes to the connection.
// We never mask inside here because our mask key is always 0,0,0,0.
// See comment on secWebSocketKey for why.
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error {
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
h := header{
fin: fin,
opcode: opcode,
masked: c.client,
payloadLength: int64(len(p)),
}

if c.client {
_, err := io.ReadFull(cryptorand.Reader, h.maskKey[:])
if err != nil {
return 0, xerrors.Errorf("failed to generate masking key: %w", err)
}
}

b2 := marshalHeader(h)

err := c.acquireLock(ctx, c.writeFrameLock)
if err != nil {
return err
return 0, err
}
defer c.releaseLock(c.writeFrameLock)

select {
case <-c.closed:
return c.closeErr
return 0, c.closeErr
case c.setWriteTimeout <- ctx:
}

Expand All @@ -705,29 +715,61 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte

_, err = c.bw.Write(b2)
if err != nil {
return writeErr(err)
}
_, err = c.bw.Write(p)
if err != nil {
return writeErr(err)
return 0, writeErr(err)
}

var n int
if c.client {
var keypos int
for len(p) > 0 {
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, writeErr(err)
}
}

// Start of next write in the buffer.
i := c.bw.Buffered()

p2 := p
if len(p) > c.bw.Available() {
p2 = p[:c.bw.Available()]
}

n2, err := c.bw.Write(p2)
if err != nil {
return n, writeErr(err)
}

keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])

p = p[n2:]
n += n2
}
} else {
n, err = c.bw.Write(p)
if err != nil {
return n, writeErr(err)
}
}

if fin {
err = c.bw.Flush()
if err != nil {
return writeErr(err)
return n, writeErr(err)
}
}

// We already finished writing, no need to potentially brick the connection if
// the context expires.
select {
case <-c.closed:
return c.closeErr
return n, c.closeErr
case c.setWriteTimeout <- context.Background():
}

return nil
return n, nil
}

func (c *Conn) writePong(p []byte) error {
Expand Down Expand Up @@ -842,3 +884,23 @@ func (c *Conn) ping(ctx context.Context) error {
return nil
}
}

type writerFunc func(p []byte) (int, error)

func (f writerFunc) Write(p []byte) (int, error) {
return f(p)
}

// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and stores it in c.writeBuf.
func (c *Conn) extractBufioWriterBuf(w io.Writer) {
c.bw.Reset(writerFunc(func(p2 []byte) (int, error) {
c.writeBuf = p2[:cap(p2)]
return len(p2), nil
}))

c.bw.WriteByte(0)
c.bw.Flush()

c.bw.Reset(w)
}
1 change: 0 additions & 1 deletion websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ func TestHandshake(t *testing.T) {

checkHeader("Connection", "Upgrade")
checkHeader("Upgrade", "websocket")
checkHeader("Sec-WebSocket-Accept", "ICX+Yqv66kxgM0FcWaLWlFLwTAI=")
checkHeader("Sec-WebSocket-Protocol", "myproto")

c.Close(websocket.StatusNormalClosure, "")
Expand Down

0 comments on commit 169cdbc

Please sign in to comment.