diff --git a/connect/connection.go b/connect/connection.go index 695965bf..ac0504c4 100644 --- a/connect/connection.go +++ b/connect/connection.go @@ -21,7 +21,23 @@ type connectReport struct { err error } -func (h *connectHandler) connect(ctx context.Context, data connectionEstablishData) { +type connectOpt func(opts *connectOpts) +type connectOpts struct { + notifyConnectedChan chan struct{} +} + +func withNotifyConnectedChan(ch chan struct{}) connectOpt { + return func(opts *connectOpts) { + opts.notifyConnectedChan = ch + } +} + +func (h *connectHandler) connect(ctx context.Context, data connectionEstablishData, opts ...connectOpt) { + o := connectOpts{} + for _, opt := range opts { + opt(&o) + } + // Set up connection (including connect handshake protocol) preparedConn, err := h.prepareConnection(ctx, data) if err != nil { @@ -37,6 +53,12 @@ func (h *connectHandler) connect(ctx context.Context, data connectionEstablishDa // Notify that the connection was established h.notifyConnectedChan <- struct{}{} + // If an additional notification channel was provided, notify it as well + if o.notifyConnectedChan != nil { + o.notifyConnectedChan <- struct{}{} + close(o.notifyConnectedChan) + } + // Set up connection lifecycle logic (receiving messages, handling requests, etc.) err = h.handleConnection(ctx, data, preparedConn.ws, preparedConn.gatewayHost) if err != nil { @@ -145,7 +167,7 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs case <-ctx.Done(): return case <-heartbeatTicker.C: - err := wsproto.Write(ctx, ws, &connectproto.ConnectMessage{ + err := wsproto.Write(context.Background(), ws, &connectproto.ConnectMessage{ Kind: connectproto.GatewayMessageType_WORKER_HEARTBEAT, }) if err != nil { @@ -199,26 +221,22 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs // Gateway is draining and will not accept new connections. // We must reconnect to a different gateway, only then can we close the old connection. - waitUntilConnected, doneWaiting := context.WithCancel(context.Background()) + waitUntilConnected, doneWaiting := context.WithTimeout(context.Background(), 10*time.Second) defer doneWaiting() - // Intercept connected signal and pass it to the main goroutine - notifyConnectedInterceptChan := make(chan struct{}) + // Set up local notification listener + notifyConnectedChan := make(chan struct{}) go func() { - <-h.notifyConnectedChan - notifyConnectedInterceptChan <- struct{}{} + <-notifyConnectedChan doneWaiting() }() - // Establish new connection and pass close reports back to the main goroutine - go h.connect(context.Background(), data) - - cancel() + // Establish new connection, notify the routine above when the new connection is established + go h.connect(context.Background(), data, withNotifyConnectedChan(notifyConnectedChan)) // Wait until the new connection is established before closing the old one - select { - case <-waitUntilConnected.Done(): - case <-time.After(10 * time.Second): + <-waitUntilConnected.Done() + if errors.Is(waitUntilConnected.Err(), context.DeadlineExceeded) { h.logger.Error("timed out waiting for new connection to be established") }