diff --git a/internal/dispatcher/dispatcher.go b/internal/dispatcher/dispatcher.go index 7926ed5b..07186835 100644 --- a/internal/dispatcher/dispatcher.go +++ b/internal/dispatcher/dispatcher.go @@ -116,7 +116,7 @@ func (d *Dispatcher) StartSessions(ctx context.Context, domains []universal.Doma err = <-results // The aggregateContext is canceled if one of the handshakes fails. We don't want to return // the Canceled error if ErrProtocolNotSupported is present. - if !errors.Is(err, context.Canceled) { + if err != nil && !errors.Is(err, context.Canceled) { return err } } diff --git a/internal/dispatcher/dispatcher_test.go b/internal/dispatcher/dispatcher_test.go index e660b73f..79bec2e7 100644 --- a/internal/dispatcher/dispatcher_test.go +++ b/internal/dispatcher/dispatcher_test.go @@ -19,6 +19,7 @@ import ( ) var errOutboxFull = errors.New("dispatcher: outbox full") +var errDropMessage = errors.New("dispatcher: simulated dropped message") var errTimeout = errors.New("dispatcher: simulated timeout") var testPayload = []byte("ack") var quiescentDelay = 250 * time.Millisecond @@ -246,7 +247,11 @@ func (d *dummyConnector) Send(ctx context.Context, buffer []byte) error { err := d.errorQueue[0] d.errorQueue = d.errorQueue[1:] d.lock.Unlock() - return err + if err == errDropMessage { + return nil + } else if err != nil { + return err + } } if err := proto.Unmarshal(buffer, &message); err != nil { return err @@ -646,6 +651,36 @@ func TestConnect(t *testing.T) { } } +func TestWaitForAllSessions(t *testing.T) { + conn := newDummyConnector(t) + defer conn.Close() + + // Configure the Connector to only respond to the first of two handshakes + conn.EnqueueSendError(nil) + conn.EnqueueSendError(errDropMessage) + + key, err := authentication.NewECDHPrivateKey(rand.Reader) + if err != nil { + t.Fatalf("Couldn't create private key: %s", err) + } + + dispatcher, err := New(conn, key) + if err != nil { + t.Fatalf("Couldn't initialize dispatcher: %s", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), quiescentDelay) + defer cancel() + + if err := dispatcher.Start(ctx); err != nil { + t.Fatal(err) + } + + if err := dispatcher.StartSessions(ctx, nil); err != context.DeadlineExceeded { + t.Fatalf("Unexpected error: %s", err) + } +} + func TestRetrySend(t *testing.T) { dispatcher, conn := getTestSetup(t) defer conn.Close()