From 76d3a107fd6719a1e6ccdea3a4d4c1300a976752 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 31 May 2019 18:25:27 -0400 Subject: [PATCH 1/5] Polish for stable release --- .gitignore | 3 -- README.md | 31 ++++++++++---- ci/bench/entrypoint.sh | 14 +++---- ci/lint/entrypoint.sh | 2 +- ci/out/.gitignore | 1 + ci/test/entrypoint.sh | 12 +++--- example_echo_test.go | 1 - limitedreader.go | 34 +++++++++++++++ websocket.go | 94 ++++++++++++++++++++++++------------------ websocket_test.go | 49 +++++++++++++++++++--- wsjson/wsjson.go | 23 +++-------- wspb/wspb.go | 10 +---- 12 files changed, 175 insertions(+), 99 deletions(-) delete mode 100644 .gitignore create mode 100644 ci/out/.gitignore create mode 100644 limitedreader.go diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 7fffaa26..00000000 --- a/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -wstest_reports -websocket.test -profs diff --git a/README.md b/README.md index 9c44e4e1..afd5abba 100644 --- a/README.md +++ b/README.md @@ -5,23 +5,25 @@ websocket is a minimal and idiomatic WebSocket library for Go. -This library is not final and the API is subject to change. +This library is now production ready but some parts of the API are marked as experimental. + +Please feel free to open an issue for feedback. ## Install ```bash -go get nhooyr.io/websocket@v0.2.0 +go get nhooyr.io/websocket ``` ## Features - Minimal and idiomatic API -- Tiny codebase at 1400 lines +- Tiny codebase at 1700 lines - First class context.Context support - Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Zero dependencies outside of the stdlib for the core library - JSON and ProtoBuf helpers in the wsjson and wspb subpackages -- High performance +- High performance, memory reuse wherever possible - Concurrent reads and writes out of the box ## Roadmap @@ -88,8 +90,9 @@ c.Close(websocket.StatusNormalClosure, "") - net.Conn is never exposed as WebSocket over HTTP/2 will not have a net.Conn. - Using net/http's Client for dialing means we do not have to reinvent dialing hooks and configurations like other WebSocket libraries -- We do not support the compression extension because Go's compress/flate library is very memory intensive - and browsers do not handle WebSocket compression intelligently. See [#5](https://github.com/nhooyr/websocket/issues/5) +- We do not support the deflate compression extension because Go's compress/flate library + is very memory intensive and browsers do not handle WebSocket compression intelligently. + See [#5](https://github.com/nhooyr/websocket/issues/5) ## Comparison @@ -111,7 +114,7 @@ Just compare the godoc of The API for nhooyr/websocket has been designed such that there is only one way to do things which makes it easy to use correctly. Not only is the API simpler, the implementation is -only 1400 lines whereas gorilla/websocket is at 3500 lines. That's more code to maintain, +only 1700 lines whereas gorilla/websocket is at 3500 lines. That's more code to maintain, more code to test, more code to document and more surface area for bugs. The future of gorilla/websocket is also uncertain. See [gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370). @@ -124,8 +127,18 @@ it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. Some more advantages of nhooyr/websocket are that it supports concurrent reads, writes and makes it very easy to close the connection with a status code and reason. -In terms of performance, the only difference is nhooyr/websocket is forced to use one extra -goroutine for context.Context support. Otherwise, they perform identically. +nhooyr/websocket also responds to pings, pongs and close frames in a separate goroutine so that +your application doesn't always need to read from the connection unless it expects a data message. +gorilla/websocket requires you to constantly read from the connection to respond to control frames +even if you don't expect the peer to send any messages. + +In terms of performance, the differences depend on your application code. nhooyr/websocket +reuses buffers efficiently out of the box whereas gorilla/websocket does not. As mentioned +above, it also supports concurrent readers and writers out of the box. + +The only performance downside to nhooyr/websocket is that uses two extra goroutines. One for +reading pings, pongs and close frames async to application code and another to support +context.Context cancellation. This costs 4 KB of memory which is fairly cheap. ### x/net/websocket diff --git a/ci/bench/entrypoint.sh b/ci/bench/entrypoint.sh index 5f7dcf73..a8350c9d 100755 --- a/ci/bench/entrypoint.sh +++ b/ci/bench/entrypoint.sh @@ -2,16 +2,14 @@ source ci/lib.sh || exit 1 -mkdir -p profs - -go test --vet=off --run=^$ -bench=. \ - -cpuprofile=profs/cpu \ - -memprofile=profs/mem \ - -blockprofile=profs/block \ - -mutexprofile=profs/mutex \ +go test --vet=off --run=^$ -bench=. -o=ci/out/websocket.test \ + -cpuprofile=ci/out/cpu.prof \ + -memprofile=ci/out/mem.prof \ + -blockprofile=ci/out/block.prof \ + -mutexprofile=ci/out/mutex.prof \ . set +x echo -echo "profiles are in ./profs +echo "profiles are in ./ci/out/*.prof keep in mind that every profiler Go provides is enabled so that may skew the benchmarks" diff --git a/ci/lint/entrypoint.sh b/ci/lint/entrypoint.sh index 09c31683..62f74022 100755 --- a/ci/lint/entrypoint.sh +++ b/ci/lint/entrypoint.sh @@ -7,5 +7,5 @@ source ci/lib.sh || exit 1 shellcheck ./**/*.sh ) -go vet -composites=false -lostcancel=false ./... +go vet ./... go run golang.org/x/lint/golint -set_exit_status ./... diff --git a/ci/out/.gitignore b/ci/out/.gitignore new file mode 100644 index 00000000..72e8ffc0 --- /dev/null +++ b/ci/out/.gitignore @@ -0,0 +1 @@ +* diff --git a/ci/test/entrypoint.sh b/ci/test/entrypoint.sh index 2a39593f..c9a0e80a 100755 --- a/ci/test/entrypoint.sh +++ b/ci/test/entrypoint.sh @@ -2,8 +2,6 @@ source ci/lib.sh || exit 1 -mkdir -p profs - set +x echo echo "this step includes benchmarks for race detection and coverage purposes @@ -12,15 +10,15 @@ accurate numbers" echo set -x -go test -race -coverprofile=profs/coverage --vet=off -bench=. ./... -go tool cover -func=profs/coverage +go test -race -coverprofile=ci/out/coverage.prof --vet=off -bench=. ./... +go tool cover -func=ci/out/coverage.prof if [[ $CI ]]; then - bash <(curl -s https://codecov.io/bash) -f profs/coverage + bash <(curl -s https://codecov.io/bash) -f ci/out/coverage.prof else - go tool cover -html=profs/coverage -o=profs/coverage.html + go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html set +x echo - echo "please open profs/coverage.html to see detailed test coverage stats" + echo "please open ci/out/coverage.html to see detailed test coverage stats" fi diff --git a/example_echo_test.go b/example_echo_test.go index 405c7a41..f1da25af 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -51,7 +51,6 @@ func Example_echo() { // Now we dial the server, send the messages and echo the responses. err = client("ws://" + l.Addr().String()) - time.Sleep(time.Second) if err != nil { log.Fatalf("client failed: %v", err) } diff --git a/limitedreader.go b/limitedreader.go new file mode 100644 index 00000000..63bf40c4 --- /dev/null +++ b/limitedreader.go @@ -0,0 +1,34 @@ +package websocket + +import ( + "fmt" + "io" + + "golang.org/x/xerrors" +) + +type limitedReader struct { + c *Conn + r io.Reader + left int64 + limit int64 +} + +func (lr *limitedReader) Read(p []byte) (int, error) { + if lr.limit == 0 { + lr.limit = lr.left + } + + if lr.left <= 0 { + msg := fmt.Sprintf("read limited at %v bytes", lr.limit) + lr.c.Close(StatusPolicyViolation, msg) + return 0, xerrors.Errorf(msg) + } + + if int64(len(p)) > lr.left { + p = p[:lr.left] + } + n, err := lr.r.Read(p) + lr.left -= int64(n) + return n, err +} diff --git a/websocket.go b/websocket.go index db2e82e7..cd709d61 100644 --- a/websocket.go +++ b/websocket.go @@ -22,6 +22,12 @@ import ( // // Please be sure to call Close on the connection when you // are finished with it to release resources. +// +// Control (ping, pong, close) frames will be responded to in a separate goroutine +// so if you do not expect any data messages, you do not need +// to read from the connection. However, if the peer +// sends a data message, further pings, pongs and close frames will not +// be read if you do not read the message from the connection. type Conn struct { subprotocol string br *bufio.Reader @@ -35,21 +41,21 @@ type Conn struct { closeErr error closed chan struct{} - writeDataLock chan struct{} + writeMsgLock chan struct{} writeFrameLock chan struct{} readMsgLock chan struct{} + readFrameLock chan struct{} readMsg chan header readMsgDone chan struct{} - readFrameLock chan struct{} setReadTimeout chan context.Context setWriteTimeout chan context.Context setConnContext chan context.Context getConnContext chan context.Context - pingListenerMu sync.Mutex - pingListener map[string]chan<- struct{} + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} } // Context returns a context derived from parent that will be cancelled @@ -91,7 +97,8 @@ func (c *Conn) close(err error) { close(c.closed) // This ensures every goroutine that interacts - // with the conn closes before it can interact with the connection + // with the conn returns before it can actually do anything and + // receives c.closeErr. c.readFrameLock <- struct{}{} c.writeFrameLock <- struct{}{} @@ -114,20 +121,20 @@ func (c *Conn) init() { c.msgReadLimit = 32768 - c.writeDataLock = make(chan struct{}, 1) + c.writeMsgLock = make(chan struct{}, 1) c.writeFrameLock = make(chan struct{}, 1) - c.readMsg = make(chan header) - c.readMsgDone = make(chan struct{}) c.readMsgLock = make(chan struct{}, 1) c.readFrameLock = make(chan struct{}, 1) + c.readMsg = make(chan header) + c.readMsgDone = make(chan struct{}) c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) c.setConnContext = make(chan context.Context) c.getConnContext = make(chan context.Context) - c.pingListener = make(map[string]chan<- struct{}) + c.activePings = make(map[string]chan<- struct{}) runtime.SetFinalizer(c, func(c *Conn) { c.close(xerrors.New("connection garbage collected")) @@ -199,11 +206,6 @@ func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() parentCtx := context.Background() - cancelCtx := func() {} - defer func() { - // We do not defer cancelCtx directly because its value may change. - cancelCtx() - }() for { select { @@ -219,8 +221,9 @@ func (c *Conn) timeoutLoop() { c.close(xerrors.Errorf("parent context cancelled: %w", parentCtx.Err())) return case parentCtx = <-c.setConnContext: - var ctx context.Context - ctx, cancelCtx = context.WithCancel(parentCtx) + ctx, cancelCtx := context.WithCancel(parentCtx) + defer cancelCtx() + select { case <-c.closed: return @@ -256,11 +259,11 @@ func (c *Conn) handleControl(h header) { case opPing: c.writePong(b) case opPong: - c.pingListenerMu.Lock() - listener, ok := c.pingListener[string(b)] - c.pingListenerMu.Unlock() + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() if ok { - close(listener) + close(pong) } case opClose: ce, err := parseClosePayload(b) @@ -278,7 +281,7 @@ func (c *Conn) handleControl(h header) { } } -func (c *Conn) readTillData() (header, error) { +func (c *Conn) readTillMsg() (header, error) { for { h, err := c.readHeader() if err != nil { @@ -330,12 +333,18 @@ func (c *Conn) readHeader() (header, error) { func (c *Conn) readLoop() { for { - h, err := c.readTillData() + h, err := c.readTillMsg() if err != nil { c.close(err) return } + if h.opcode == opContinuation && + h.fin && + h.payloadLength == 0 { + c.releaseLock(c.readMsgLock) + } + select { case <-c.closed: return @@ -438,11 +447,11 @@ func (c *Conn) releaseLock(lock chan struct{}) { func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error { if !opcode.controlOp() { - err := c.acquireLock(ctx, c.writeDataLock) + err := c.acquireLock(ctx, c.writeMsgLock) if err != nil { return err } - defer c.releaseLock(c.writeDataLock) + defer c.releaseLock(c.writeMsgLock) } err := c.writeFrame(ctx, header{ @@ -450,7 +459,7 @@ func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error opcode: opcode, }, p) if err != nil { - return xerrors.Errorf("failed to write frame: %v", err) + return xerrors.Errorf("failed to write frame: %w", err) } return nil } @@ -471,7 +480,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.acquireLock(ctx, c.writeDataLock) + err := c.acquireLock(ctx, c.writeMsgLock) if err != nil { return nil, err } @@ -567,7 +576,7 @@ func (w *messageWriter) close() error { return err } - w.c.releaseLock(w.c.writeDataLock) + w.c.releaseLock(w.c.writeMsgLock) return nil } @@ -575,18 +584,21 @@ func (w *messageWriter) close() error { // It returns the type of the message and a reader to read it. // The passed context will also bound the reader. // -// Your application must keep reading messages for the Conn to automatically respond to ping -// and close frames and not become stuck waiting for a data message to be read. -// Please ensure to read the full message from io.Reader. +// If you do not read from the reader till EOF, the connection will hang. // -// You can only read a single message at a time so do not call this method -// concurrently. +// You do not need to explicitly read from the connection to reply to control frames. +// Please see the docs on the Conn type. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { typ, r, err := c.reader(ctx) if err != nil { return 0, nil, xerrors.Errorf("failed to get reader: %w", err) } - return typ, io.LimitReader(r, c.msgReadLimit), nil + readLimit := atomic.LoadInt64(&c.msgReadLimit) + return typ, &limitedReader{ + c: c, + r: r, + left: readLimit, + }, nil } func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { @@ -717,6 +729,8 @@ func (r *messageReader) read(p []byte) (int, error) { // It applies to the Reader and Read methods. // // By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusPolicyViolation. func (c *Conn) SetReadLimit(n int64) { atomic.StoreInt64(&c.msgReadLimit, n) } @@ -744,14 +758,14 @@ func (c *Conn) ping(ctx context.Context) error { pong := make(chan struct{}) - c.pingListenerMu.Lock() - c.pingListener[p] = pong - c.pingListenerMu.Unlock() + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() defer func() { - c.pingListenerMu.Lock() - delete(c.pingListener, p) - c.pingListenerMu.Unlock() + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() }() err := c.writeMessage(ctx, opPing, []byte(p)) @@ -762,6 +776,8 @@ func (c *Conn) ping(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() + case <-c.closed: + return c.closeErr case <-pong: return nil } diff --git a/websocket_test.go b/websocket_test.go index f1905c30..c1e28d5f 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -409,6 +409,43 @@ func TestHandshake(t *testing.T) { return nil }, }, + { + name: "readLimit", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + _, _, err = c.Read(r.Context()) + if err == nil { + return xerrors.Errorf("expected error but got nil") + } + return nil + }, + client: func(ctx context.Context, u string) error { + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) + if err != nil { + return err + } + + err = c.Ping(ctx) + + var ce websocket.CloseError + if !xerrors.As(err, &ce) || ce.Code != websocket.StatusPolicyViolation { + return xerrors.Errorf("unexpected error: %w", err) + } + + return nil + }, + }, } for _, tc := range testCases { @@ -477,7 +514,7 @@ func TestAutobahnServer(t *testing.T) { defer s.Close() spec := map[string]interface{}{ - "outdir": "wstest_reports/server", + "outdir": "ci/out/wstestServerReports", "servers": []interface{}{ map[string]interface{}{ "agent": "main", @@ -487,7 +524,7 @@ func TestAutobahnServer(t *testing.T) { "cases": []string{"*"}, "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, } - specFile, err := ioutil.TempFile("", "websocket_fuzzingclient.json") + specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") if err != nil { t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) } @@ -516,7 +553,7 @@ func TestAutobahnServer(t *testing.T) { t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) } - checkWSTestIndex(t, "./wstest_reports/server/index.json") + checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") } func echoLoop(ctx context.Context, c *websocket.Conn) { @@ -594,11 +631,11 @@ func TestAutobahnClient(t *testing.T) { spec := map[string]interface{}{ "url": "ws://localhost:9001", - "outdir": "wstest_reports/client", + "outdir": "ci/out/wstestClientReports", "cases": []string{"*"}, "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, } - specFile, err := ioutil.TempFile("", "websocket_fuzzingserver.json") + specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") if err != nil { t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) } @@ -682,7 +719,7 @@ func TestAutobahnClient(t *testing.T) { } c.Close(websocket.StatusNormalClosure, "") - checkWSTestIndex(t, "./wstest_reports/client/index.json") + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") } func checkWSTestIndex(t *testing.T, path string) { diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index d85700bc..9d74b5bd 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -1,10 +1,9 @@ -// Package wsjson provides helpers for JSON messages. +// Package wsjson provides websocket helpers for JSON messages. package wsjson import ( "context" "encoding/json" - "io" "golang.org/x/xerrors" @@ -21,7 +20,7 @@ func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { } func read(ctx context.Context, c *websocket.Conn, v interface{}) error { - typ, r, err := c.Reader(ctx) + typ, b, err := c.Read(ctx) if err != nil { return err } @@ -31,21 +30,9 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) } - d := json.NewDecoder(r) - err = d.Decode(v) + err = json.Unmarshal(b, v) if err != nil { - return xerrors.Errorf("failed to decode json: %w", err) - } - - // Have to ensure we read till EOF. - // Unfortunate but necessary evil for now. Can improve later. - // The code to do this automatically gets complicated fast because - // we support concurrent reading. - // So the Reader has to synchronize with Read somehow. - // Maybe its best to bring back the old readLoop? - _, err = r.Read([]byte{0}) - if !xerrors.Is(err, io.EOF) { - return xerrors.Errorf("more data than needed in reader") + return xerrors.Errorf("failed to unmarshal json: %w", err) } return nil @@ -66,6 +53,8 @@ func write(ctx context.Context, c *websocket.Conn, v interface{}) error { return err } + // We use Encode because it automatically enables buffer reuse without us + // needing to do anything. Though see https://github.com/golang/go/issues/27735 e := json.NewEncoder(w) err = e.Encode(v) if err != nil { diff --git a/wspb/wspb.go b/wspb/wspb.go index edffede1..bcac9ea4 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -1,9 +1,8 @@ -// Package wspb provides helpers for protobuf messages. +// Package wspb provides websocket helpers for protobuf messages. package wspb import ( "context" - "io/ioutil" "github.com/golang/protobuf/proto" "golang.org/x/xerrors" @@ -21,7 +20,7 @@ func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { } func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - typ, r, err := c.Reader(ctx) + typ, b, err := c.Read(ctx) if err != nil { return err } @@ -31,11 +30,6 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) } - b, err := ioutil.ReadAll(r) - if err != nil { - return xerrors.Errorf("failed to read message: %w", err) - } - err = proto.Unmarshal(b, v) if err != nil { return xerrors.Errorf("failed to unmarshal protobuf: %w", err) From 4bfe0e9b4055f8d3847a09dea372c3d91de9d994 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 31 May 2019 18:36:10 -0400 Subject: [PATCH 2/5] Document wspb and wsjson buffer reuse See #71 --- wsjson/wsjson.go | 3 +++ wspb/wspb.go | 2 ++ 2 files changed, 5 insertions(+) diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 9d74b5bd..994ffad1 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -11,6 +11,8 @@ import ( ) // Read reads a json message from c into v. +// If the message is larger than 128 bytes, it will use a buffer +// from a pool instead of performing an allocation. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { err := read(ctx, c, v) if err != nil { @@ -39,6 +41,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { } // Write writes the json message v to c. +// It uses json.Encoder which automatically reuses buffers. func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { err := write(ctx, c, v) if err != nil { diff --git a/wspb/wspb.go b/wspb/wspb.go index bcac9ea4..e6c91693 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -11,6 +11,7 @@ import ( ) // Read reads a protobuf message from c into v. +// It will reuse buffers to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { err := read(ctx, c, v) if err != nil { @@ -39,6 +40,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { } // Write writes the protobuf message v to c. +// It will reuse buffers to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { err := write(ctx, c, v) if err != nil { From 6f3d9b364fc9651e3721aba9a8edf4f12436db2f Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 31 May 2019 23:27:55 -0400 Subject: [PATCH 3/5] Polish further --- README.md | 19 ++- accept.go | 12 +- dial.go | 10 +- example_echo_test.go | 2 +- messagetype.go | 2 + statuscode.go | 16 +-- websocket.go | 335 ++++++++++++++++++++++++------------------- websocket_test.go | 22 ++- 8 files changed, 239 insertions(+), 179 deletions(-) diff --git a/README.md b/README.md index afd5abba..f1fc5896 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,6 @@ websocket is a minimal and idiomatic WebSocket library for Go. -This library is now production ready but some parts of the API are marked as experimental. - -Please feel free to open an issue for feedback. - ## Install ```bash @@ -23,8 +19,8 @@ go get nhooyr.io/websocket - Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Zero dependencies outside of the stdlib for the core library - JSON and ProtoBuf helpers in the wsjson and wspb subpackages -- High performance, memory reuse wherever possible -- Concurrent reads and writes out of the box +- High performance, memory reuse by default +- Concurrent writes out of the box ## Roadmap @@ -124,8 +120,8 @@ also uses net/http's Client and ResponseWriter directly for WebSocket handshakes gorilla/websocket writes its handshakes to the underlying net.Conn which means it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. -Some more advantages of nhooyr/websocket are that it supports concurrent reads, -writes and makes it very easy to close the connection with a status code and reason. +Some more advantages of nhooyr/websocket are that it supports concurrent writes and +makes it very easy to close the connection with a status code and reason. nhooyr/websocket also responds to pings, pongs and close frames in a separate goroutine so that your application doesn't always need to read from the connection unless it expects a data message. @@ -134,11 +130,12 @@ even if you don't expect the peer to send any messages. In terms of performance, the differences depend on your application code. nhooyr/websocket reuses buffers efficiently out of the box whereas gorilla/websocket does not. As mentioned -above, it also supports concurrent readers and writers out of the box. +above, it also supports concurrent writers out of the box. -The only performance downside to nhooyr/websocket is that uses two extra goroutines. One for +The only performance con to nhooyr/websocket is that uses two extra goroutines. One for reading pings, pongs and close frames async to application code and another to support -context.Context cancellation. This costs 4 KB of memory which is fairly cheap. +context.Context cancellation. This costs 4 KB of memory which is cheap compared +to the benefits. ### x/net/websocket diff --git a/accept.go b/accept.go index a80f70aa..6e214111 100644 --- a/accept.go +++ b/accept.go @@ -76,12 +76,12 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { } // Accept accepts a WebSocket handshake from a client and upgrades the -// the connection to WebSocket. +// the connection to a WebSocket. // // Accept will reject the handshake if the Origin domain is not the same as the Host unless // the InsecureSkipVerify option is set. // -// The returned connection will be bound by r.Context(). Use c.Context() to change +// The returned connection will be bound by r.Context(). Use conn.Context() to change // the bounding context. func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) @@ -107,7 +107,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, hj, ok := w.(http.Hijacker) if !ok { - err = xerrors.New("response writer must implement http.Hijacker") + err = xerrors.New("passed ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } @@ -115,7 +115,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") - handleKey(w, r) + handleSecWebSocketKey(w, r) subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { @@ -163,7 +163,7 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string { var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") -func handleKey(w http.ResponseWriter, r *http.Request) { +func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { key := r.Header.Get("Sec-WebSocket-Key") h := sha1.New() h.Write([]byte(key)) @@ -185,5 +185,5 @@ func authenticateOrigin(r *http.Request) error { if strings.EqualFold(u.Host, r.Host) { return nil } - return xerrors.Errorf("request origin %q is not authorized for host %q", origin, r.Host) + return xerrors.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) } diff --git a/dial.go b/dial.go index 53acd32c..64d2820d 100644 --- a/dial.go +++ b/dial.go @@ -18,9 +18,9 @@ import ( // DialOptions represents the options available to pass to Dial. type DialOptions struct { // HTTPClient is the http client used for the handshake. - // Its Transport must use HTTP/1.1 and return writable bodies - // for WebSocket handshakes. This was introduced in Go 1.12. - // http.Transport does this all correctly. + // Its Transport must return writable bodies + // for WebSocket handshakes. + // http.Transport does this correctly beginning with Go 1.12. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. @@ -30,7 +30,7 @@ type DialOptions struct { Subprotocols []string } -// We use this key for all client requests as the Sec-WebSocket-Key header is useless. +// We use this key for all client requests as the Sec-WebSocket-Key header doesn't do anything. // See https://stackoverflow.com/a/37074398/4283659. // We also use the same mask key for every message as it too does not make a difference. var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16)) @@ -108,7 +108,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res rwc, ok := resp.Body.(io.ReadWriteCloser) if !ok { - return nil, resp, xerrors.Errorf("response body is not a read write closer: %T", rwc) + return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) } c := &Conn{ diff --git a/example_echo_test.go b/example_echo_test.go index f1da25af..6923bc04 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -20,7 +20,7 @@ import ( // dials the server and then sends 5 different messages // and prints out the server's responses. func Example_echo() { - // First we listen on port 0, that means the OS will + // First we listen on port 0 which means the OS will // assign us a random free port. This is the listener // the server will serve on and the client will connect to. l, err := net.Listen("tcp", "localhost:0") diff --git a/messagetype.go b/messagetype.go index 1fd9cd6e..6a1205ee 100644 --- a/messagetype.go +++ b/messagetype.go @@ -13,3 +13,5 @@ const ( // MessageBinary is for binary messages like Protobufs. MessageBinary MessageType = MessageType(opBinary) ) + +// Above I've explicitly included the types of the constants for stringer. diff --git a/statuscode.go b/statuscode.go index c7b20367..661c6693 100644 --- a/statuscode.go +++ b/statuscode.go @@ -60,7 +60,7 @@ func parseClosePayload(p []byte) (CloseError, error) { } if len(p) < 2 { - return CloseError{}, xerrors.Errorf("close payload too small, cannot even contain the 2 byte status code") + return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ @@ -78,13 +78,13 @@ func parseClosePayload(p []byte) (CloseError, error) { // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 func validWireCloseCode(code StatusCode) bool { - if code >= StatusNormalClosure && code <= statusTLSHandshake { - switch code { - case 1004, StatusNoStatusRcvd, statusAbnormalClosure, statusTLSHandshake: - return false - default: - return true - } + switch code { + case 1004, StatusNoStatusRcvd, statusAbnormalClosure, statusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true } if code >= 3000 && code <= 4999 { return true diff --git a/websocket.go b/websocket.go index cd709d61..bb184eb4 100644 --- a/websocket.go +++ b/websocket.go @@ -21,13 +21,7 @@ import ( // All methods may be called concurrently. // // Please be sure to call Close on the connection when you -// are finished with it to release resources. -// -// Control (ping, pong, close) frames will be responded to in a separate goroutine -// so if you do not expect any data messages, you do not need -// to read from the connection. However, if the peer -// sends a data message, further pings, pongs and close frames will not -// be read if you do not read the message from the connection. +// are finished with it to release the associated resources. type Conn struct { subprotocol string br *bufio.Reader @@ -35,18 +29,28 @@ type Conn struct { closer io.Closer client bool + // In bytes. msgReadLimit int64 closeOnce sync.Once closeErr error closed chan struct{} + // writeMsgLock is acquired to write a multi frame message. writeMsgLock chan struct{} + // writeFrameLock is acquired to write a single frame. + // Effectively meaning whoever holds it gets to write to bw. writeFrameLock chan struct{} + // readMsgLock is acquired to read a message with Reader. readMsgLock chan struct{} + // readFrameLock is acquired to read from bw. readFrameLock chan struct{} + // readMsg is used by messageReader to receive frames from + // readLoop. readMsg chan header + // readMsgDone is used to tell the readLoop to continue after + // messageReader has read a frame. readMsgDone chan struct{} setReadTimeout chan context.Context @@ -62,7 +66,7 @@ type Conn struct { // when the connection is closed or broken. // If the parent context is cancelled, the connection will be closed. // -// This is an experimental API that may be removed in the future. +// This is an experimental API. // Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79 func (c *Conn) Context(parent context.Context) context.Context { select { @@ -87,24 +91,25 @@ func (c *Conn) close(err error) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) - cerr := c.closer.Close() - if err != nil { - cerr = err - } - - c.closeErr = xerrors.Errorf("websocket closed: %w", cerr) - + c.closeErr = xerrors.Errorf("websocket closed: %w", err) close(c.closed) - // This ensures every goroutine that interacts - // with the conn returns before it can actually do anything and - // receives c.closeErr. - c.readFrameLock <- struct{}{} - c.writeFrameLock <- struct{}{} + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.closer.Close() // See comment in dial.go if c.client { + // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer + // and we can safely return them. + // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent + // a deadlock. + // As of now, this is in writeFrame, readPayload and readHeader. + c.readFrameLock <- struct{}{} returnBufioReader(c.br) + + c.writeFrameLock <- struct{}{} returnBufioWriter(c.bw) } }) @@ -144,61 +149,78 @@ func (c *Conn) init() { go c.readLoop() } +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { + err := c.writeFrame(ctx, true, opcode, p) + if err != nil { + return xerrors.Errorf("failed to write control frame: %w", err) + } + return nil +} + +// writeFrame handles all writes to the connection. // We never mask inside here because our mask key is always 0,0,0,0. -// See comment on secWebSocketKey. -func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) (err error) { - err = c.acquireLock(ctx, c.writeFrameLock) +// See comment on secWebSocketKey for why. +func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error { + h := header{ + fin: fin, + opcode: opcode, + masked: c.client, + payloadLength: int64(len(p)), + } + b2 := marshalHeader(h) + + err := c.acquireLock(ctx, c.writeFrameLock) if err != nil { return err } defer c.releaseLock(c.writeFrameLock) select { - case <-ctx.Done(): - return ctx.Err() case <-c.closed: return c.closeErr case c.setWriteTimeout <- ctx: } - defer func() { - // We have to remove the write timeout, even if ctx is cancelled. + + writeErr := func(err error) error { select { case <-c.closed: - return - case c.setWriteTimeout <- context.Background(): + return c.closeErr + default: } - }() - defer func() { - if err != nil { - // We need to always release the lock first before closing the connection to ensure - // the lock can be acquired inside close. - c.releaseLock(c.writeFrameLock) - c.close(err) - } - }() + err = xerrors.Errorf("failed to write to connection: %w", err) + // We need to release the lock first before closing the connection to ensure + // the lock can be acquired inside close to ensure no one can access c.bw. + c.releaseLock(c.writeFrameLock) + c.close(err) - h.masked = c.client - h.payloadLength = int64(len(p)) + return err + } - b2 := marshalHeader(h) _, err = c.bw.Write(b2) if err != nil { - return xerrors.Errorf("failed to write to connection: %w", err) + return writeErr(err) } _, err = c.bw.Write(p) if err != nil { - return xerrors.Errorf("failed to write to connection: %w", err) - + return writeErr(err) } - if h.fin { - err := c.bw.Flush() + if fin { + err = c.bw.Flush() if err != nil { - return xerrors.Errorf("failed to write to connection: %w", err) + return writeErr(err) } } + // We already finished writing, no need to potentially brick the connection if + // the context expires. + select { + case <-c.closed: + return c.closeErr + case c.setWriteTimeout <- context.Background(): + } + return nil } @@ -244,10 +266,13 @@ func (c *Conn) handleControl(h header) { return } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + b := make([]byte, h.payloadLength) - _, err := io.ReadFull(c.br, b) + + _, err := c.readPayload(ctx, b) if err != nil { - c.close(xerrors.Errorf("failed to read control frame payload: %w", err)) return } @@ -289,12 +314,9 @@ func (c *Conn) readTillMsg() (header, error) { } if h.rsv1 || h.rsv2 || h.rsv3 { - ce := CloseError{ - Code: StatusProtocolError, - Reason: fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3), - } - c.Close(ce.Code, ce.Reason) - return header{}, ce + err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + c.Close(StatusProtocolError, err.Error()) + return header{}, err } if h.opcode.controlOp() { @@ -306,12 +328,9 @@ func (c *Conn) readTillMsg() (header, error) { case opBinary, opText, opContinuation: return h, nil default: - ce := CloseError{ - Code: StatusProtocolError, - Reason: fmt.Sprintf("unknown opcode %v", h.opcode), - } - c.Close(ce.Code, ce.Reason) - return header{}, ce + err := xerrors.Errorf("received unknown opcode %v", h.opcode) + c.Close(StatusProtocolError, err.Error()) + return header{}, err } } } @@ -325,7 +344,10 @@ func (c *Conn) readHeader() (header, error) { h, err := readHeader(c.br) if err != nil { - return header{}, xerrors.Errorf("failed to read header: %w", err) + err := xerrors.Errorf("failed to read header: %w", err) + c.releaseLock(c.readFrameLock) + c.close(err) + return header{}, err } return h, nil @@ -335,16 +357,9 @@ func (c *Conn) readLoop() { for { h, err := c.readTillMsg() if err != nil { - c.close(err) return } - if h.opcode == opContinuation && - h.fin && - h.payloadLength == 0 { - c.releaseLock(c.readMsgLock) - } - select { case <-c.closed: return @@ -363,7 +378,7 @@ func (c *Conn) writePong(p []byte) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeMessage(ctx, opPong, p) + err := c.writeControl(ctx, opPong, p) return err } @@ -411,7 +426,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeMessage(ctx, opClose, p) + err := c.writeControl(ctx, opClose, p) c.close(cerr) @@ -445,32 +460,13 @@ func (c *Conn) releaseLock(lock chan struct{}) { } } -func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error { - if !opcode.controlOp() { - err := c.acquireLock(ctx, c.writeMsgLock) - if err != nil { - return err - } - defer c.releaseLock(c.writeMsgLock) - } - - err := c.writeFrame(ctx, header{ - fin: true, - opcode: opcode, - }, p) - if err != nil { - return xerrors.Errorf("failed to write frame: %w", err) - } - return nil -} - // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // -// Ensure you close the writer once you have written the entire message. -// Concurrent calls to Writer are ok. -// Only one writer can be open at a time so Writer will block if there is -// another goroutine with an open writer until that writer is closed. +// You must close the writer once you have written the entire message. +// +// Only one writer can be open at a time, multiple calls will block until the previous writer +// is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { wc, err := c.writer(ctx, typ) if err != nil { @@ -494,6 +490,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // Read is a convenience method to read a single message from the connection. // // See the Reader method if you want to be able to reuse buffers or want to stream a message. +// The docs on Reader apply to this metohd as well. // // This is an experimental API, please let me know how you feel about it in // https://github.com/nhooyr/websocket/issues/62 @@ -513,12 +510,31 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // Write is a convenience method to write a message to the connection. // -// See the Writer method if you want to stream a message. +// See the Writer method if you want to stream a message. The docs on Writer +// regarding concurrency also apply to this method. // // This is an experimental API, please let me know how you feel about it in // https://github.com/nhooyr/websocket/issues/62 func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - return c.writeMessage(ctx, opcode(typ), p) + err := c.write(ctx, typ, p) + if err != nil { + return xerrors.Errorf("failed to write msg: %w", err) + } + return nil +} + +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { + err := c.acquireLock(ctx, c.writeMsgLock) + if err != nil { + return err + } + defer c.releaseLock(c.writeMsgLock) + + err = c.writeFrame(ctx, true, opcode(typ), p) + if err != nil { + return err + } + return nil } // messageWriter enables writing to a WebSocket connection. @@ -542,11 +558,9 @@ func (w *messageWriter) write(p []byte) (int, error) { if w.closed { return 0, xerrors.Errorf("cannot use closed writer") } - err := w.c.writeFrame(w.ctx, header{ - opcode: w.opcode, - }, p) + err := w.c.writeFrame(w.ctx, false, w.opcode, p) if err != nil { - return 0, err + return 0, xerrors.Errorf("failed to write data frame: %w", err) } w.opcode = opContinuation return len(p), nil @@ -568,12 +582,9 @@ func (w *messageWriter) close() error { } w.closed = true - err := w.c.writeFrame(w.ctx, header{ - fin: true, - opcode: w.opcode, - }, nil) + err := w.c.writeFrame(w.ctx, true, w.opcode, nil) if err != nil { - return err + return xerrors.Errorf("failed to write fin frame: %w", err) } w.c.releaseLock(w.c.writeMsgLock) @@ -584,20 +595,30 @@ func (w *messageWriter) close() error { // It returns the type of the message and a reader to read it. // The passed context will also bound the reader. // -// If you do not read from the reader till EOF, the connection will hang. +// Control (ping, pong, close) frames will be handled automatically +// in a separate goroutine so if you do not expect any data messages, +// you do not need to read from the connection. However, if the peer +// sends a data message, further pings, pongs and close frames will not +// be read if you do not read the message from the connection. // -// You do not need to explicitly read from the connection to reply to control frames. -// Please see the docs on the Conn type. +// If you do not read from the reader till EOF, nothing further will be read from the connection. +// Only one reader can be open at a time, multiple calls will block until the previous reader +// is read to completion. +// TODO remove concurrent reads. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + // We could handle the case of json.Decoder where the message may not be read + // till EOF but would still be read till the end of data. E.g. if the other side + // sends a fin frame after the message, we could allow the code to continue and + // just pick off but the code for that gets complicated and if there is real data + // after the JSON object, Reader would block until the timeout is hit typ, r, err := c.reader(ctx) if err != nil { return 0, nil, xerrors.Errorf("failed to get reader: %w", err) } - readLimit := atomic.LoadInt64(&c.msgReadLimit) return typ, &limitedReader{ c: c, r: r, - left: readLimit, + left: atomic.LoadInt64(&c.msgReadLimit), }, nil } @@ -614,12 +635,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro return 0, nil, ctx.Err() case h := <-c.readMsg: if h.opcode == opContinuation { - ce := CloseError{ - Code: StatusProtocolError, - Reason: "continuation frame not after data or text frame", - } - c.Close(ce.Code, ce.Reason) - return 0, nil, ce + err := xerrors.Errorf("received continuation frame not after data or text frame") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err } return MessageType(h.opcode), &messageReader{ ctx: ctx, @@ -661,14 +679,13 @@ func (r *messageReader) read(p []byte) (int, error) { select { case <-r.c.closed: return 0, r.c.closeErr + case <-r.ctx.Done(): + return 0, r.ctx.Err() case h := <-r.c.readMsg: if h.opcode != opContinuation { - ce := CloseError{ - Code: StatusProtocolError, - Reason: "cannot read new data frame when previous frame is not finished", - } - r.c.Close(ce.Code, ce.Reason) - return 0, ce + err := xerrors.Errorf("received new data frame without finishing the previous frame") + r.c.Close(StatusProtocolError, err.Error()) + return 0, err } r.h = &h } @@ -678,24 +695,7 @@ func (r *messageReader) read(p []byte) (int, error) { p = p[:r.h.payloadLength] } - select { - case <-r.c.closed: - return 0, r.c.closeErr - case r.c.setReadTimeout <- r.ctx: - } - - err := r.c.acquireLock(r.ctx, r.c.readFrameLock) - if err != nil { - return 0, err - } - n, err := io.ReadFull(r.c.br, p) - r.c.releaseLock(r.c.readFrameLock) - - select { - case <-r.c.closed: - return 0, r.c.closeErr - case r.c.setReadTimeout <- context.Background(): - } + n, err := r.readPayload(p) r.h.payloadLength -= int64(n) if r.h.masked { @@ -703,8 +703,9 @@ func (r *messageReader) read(p []byte) (int, error) { } if err != nil { - r.c.close(xerrors.Errorf("failed to read control frame payload: %w", err)) - return n, r.c.closeErr + err := xerrors.Errorf("failed to read frame payload: %w", err) + r.c.close(err) + return n, err } if r.h.payloadLength == 0 { @@ -713,11 +714,13 @@ func (r *messageReader) read(p []byte) (int, error) { return n, r.c.closeErr case r.c.readMsgDone <- struct{}{}: } + if r.h.fin { r.eofed = true r.c.releaseLock(r.c.readMsgLock) return n, io.EOF } + r.maskPos = 0 r.h = nil } @@ -725,6 +728,50 @@ func (r *messageReader) read(p []byte) (int, error) { return n, nil } +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +func (c *Conn) readPayload(ctx context.Context, p []byte) (int, error) { + err := c.acquireLock(ctx, c.readFrameLock) + if err != nil { + return 0, err + } + defer c.releaseLock(c.readFrameLock) + + select { + case <-c.closed: + return 0, c.closeErr + case c.setReadTimeout <- ctx: + } + + n, err := io.ReadFull(c.br, p) + if err != nil { + select { + case <-c.closed: + return n, c.closeErr + default: + } + err = xerrors.Errorf("failed to read from connection: %w", err) + c.releaseLock(c.readFrameLock) + c.close(err) + return n, err + } + + select { + case <-c.closed: + return 0, c.closeErr + case c.setReadTimeout <- context.Background(): + } + + return 0, err +} + // SetReadLimit sets the max number of bytes to read for a single message. // It applies to the Reader and Read methods. // @@ -742,7 +789,7 @@ func init() { // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // -// This API is experimental and subject to change. +// This API is experimental. // Please provide feedback in https://github.com/nhooyr/websocket/issues/1. func (c *Conn) Ping(ctx context.Context) error { err := c.ping(ctx) @@ -768,7 +815,7 @@ func (c *Conn) ping(ctx context.Context) error { c.activePingsMu.Unlock() }() - err := c.writeMessage(ctx, opPing, []byte(p)) + err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } diff --git a/websocket_test.go b/websocket_test.go index c1e28d5f..b1c5b9d4 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -390,6 +390,11 @@ func TestHandshake(t *testing.T) { return err } + err = c.Write(r.Context(), websocket.MessageText, []byte("hi")) + if err != nil { + return err + } + c.Close(websocket.StatusNormalClosure, "") return nil }, @@ -405,6 +410,11 @@ func TestHandshake(t *testing.T) { return err } + _, _, err = c.Read(ctx) + if err != nil { + return err + } + c.Close(websocket.StatusNormalClosure, "") return nil }, @@ -521,7 +531,10 @@ func TestAutobahnServer(t *testing.T) { "url": strings.Replace(s.URL, "http", "ws", 1), }, }, - "cases": []string{"*"}, + "cases": []string{"*"}, + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. 7.5.1 is the same. + // 12.* and 13.* as we do not support compression. "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, } specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") @@ -630,9 +643,10 @@ func TestAutobahnClient(t *testing.T) { t.Parallel() spec := map[string]interface{}{ - "url": "ws://localhost:9001", - "outdir": "ci/out/wstestClientReports", - "cases": []string{"*"}, + "url": "ws://localhost:9001", + "outdir": "ci/out/wstestClientReports", + "cases": []string{"*"}, + // See TestAutobahnServer for the reasons why we exclude these. "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, } specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") From 9925b643d1f5926762fa2d880f2df18be6c2c1d7 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 1 Jun 2019 05:40:05 -0400 Subject: [PATCH 4/5] Remove concurrent reads feature Doesn't really add much. --- websocket.go | 132 ++++++++++++++++++++++++++------------------------- 1 file changed, 67 insertions(+), 65 deletions(-) diff --git a/websocket.go b/websocket.go index bb184eb4..e974002b 100644 --- a/websocket.go +++ b/websocket.go @@ -11,14 +11,14 @@ import ( "runtime" "strconv" "sync" - "sync/atomic" "time" "golang.org/x/xerrors" ) // Conn represents a WebSocket connection. -// All methods may be called concurrently. +// All methods may be called concurrently except for Reader, Read +// and SetReadLimit. // // Please be sure to call Close on the connection when you // are finished with it to release the associated resources. @@ -29,29 +29,30 @@ type Conn struct { closer io.Closer client bool - // In bytes. + // read limit for a message in bytes. msgReadLimit int64 closeOnce sync.Once closeErr error closed chan struct{} - // writeMsgLock is acquired to write a multi frame message. - writeMsgLock chan struct{} + // writeMsgLock is acquired to write a data message. + writeMsgLock chan struct{} // writeFrameLock is acquired to write a single frame. // Effectively meaning whoever holds it gets to write to bw. writeFrameLock chan struct{} - // readMsgLock is acquired to read a message with Reader. - readMsgLock chan struct{} + // Used to ensure the previous reader is read till EOF before allowing + // a new one. + previousReader *messageReader // readFrameLock is acquired to read from bw. readFrameLock chan struct{} // readMsg is used by messageReader to receive frames from // readLoop. - readMsg chan header + readMsg chan header // readMsgDone is used to tell the readLoop to continue after // messageReader has read a frame. - readMsgDone chan struct{} + readMsgDone chan struct{} setReadTimeout chan context.Context setWriteTimeout chan context.Context @@ -129,7 +130,6 @@ func (c *Conn) init() { c.writeMsgLock = make(chan struct{}, 1) c.writeFrameLock = make(chan struct{}, 1) - c.readMsgLock = make(chan struct{}, 1) c.readFrameLock = make(chan struct{}, 1) c.readMsg = make(chan header) c.readMsgDone = make(chan struct{}) @@ -271,7 +271,7 @@ func (c *Conn) handleControl(h header) { b := make([]byte, h.payloadLength) - _, err := c.readPayload(ctx, b) + _, err := c.readFramePayload(ctx, b) if err != nil { return } @@ -427,13 +427,11 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { defer cancel() err := c.writeControl(ctx, opClose, p) - - c.close(cerr) - if err != nil { return err } + c.close(cerr) if !xerrors.Is(c.closeErr, cerr) { return c.closeErr } @@ -444,6 +442,16 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { select { case <-ctx.Done(): + var err error + switch lock { + case c.writeFrameLock, c.writeMsgLock: + err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err()) + case c.readFrameLock: + err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err()) + default: + panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) + } + c.close(err) return ctx.Err() case <-c.closed: return c.closeErr @@ -490,7 +498,7 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // Read is a convenience method to read a single message from the connection. // // See the Reader method if you want to be able to reuse buffers or want to stream a message. -// The docs on Reader apply to this metohd as well. +// The docs on Reader apply to this method as well. // // This is an experimental API, please let me know how you feel about it in // https://github.com/nhooyr/websocket/issues/62 @@ -501,11 +509,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { } b, err := ioutil.ReadAll(r) - if err != nil { - return typ, b, err - } - - return typ, b, nil + return typ, b, err } // Write is a convenience method to write a message to the connection. @@ -531,10 +535,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { defer c.releaseLock(c.writeMsgLock) err = c.writeFrame(ctx, true, opcode(typ), p) - if err != nil { - return err - } - return nil + return err } // messageWriter enables writing to a WebSocket connection. @@ -591,9 +592,11 @@ func (w *messageWriter) close() error { return nil } -// Reader will wait until there is a WebSocket data message to read from the connection. +// Reader waits until there is a WebSocket data message to read +// from the connection. // It returns the type of the message and a reader to read it. // The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. // // Control (ping, pong, close) frames will be handled automatically // in a separate goroutine so if you do not expect any data messages, @@ -601,16 +604,8 @@ func (w *messageWriter) close() error { // sends a data message, further pings, pongs and close frames will not // be read if you do not read the message from the connection. // -// If you do not read from the reader till EOF, nothing further will be read from the connection. -// Only one reader can be open at a time, multiple calls will block until the previous reader -// is read to completion. -// TODO remove concurrent reads. +// Only one Reader may be open at a time. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - // We could handle the case of json.Decoder where the message may not be read - // till EOF but would still be read till the end of data. E.g. if the other side - // sends a fin frame after the message, we could allow the code to continue and - // just pick off but the code for that gets complicated and if there is real data - // after the JSON object, Reader would block until the timeout is hit typ, r, err := c.reader(ctx) if err != nil { return 0, nil, xerrors.Errorf("failed to get reader: %w", err) @@ -618,14 +613,13 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return typ, &limitedReader{ c: c, r: r, - left: atomic.LoadInt64(&c.msgReadLimit), + left: c.msgReadLimit, }, nil } -func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { - err = c.acquireLock(ctx, c.readMsgLock) - if err != nil { - return 0, nil, err +func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { + if c.previousReader.h != nil && c.previousReader.h.payloadLength > 0 { + return 0, nil, xerrors.Errorf("previous message not read to completion") } select { @@ -634,26 +628,42 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro case <-ctx.Done(): return 0, nil, ctx.Err() case h := <-c.readMsg: - if h.opcode == opContinuation { + if c.previousReader != nil && !c.previousReader.done { + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data message without finishing the previous message") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err + } + + if !h.fin || h.payloadLength > 0 { + return 0, nil, xerrors.Errorf("previous message not read to completion") + } + + c.previousReader.done = true + return c.reader(ctx) + } else if h.opcode == opContinuation { err := xerrors.Errorf("received continuation frame not after data or text frame") c.Close(StatusProtocolError, err.Error()) return 0, nil, err } - return MessageType(h.opcode), &messageReader{ + r := &messageReader{ ctx: ctx, h: &h, c: c, - }, nil + } + c.previousReader = r + return MessageType(h.opcode), r, nil } } // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { - ctx context.Context - maskPos int + ctx context.Context + c *Conn + h *header - c *Conn - eofed bool + maskPos int + done bool } // Read reads as many bytes as possible into p. @@ -665,13 +675,15 @@ func (r *messageReader) Read(p []byte) (int, error) { if xerrors.Is(err, io.EOF) { return n, io.EOF } - return n, xerrors.Errorf("failed to read: %w", err) + err = xerrors.Errorf("failed to read: %w", err) + r.c.close(err) + return n, err } return n, nil } func (r *messageReader) read(p []byte) (int, error) { - if r.eofed { + if r.done { return 0, xerrors.Errorf("cannot use EOFed reader") } @@ -695,7 +707,7 @@ func (r *messageReader) read(p []byte) (int, error) { p = p[:r.h.payloadLength] } - n, err := r.readPayload(p) + n, err := r.c.readFramePayload(r.ctx, p) r.h.payloadLength -= int64(n) if r.h.masked { @@ -703,8 +715,6 @@ func (r *messageReader) read(p []byte) (int, error) { } if err != nil { - err := xerrors.Errorf("failed to read frame payload: %w", err) - r.c.close(err) return n, err } @@ -716,8 +726,7 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.h.fin { - r.eofed = true - r.c.releaseLock(r.c.readMsgLock) + r.done = true return n, io.EOF } @@ -728,16 +737,7 @@ func (r *messageReader) read(p []byte) (int, error) { return n, nil } -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - -func (c *Conn) readPayload(ctx context.Context, p []byte) (int, error) { +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { err := c.acquireLock(ctx, c.readFrameLock) if err != nil { return 0, err @@ -779,7 +779,7 @@ func (c *Conn) readPayload(ctx context.Context, p []byte) (int, error) { // // When the limit is hit, the connection will be closed with StatusPolicyViolation. func (c *Conn) SetReadLimit(n int64) { - atomic.StoreInt64(&c.msgReadLimit, n) + c.msgReadLimit = n } func init() { @@ -794,7 +794,9 @@ func init() { func (c *Conn) Ping(ctx context.Context) error { err := c.ping(ctx) if err != nil { - return xerrors.Errorf("failed to ping: %w", err) + err = xerrors.Errorf("failed to ping: %w", err) + c.close(err) + return err } return nil } From 38219393379f9e6838ecfa2fe4f2e9e56c4df47b Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 1 Jun 2019 06:21:16 -0400 Subject: [PATCH 5/5] Get CI passing --- README.md | 9 +- websocket.go | 849 ++++++++++++++++++++++++++------------------------- 2 files changed, 440 insertions(+), 418 deletions(-) diff --git a/README.md b/README.md index f1fc5896..4199423c 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ websocket is a minimal and idiomatic WebSocket library for Go. ## Install ```bash -go get nhooyr.io/websocket +go get nhooyr.io/websocket@v1.0.0 ``` ## Features @@ -19,7 +19,7 @@ go get nhooyr.io/websocket - Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Zero dependencies outside of the stdlib for the core library - JSON and ProtoBuf helpers in the wsjson and wspb subpackages -- High performance, memory reuse by default +- Highly optimized by default - Concurrent writes out of the box ## Roadmap @@ -129,8 +129,9 @@ gorilla/websocket requires you to constantly read from the connection to respond even if you don't expect the peer to send any messages. In terms of performance, the differences depend on your application code. nhooyr/websocket -reuses buffers efficiently out of the box whereas gorilla/websocket does not. As mentioned -above, it also supports concurrent writers out of the box. +reuses buffers efficiently out of the box if you use the wsjson and wspb subpackages whereas +gorilla/websocket does not. As mentioned above, nhooyr/websocket also supports concurrent +writers out of the box. The only performance con to nhooyr/websocket is that uses two extra goroutines. One for reading pings, pongs and close frames async to application code and another to support diff --git a/websocket.go b/websocket.go index e974002b..d59812b8 100644 --- a/websocket.go +++ b/websocket.go @@ -63,65 +63,6 @@ type Conn struct { activePings map[string]chan<- struct{} } -// Context returns a context derived from parent that will be cancelled -// when the connection is closed or broken. -// If the parent context is cancelled, the connection will be closed. -// -// This is an experimental API. -// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79 -func (c *Conn) Context(parent context.Context) context.Context { - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case c.setConnContext <- parent: - } - - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case ctx := <-c.getConnContext: - return ctx - } -} - -func (c *Conn) close(err error) { - c.closeOnce.Do(func() { - runtime.SetFinalizer(c, nil) - - c.closeErr = xerrors.Errorf("websocket closed: %w", err) - close(c.closed) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.closer.Close() - - // See comment in dial.go - if c.client { - // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer - // and we can safely return them. - // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent - // a deadlock. - // As of now, this is in writeFrame, readPayload and readHeader. - c.readFrameLock <- struct{}{} - returnBufioReader(c.br) - - c.writeFrameLock <- struct{}{} - returnBufioWriter(c.bw) - } - }) -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - func (c *Conn) init() { c.closed = make(chan struct{}) @@ -149,79 +90,38 @@ func (c *Conn) init() { go c.readLoop() } -func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { - err := c.writeFrame(ctx, true, opcode, p) - if err != nil { - return xerrors.Errorf("failed to write control frame: %w", err) - } - return nil +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol } -// writeFrame handles all writes to the connection. -// We never mask inside here because our mask key is always 0,0,0,0. -// See comment on secWebSocketKey for why. -func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error { - h := header{ - fin: fin, - opcode: opcode, - masked: c.client, - payloadLength: int64(len(p)), - } - b2 := marshalHeader(h) - - err := c.acquireLock(ctx, c.writeFrameLock) - if err != nil { - return err - } - defer c.releaseLock(c.writeFrameLock) - - select { - case <-c.closed: - return c.closeErr - case c.setWriteTimeout <- ctx: - } - - writeErr := func(err error) error { - select { - case <-c.closed: - return c.closeErr - default: - } +func (c *Conn) close(err error) { + c.closeOnce.Do(func() { + runtime.SetFinalizer(c, nil) - err = xerrors.Errorf("failed to write to connection: %w", err) - // We need to release the lock first before closing the connection to ensure - // the lock can be acquired inside close to ensure no one can access c.bw. - c.releaseLock(c.writeFrameLock) - c.close(err) + c.closeErr = xerrors.Errorf("websocket closed: %w", err) + close(c.closed) - return err - } + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.closer.Close() - _, err = c.bw.Write(b2) - if err != nil { - return writeErr(err) - } - _, err = c.bw.Write(p) - if err != nil { - return writeErr(err) - } + // See comment in dial.go + if c.client { + // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer + // and we can safely return them. + // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent + // a deadlock. + // As of now, this is in writeFrame, readFramePayload and readHeader. + c.readFrameLock <- struct{}{} + returnBufioReader(c.br) - if fin { - err = c.bw.Flush() - if err != nil { - return writeErr(err) + c.writeFrameLock <- struct{}{} + returnBufioWriter(c.bw) } - } - - // We already finished writing, no need to potentially brick the connection if - // the context expires. - select { - case <-c.closed: - return c.closeErr - case c.setWriteTimeout <- context.Background(): - } - - return nil + }) } func (c *Conn) timeoutLoop() { @@ -255,60 +155,84 @@ func (c *Conn) timeoutLoop() { } } -func (c *Conn) handleControl(h header) { - if h.payloadLength > maxControlFramePayload { - c.Close(StatusProtocolError, "control frame too large") - return +// Context returns a context derived from parent that will be cancelled +// when the connection is closed or broken. +// If the parent context is cancelled, the connection will be closed. +// +// This is an experimental API. +// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79 +func (c *Conn) Context(parent context.Context) context.Context { + select { + case <-c.closed: + ctx, cancel := context.WithCancel(parent) + cancel() + return ctx + case c.setConnContext <- parent: } - if !h.fin { - c.Close(StatusProtocolError, "control frame cannot be fragmented") - return + select { + case <-c.closed: + ctx, cancel := context.WithCancel(parent) + cancel() + return ctx + case ctx := <-c.getConnContext: + return ctx } +} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - b := make([]byte, h.payloadLength) - - _, err := c.readFramePayload(ctx, b) - if err != nil { - return +func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { + select { + case <-ctx.Done(): + var err error + switch lock { + case c.writeFrameLock, c.writeMsgLock: + err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err()) + case c.readFrameLock: + err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err()) + default: + panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) + } + c.close(err) + return ctx.Err() + case <-c.closed: + return c.closeErr + case lock <- struct{}{}: + return nil } +} - if h.masked { - fastXOR(h.maskKey, 0, b) +func (c *Conn) releaseLock(lock chan struct{}) { + // Allow multiple releases. + select { + case <-lock: + default: } +} - switch h.opcode { - case opPing: - c.writePong(b) - case opPong: - c.activePingsMu.Lock() - pong, ok := c.activePings[string(b)] - c.activePingsMu.Unlock() - if ok { - close(pong) - } - case opClose: - ce, err := parseClosePayload(b) +func (c *Conn) readLoop() { + for { + h, err := c.readTillMsg() if err != nil { - c.close(xerrors.Errorf("received invalid close payload: %w", err)) return } - if ce.Code == StatusNoStatusRcvd { - c.writeClose(nil, ce) - } else { - c.Close(ce.Code, ce.Reason) + + select { + case <-c.closed: + return + case c.readMsg <- h: + } + + select { + case <-c.closed: + return + case <-c.readMsgDone: } - default: - panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) } } func (c *Conn) readTillMsg() (header, error) { for { - h, err := c.readHeader() + h, err := c.readFrameHeader() if err != nil { return header{}, err } @@ -335,7 +259,7 @@ func (c *Conn) readTillMsg() (header, error) { } } -func (c *Conn) readHeader() (header, error) { +func (c *Conn) readFrameHeader() (header, error) { err := c.acquireLock(context.Background(), c.readFrameLock) if err != nil { return header{}, err @@ -353,119 +277,282 @@ func (c *Conn) readHeader() (header, error) { return h, nil } -func (c *Conn) readLoop() { - for { - h, err := c.readTillMsg() - if err != nil { - return - } +func (c *Conn) handleControl(h header) { + if h.payloadLength > maxControlFramePayload { + c.Close(StatusProtocolError, "control frame too large") + return + } - select { - case <-c.closed: - return - case c.readMsg <- h: - } - - select { - case <-c.closed: - return - case <-c.readMsgDone: - } + if !h.fin { + c.Close(StatusProtocolError, "control frame cannot be fragmented") + return } -} -func (c *Conn) writePong(p []byte) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeControl(ctx, opPong, p) - return err + b := make([]byte, h.payloadLength) + + _, err := c.readFramePayload(ctx, b) + if err != nil { + return + } + + if h.masked { + fastXOR(h.maskKey, 0, b) + } + + switch h.opcode { + case opPing: + c.writePong(b) + case opPong: + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() + if ok { + close(pong) + } + case opClose: + ce, err := parseClosePayload(b) + if err != nil { + c.close(xerrors.Errorf("received invalid close payload: %w", err)) + return + } + if ce.Code == StatusNoStatusRcvd { + c.writeClose(nil, ce) + } else { + c.Close(ce.Code, ce.Reason) + } + default: + panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) + } } -// Close closes the WebSocket connection with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5 seconds. -// The connection can only be closed once. Additional calls to Close -// are no-ops. +// Reader waits until there is a WebSocket data message to read +// from the connection. +// It returns the type of the message and a reader to read it. +// The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. // -// The maximum length of reason must be 125 bytes otherwise an internal -// error will be sent to the peer. For this reason, you should avoid -// sending a dynamic reason. +// Control (ping, pong, close) frames will be handled automatically +// in a separate goroutine so if you do not expect any data messages, +// you do not need to read from the connection. However, if the peer +// sends a data message, further pings, pongs and close frames will not +// be read if you do not read the message from the connection. // -// Close will unblock all goroutines interacting with the connection. -func (c *Conn) Close(code StatusCode, reason string) error { - err := c.exportedClose(code, reason) +// Only one Reader may be open at a time. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, r, err := c.reader(ctx) if err != nil { - return xerrors.Errorf("failed to close connection: %w", err) + return 0, nil, xerrors.Errorf("failed to get reader: %w", err) } - return nil + return typ, &limitedReader{ + c: c, + r: r, + left: c.msgReadLimit, + }, nil } -func (c *Conn) exportedClose(code StatusCode, reason string) error { - ce := CloseError{ - Code: code, - Reason: reason, +func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { + if c.previousReader != nil && c.previousReader.h != nil { + // The only way we know for sure the previous reader is not yet complete is + // if there is an active frame not yet fully read. + // Otherwise, a user may have read the last byte but not the EOF if the EOF + // is in the next frame so we check for that below. + return 0, nil, xerrors.Errorf("previous message not read to completion") } - // This function also will not wait for a close frame from the peer like the RFC - // wants because that makes no sense and I don't think anyone actually follows that. - // Definitely worth seeing what popular browsers do later. - p, err := ce.bytes() - if err != nil { - fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) - ce = CloseError{ - Code: StatusInternalError, + select { + case <-c.closed: + return 0, nil, c.closeErr + case <-ctx.Done(): + return 0, nil, ctx.Err() + case h := <-c.readMsg: + if c.previousReader != nil && !c.previousReader.done { + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data message without finishing the previous message") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err + } + + if !h.fin || h.payloadLength > 0 { + return 0, nil, xerrors.Errorf("previous message not read to completion") + } + + c.previousReader.done = true + + select { + case <-c.closed: + return 0, nil, c.closeErr + case c.readMsgDone <- struct{}{}: + } + + return c.reader(ctx) + } else if h.opcode == opContinuation { + err := xerrors.Errorf("received continuation frame not after data or text frame") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err } - p, _ = ce.bytes() + + r := &messageReader{ + ctx: ctx, + c: c, + + h: &h, + } + c.previousReader = r + return MessageType(h.opcode), r, nil } +} - return c.writeClose(p, ce) +// messageReader enables reading a data frame from the WebSocket connection. +type messageReader struct { + ctx context.Context + c *Conn + + h *header + maskPos int + done bool } -func (c *Conn) writeClose(p []byte, cerr CloseError) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() +// Read reads as many bytes as possible into p. +func (r *messageReader) Read(p []byte) (int, error) { + n, err := r.read(p) + if err != nil { + // Have to return io.EOF directly for now, we cannot wrap as xerrors + // isn't used in stdlib. + if xerrors.Is(err, io.EOF) { + return n, io.EOF + } + return n, xerrors.Errorf("failed to read: %w", err) + } + return n, nil +} + +func (r *messageReader) read(p []byte) (int, error) { + if r.done { + return 0, xerrors.Errorf("cannot use EOFed reader") + } + + if r.h == nil { + select { + case <-r.c.closed: + return 0, r.c.closeErr + case <-r.ctx.Done(): + r.c.close(xerrors.Errorf("failed to read: %w", r.ctx.Err())) + return 0, r.ctx.Err() + case h := <-r.c.readMsg: + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data frame without finishing the previous frame") + r.c.Close(StatusProtocolError, err.Error()) + return 0, err + } + r.h = &h + } + } + + if int64(len(p)) > r.h.payloadLength { + p = p[:r.h.payloadLength] + } + + n, err := r.c.readFramePayload(r.ctx, p) + + r.h.payloadLength -= int64(n) + if r.h.masked { + r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) + } - err := c.writeControl(ctx, opClose, p) if err != nil { - return err + return n, err } - c.close(cerr) - if !xerrors.Is(c.closeErr, cerr) { - return c.closeErr + if r.h.payloadLength == 0 { + select { + case <-r.c.closed: + return n, r.c.closeErr + case r.c.readMsgDone <- struct{}{}: + } + + fin := r.h.fin + + // Need to nil this as Reader uses it to check + // whether there is active data on the previous reader and + // now there isn't. + r.h = nil + + if fin { + r.done = true + return n, io.EOF + } + + r.maskPos = 0 } - return nil + return n, nil } -func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { + err := c.acquireLock(ctx, c.readFrameLock) + if err != nil { + return 0, err + } + defer c.releaseLock(c.readFrameLock) + select { - case <-ctx.Done(): - var err error - switch lock { - case c.writeFrameLock, c.writeMsgLock: - err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err()) - case c.readFrameLock: - err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err()) + case <-c.closed: + return 0, c.closeErr + case c.setReadTimeout <- ctx: + } + + n, err := io.ReadFull(c.br, p) + if err != nil { + select { + case <-c.closed: + return n, c.closeErr + case <-ctx.Done(): + err = ctx.Err() default: - panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) } + err = xerrors.Errorf("failed to read from connection: %w", err) + c.releaseLock(c.readFrameLock) c.close(err) - return ctx.Err() + return n, err + } + + select { case <-c.closed: - return c.closeErr - case lock <- struct{}{}: - return nil + return n, c.closeErr + case c.setReadTimeout <- context.Background(): } + + return n, err } -func (c *Conn) releaseLock(lock chan struct{}) { - // Allow multiple releases. - select { - case <-lock: - default: +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusPolicyViolation. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit = n +} + +// Read is a convenience method to read a single message from the connection. +// +// See the Reader method if you want to be able to reuse buffers or want to stream a message. +// The docs on Reader apply to this method as well. +// +// This is an experimental API, please let me know how you feel about it in +// https://github.com/nhooyr/websocket/issues/62 +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, err } + + b, err := ioutil.ReadAll(r) + return typ, b, err } // Writer returns a writer bounded by the context that will write @@ -488,28 +575,11 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err if err != nil { return nil, err } - return &messageWriter{ - ctx: ctx, - opcode: opcode(typ), - c: c, - }, nil -} - -// Read is a convenience method to read a single message from the connection. -// -// See the Reader method if you want to be able to reuse buffers or want to stream a message. -// The docs on Reader apply to this method as well. -// -// This is an experimental API, please let me know how you feel about it in -// https://github.com/nhooyr/websocket/issues/62 -func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - typ, r, err := c.Reader(ctx) - if err != nil { - return 0, nil, err - } - - b, err := ioutil.ReadAll(r) - return typ, b, err + return &messageWriter{ + ctx: ctx, + opcode: opcode(typ), + c: c, + }, nil } // Write is a convenience method to write a message to the connection. @@ -592,194 +662,146 @@ func (w *messageWriter) close() error { return nil } -// Reader waits until there is a WebSocket data message to read -// from the connection. -// It returns the type of the message and a reader to read it. -// The passed context will also bound the reader. -// Ensure you read to EOF otherwise the connection will hang. -// -// Control (ping, pong, close) frames will be handled automatically -// in a separate goroutine so if you do not expect any data messages, -// you do not need to read from the connection. However, if the peer -// sends a data message, further pings, pongs and close frames will not -// be read if you do not read the message from the connection. -// -// Only one Reader may be open at a time. -func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - typ, r, err := c.reader(ctx) +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { + err := c.writeFrame(ctx, true, opcode, p) if err != nil { - return 0, nil, xerrors.Errorf("failed to get reader: %w", err) + return xerrors.Errorf("failed to write control frame: %w", err) } - return typ, &limitedReader{ - c: c, - r: r, - left: c.msgReadLimit, - }, nil + return nil } -func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { - if c.previousReader.h != nil && c.previousReader.h.payloadLength > 0 { - return 0, nil, xerrors.Errorf("previous message not read to completion") +// writeFrame handles all writes to the connection. +// We never mask inside here because our mask key is always 0,0,0,0. +// See comment on secWebSocketKey for why. +func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error { + h := header{ + fin: fin, + opcode: opcode, + masked: c.client, + payloadLength: int64(len(p)), + } + b2 := marshalHeader(h) + + err := c.acquireLock(ctx, c.writeFrameLock) + if err != nil { + return err } + defer c.releaseLock(c.writeFrameLock) select { case <-c.closed: - return 0, nil, c.closeErr - case <-ctx.Done(): - return 0, nil, ctx.Err() - case h := <-c.readMsg: - if c.previousReader != nil && !c.previousReader.done { - if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err - } - - if !h.fin || h.payloadLength > 0 { - return 0, nil, xerrors.Errorf("previous message not read to completion") - } + return c.closeErr + case c.setWriteTimeout <- ctx: + } - c.previousReader.done = true - return c.reader(ctx) - } else if h.opcode == opContinuation { - err := xerrors.Errorf("received continuation frame not after data or text frame") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err - } - r := &messageReader{ - ctx: ctx, - h: &h, - c: c, + writeErr := func(err error) error { + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err = ctx.Err() + default: } - c.previousReader = r - return MessageType(h.opcode), r, nil - } -} -// messageReader enables reading a data frame from the WebSocket connection. -type messageReader struct { - ctx context.Context - c *Conn + err = xerrors.Errorf("failed to write to connection: %w", err) + // We need to release the lock first before closing the connection to ensure + // the lock can be acquired inside close to ensure no one can access c.bw. + c.releaseLock(c.writeFrameLock) + c.close(err) - h *header - maskPos int - done bool -} + return err + } -// Read reads as many bytes as possible into p. -func (r *messageReader) Read(p []byte) (int, error) { - n, err := r.read(p) + _, err = c.bw.Write(b2) if err != nil { - // Have to return io.EOF directly for now, we cannot wrap as xerrors - // isn't used in stdlib. - if xerrors.Is(err, io.EOF) { - return n, io.EOF - } - err = xerrors.Errorf("failed to read: %w", err) - r.c.close(err) - return n, err + return writeErr(err) } - return n, nil -} - -func (r *messageReader) read(p []byte) (int, error) { - if r.done { - return 0, xerrors.Errorf("cannot use EOFed reader") + _, err = c.bw.Write(p) + if err != nil { + return writeErr(err) } - if r.h == nil { - select { - case <-r.c.closed: - return 0, r.c.closeErr - case <-r.ctx.Done(): - return 0, r.ctx.Err() - case h := <-r.c.readMsg: - if h.opcode != opContinuation { - err := xerrors.Errorf("received new data frame without finishing the previous frame") - r.c.Close(StatusProtocolError, err.Error()) - return 0, err - } - r.h = &h + if fin { + err = c.bw.Flush() + if err != nil { + return writeErr(err) } } - if int64(len(p)) > r.h.payloadLength { - p = p[:r.h.payloadLength] + // We already finished writing, no need to potentially brick the connection if + // the context expires. + select { + case <-c.closed: + return c.closeErr + case c.setWriteTimeout <- context.Background(): } - n, err := r.c.readFramePayload(r.ctx, p) + return nil +} - r.h.payloadLength -= int64(n) - if r.h.masked { - r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) - } +func (c *Conn) writePong(p []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.writeControl(ctx, opPong, p) + return err +} +// Close closes the WebSocket connection with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5 seconds. +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes otherwise an internal +// error will be sent to the peer. For this reason, you should avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection. +func (c *Conn) Close(code StatusCode, reason string) error { + err := c.exportedClose(code, reason) if err != nil { - return n, err + return xerrors.Errorf("failed to close connection: %w", err) } + return nil +} - if r.h.payloadLength == 0 { - select { - case <-r.c.closed: - return n, r.c.closeErr - case r.c.readMsgDone <- struct{}{}: - } +func (c *Conn) exportedClose(code StatusCode, reason string) error { + ce := CloseError{ + Code: code, + Reason: reason, + } - if r.h.fin { - r.done = true - return n, io.EOF + // This function also will not wait for a close frame from the peer like the RFC + // wants because that makes no sense and I don't think anyone actually follows that. + // Definitely worth seeing what popular browsers do later. + p, err := ce.bytes() + if err != nil { + fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) + ce = CloseError{ + Code: StatusInternalError, } - - r.maskPos = 0 - r.h = nil + p, _ = ce.bytes() } - return n, nil + return c.writeClose(p, ce) } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { - err := c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.readFrameLock) - - select { - case <-c.closed: - return 0, c.closeErr - case c.setReadTimeout <- ctx: - } +func (c *Conn) writeClose(p []byte, cerr CloseError) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - n, err := io.ReadFull(c.br, p) + err := c.writeControl(ctx, opClose, p) if err != nil { - select { - case <-c.closed: - return n, c.closeErr - default: - } - err = xerrors.Errorf("failed to read from connection: %w", err) - c.releaseLock(c.readFrameLock) - c.close(err) - return n, err + return err } - select { - case <-c.closed: - return 0, c.closeErr - case c.setReadTimeout <- context.Background(): + c.close(cerr) + if !xerrors.Is(c.closeErr, cerr) { + return c.closeErr } - return 0, err -} - -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -// -// When the limit is hit, the connection will be closed with StatusPolicyViolation. -func (c *Conn) SetReadLimit(n int64) { - c.msgReadLimit = n + return nil } func init() { @@ -794,9 +816,7 @@ func init() { func (c *Conn) Ping(ctx context.Context) error { err := c.ping(ctx) if err != nil { - err = xerrors.Errorf("failed to ping: %w", err) - c.close(err) - return err + return xerrors.Errorf("failed to ping: %w", err) } return nil } @@ -823,10 +843,11 @@ func (c *Conn) ping(ctx context.Context) error { } select { - case <-ctx.Done(): - return ctx.Err() case <-c.closed: return c.closeErr + case <-ctx.Done(): + c.close(xerrors.Errorf("failed to ping: %w", ctx.Err())) + return ctx.Err() case <-pong: return nil }