From 75d319711c2d87c586ed8b1edb09ce3adcaf0c6d Mon Sep 17 00:00:00 2001 From: Alex Hamlin Date: Thu, 2 Nov 2023 21:39:50 -0700 Subject: [PATCH] Use contexts to shut down WebSocket handlers I tried this out a long time ago in an experimental branch, alongside an evaluation of a different WebSocket library. I don't see a need to switch the WebSocket implementation at this time, but this half of the change seems like a reasonable simplification. --- internal/api/tunerstatus.go | 32 ++++++++++++++------------------ internal/api/webrtc.go | 34 +++++++++++++++------------------- 2 files changed, 29 insertions(+), 37 deletions(-) diff --git a/internal/api/tunerstatus.go b/internal/api/tunerstatus.go index f92f64c..aac8d2b 100644 --- a/internal/api/tunerstatus.go +++ b/internal/api/tunerstatus.go @@ -1,6 +1,7 @@ package api import ( + "context" "log" "net/http" "sync" @@ -12,30 +13,32 @@ import ( ) type tunerStatusHandler struct { - tuner *tuner.Tuner - socket *websocket.Conn - watch watch.Watch - shutdownErr chan error - waitGroup sync.WaitGroup + tuner *tuner.Tuner + socket *websocket.Conn + watch watch.Watch + ctx context.Context + shutdown context.CancelCauseFunc + waitGroup sync.WaitGroup } func (h *Handler) handleSocketTunerStatus(w http.ResponseWriter, r *http.Request) { + ctx, shutdown := context.WithCancelCause(r.Context()) tsh := &tunerStatusHandler{ - tuner: h.tuner, - shutdownErr: make(chan error, 1), + tuner: h.tuner, + ctx: ctx, + shutdown: shutdown, } tsh.ServeHTTP(w, r) } func (tsh *tunerStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var err error - tsh.logf("Starting new connection") defer func() { tsh.waitForCleanup() - tsh.logf("Finished with error: %v", err) + tsh.logf("Connection done: %v", context.Cause(tsh.ctx)) }() + var err error tsh.socket, err = websocketUpgrader.Upgrade(w, r, nil) if err != nil { return @@ -51,7 +54,7 @@ func (tsh *tunerStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) tsh.watch = tsh.tuner.WatchStatus(tsh.sendNewTunerStatus) defer tsh.watch.Cancel() - err = <-tsh.shutdownErr + <-tsh.ctx.Done() } func (tsh *tunerStatusHandler) sendNewTunerStatus(s tuner.Status) { @@ -75,13 +78,6 @@ func (tsh *tunerStatusHandler) drainClient() { } } -func (tsh *tunerStatusHandler) shutdown(err error) { - select { - case tsh.shutdownErr <- err: - default: - } -} - func (tsh *tunerStatusHandler) waitForCleanup() { if tsh.watch != nil { tsh.watch.Wait() diff --git a/internal/api/webrtc.go b/internal/api/webrtc.go index 60227bf..c6c3b80 100644 --- a/internal/api/webrtc.go +++ b/internal/api/webrtc.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "log" "net/http" @@ -47,31 +48,33 @@ func init() { } type webrtcHandler struct { - tuner *tuner.Tuner - socket *websocket.Conn - rtcPeer *webrtc.PeerConnection - watch watch.Watch - shutdownErr chan error - waitGroup sync.WaitGroup + tuner *tuner.Tuner + socket *websocket.Conn + rtcPeer *webrtc.PeerConnection + watch watch.Watch + ctx context.Context + shutdown context.CancelCauseFunc + waitGroup sync.WaitGroup } func (h *Handler) handleSocketWebRTCPeer(w http.ResponseWriter, r *http.Request) { + ctx, shutdown := context.WithCancelCause(r.Context()) wh := &webrtcHandler{ - tuner: h.tuner, - shutdownErr: make(chan error, 1), + tuner: h.tuner, + ctx: ctx, + shutdown: shutdown, } wh.ServeHTTP(w, r) } func (wh *webrtcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var err error - wh.logf("Starting new connection") defer func() { wh.waitForCleanup() - wh.logf("Finished with error: %v", err) + wh.logf("Connection done: %v", context.Cause(wh.ctx)) }() + var err error wh.socket, err = websocketUpgrader.Upgrade(w, r, nil) if err != nil { return @@ -93,7 +96,7 @@ func (wh *webrtcHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { wh.watch = wh.tuner.WatchTracks(wh.handleTrackUpdate) defer wh.watch.Cancel() - err = <-wh.shutdownErr + <-wh.ctx.Done() } func (wh *webrtcHandler) handleClientSessionAnswers() { @@ -207,13 +210,6 @@ func (wh *webrtcHandler) hasTransceivers() bool { return len(wh.rtcPeer.GetTransceivers()) > 0 } -func (wh *webrtcHandler) shutdown(err error) { - select { - case wh.shutdownErr <- err: - default: - } -} - func (wh *webrtcHandler) waitForCleanup() { if wh.watch != nil { wh.watch.Wait()