Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix "zombie" http output websocket connections #1

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 94 additions & 24 deletions internal/impl/io/output_http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ const (
hsoFieldCORS = "cors"
hsoFieldCORSEnabled = "enabled"
hsoFieldCORSAllowedOrigins = "allowed_origins"
hsoFieldWriteWait = "write_wait"
hsoFieldPongWait = "pong_wait"
hsoFieldPingPeriod = "ping_period"
)

type hsoConfig struct {
Expand All @@ -54,6 +57,9 @@ type hsoConfig struct {
CertFile string
KeyFile string
CORS httpserver.CORSConfig
WriteWait time.Duration
PongWait time.Duration
PingPeriod time.Duration
}

func hsoConfigFromParsed(pConf *service.ParsedConfig) (conf hsoConfig, err error) {
Expand Down Expand Up @@ -95,6 +101,15 @@ func hsoConfigFromParsed(pConf *service.ParsedConfig) (conf hsoConfig, err error
if conf.CORS, err = corsConfigFromParsed(pConf.Namespace(hsoFieldCORS)); err != nil {
return
}
if conf.WriteWait, err = pConf.FieldDuration(hsoFieldWriteWait); err != nil {
return
}
if conf.PongWait, err = pConf.FieldDuration(hsoFieldPongWait); err != nil {
return
}
if conf.PingPeriod, err = pConf.FieldDuration(hsoFieldPingPeriod); err != nil {
return
}
return
}

Expand Down Expand Up @@ -145,6 +160,18 @@ Please note, messages are considered delivered as soon as the data is written to
Advanced().
Default(""),
service.NewInternalField(corsSpec),
service.NewDurationField(hsoFieldWriteWait).
Description("The time allowed to write a message to the websocket.").
Default("10s").
Advanced(),
service.NewDurationField(hsoFieldPongWait).
Description("The time allowed to read the next pong message from the client.").
Default("60s").
Advanced(),
service.NewDurationField(hsoFieldPingPeriod).
Description("Send pings to client with this period. Must be less than pong wait.").
Default("54s").
Advanced(),
)
}

Expand Down Expand Up @@ -393,50 +420,93 @@ func (h *httpServerOutput) wsHandler(w http.ResponseWriter, r *http.Request) {
defer func() {
if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
h.log.Warn("Websocket request failed: %v\n", err)
h.log.Warn("WebSocket request failed: %v", err)
return
}
}()

upgrader := websocket.Upgrader{}

var ws *websocket.Conn
if ws, err = upgrader.Upgrade(w, r, nil); err != nil {
// Upgrade the HTTP connection to a WebSocket connection
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.log.Warn("WebSocket upgrade failed: %v", err)
return
}
defer ws.Close()

ctx, done := h.shutSig.SoftStopCtx(r.Context())
defer done()
ws.SetReadLimit(512)
if err := ws.SetReadDeadline(time.Now().Add(h.conf.PongWait)); err != nil {
h.log.Warn("Failed to set read deadline: %v", err)
return
}

for !h.shutSig.IsSoftStopSignalled() {
var ts message.Transaction
var open bool
ws.SetPongHandler(func(string) error {
return ws.SetReadDeadline(time.Now().Add(h.conf.PongWait))
})

// Start a goroutine to read messages (to process control frames)
done := make(chan struct{})
go func() {
defer close(done)
for {
_, _, err := ws.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
h.log.Warn("WebSocket read error: %v", err)
}
break
}
}
}()

// Start ticker to send ping messages to the client periodically
ticker := time.NewTicker(h.conf.PingPeriod)
defer ticker.Stop()

ctx, doneCtx := h.shutSig.SoftStopCtx(r.Context())
defer doneCtx()

for !h.shutSig.IsSoftStopSignalled() {
select {
case ts, open = <-h.transactions:
case ts, open := <-h.transactions:
if !open {
// If the transactions channel is closed, trigger server shutdown
go h.TriggerCloseNow()
return
}
case <-r.Context().Done():
// Write messages to the client
var writeErr error
for _, msg := range message.GetAllBytes(ts.Payload) {
_ = ws.SetWriteDeadline(time.Now().Add(h.conf.WriteWait))
if writeErr = ws.WriteMessage(websocket.BinaryMessage, msg); writeErr != nil {
break
}
h.mWSBatchSent.Incr(1)
h.mWSSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload)))
}
if writeErr != nil {
h.mWSError.Incr(1)
_ = ts.Ack(ctx, writeErr)
return // Exit the loop on write error
}
_ = ts.Ack(ctx, nil)
case <-ticker.C:
// Send a ping message to the client
//nolint:errcheck // this function does not actually return an error
ws.SetWriteDeadline(time.Now().Add(h.conf.WriteWait))
if err := ws.WriteMessage(websocket.PingMessage, nil); err != nil {
h.log.Warn("WebSocket ping error: %v", err)
return
}
case <-done:
// The read goroutine has exited, indicating the client has disconnected
h.log.Debug("WebSocket client disconnected")
return
case <-h.shutSig.SoftStopChan():
case <-ctx.Done():
// The context has been canceled (e.g., server is shutting down)
return
}

var werr error
for _, msg := range message.GetAllBytes(ts.Payload) {
if werr = ws.WriteMessage(websocket.BinaryMessage, msg); werr != nil {
break
}
h.mWSBatchSent.Incr(1)
h.mWSSent.Incr(int64(batch.MessageCollapsedCount(ts.Payload)))
}
if werr != nil {
h.mWSError.Incr(1)
}
_ = ts.Ack(ctx, werr)
}
}

Expand Down
27 changes: 27 additions & 0 deletions website/docs/components/outputs/http_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ output:
cors:
enabled: false
allowed_origins: []
write_wait: 10s
pong_wait: 60s
ping_period: 54s
```

</TabItem>
Expand Down Expand Up @@ -172,4 +175,28 @@ An explicit list of origins that are allowed for CORS requests.
Type: `array`
Default: `[]`

### `write_wait`

The time allowed to write a message to the websocket.


Type: `string`
Default: `"10s"`

### `pong_wait`

The time allowed to read the next pong message from the client.


Type: `string`
Default: `"60s"`

### `ping_period`

Send pings to client with this period. Must be less than pong wait.


Type: `string`
Default: `"54s"`