diff --git a/pkg/provider/tcp/README.md b/pkg/provider/tcp/README.md index 1492183..1c5777f 100644 --- a/pkg/provider/tcp/README.md +++ b/pkg/provider/tcp/README.md @@ -13,6 +13,7 @@ The TCP Provider is configured through the platform-health server's configuratio * `name` (required): The name of the TCP service instance, used to identify the service in the health reports. * `host` (required): The hostname or IP address of the TCP service to monitor. * `port` (default: `80`): The port number of the TCP service to monitor. +* `invert` (default: `false`): Reverse logic to report "unhealthy" if port is open and "healthy" if it is closed. * `timeout` (default: `1s`): The maximum time to wait for a connection to be established before timing out. ### Example diff --git a/pkg/provider/tcp/tcp.go b/pkg/provider/tcp/tcp.go index 7d0b7b2..8d65df8 100644 --- a/pkg/provider/tcp/tcp.go +++ b/pkg/provider/tcp/tcp.go @@ -20,6 +20,7 @@ type TCP struct { Name string `mapstructure:"name"` Host string `mapstructure:"host"` Port int `mapstructure:"port" default:"80"` + Invert bool `mapstructure:"invert" default:"false"` Timeout time.Duration `mapstructure:"timeout" default:"1s"` } @@ -33,6 +34,7 @@ func (i *TCP) LogValue() slog.Value { slog.String("host", i.Host), slog.Int("port", i.Port), slog.Any("timeout", i.Timeout), + slog.Bool("invert", i.Invert), } return slog.GroupValue(logAttr...) } @@ -66,9 +68,17 @@ func (i *TCP) GetHealth(ctx context.Context) *ph.HealthCheckResponse { dialer := &net.Dialer{} conn, err := dialer.DialContext(ctx, "tcp", address) if err != nil { - return component.Unhealthy(err.Error()) + if i.Invert { + return component.Healthy() + } else { + return component.Unhealthy(err.Error()) + } + } else { + _ = conn.Close() + if i.Invert { + return component.Unhealthy("port open") + } else { + return component.Healthy() + } } - _ = conn.Close() - - return component.Healthy() } diff --git a/pkg/provider/tcp/tcp_test.go b/pkg/provider/tcp/tcp_test.go index 947f111..fd3ffcc 100644 --- a/pkg/provider/tcp/tcp_test.go +++ b/pkg/provider/tcp/tcp_test.go @@ -25,30 +25,54 @@ func TestTCP(t *testing.T) { } defer listener.Close() + port := listener.Addr().(*net.TCPAddr).Port + tests := []struct { - name string - port int - status ph.Status + name string + port int + invert bool + timeout time.Duration + expected ph.Status }{ { - name: "Port open", - port: listener.Addr().(*net.TCPAddr).Port, - status: ph.Status_HEALTHY, + name: "Port open", + port: port, + expected: ph.Status_HEALTHY, + }, + { + name: "Port closed", + port: 1, + expected: ph.Status_UNHEALTHY, + }, + { + name: "Port closed, expect failure", + port: 1, + invert: true, + expected: ph.Status_HEALTHY, + }, + { + name: "Unexpected timeout", + port: port, + timeout: time.Nanosecond, + expected: ph.Status_UNHEALTHY, }, { - name: "Port closed", - port: 1, - status: ph.Status_UNHEALTHY, + name: "Expected timeout", + port: port, + invert: true, + timeout: time.Nanosecond, + expected: ph.Status_HEALTHY, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { instance := &tcp.TCP{ - Name: "TestTCP", + Name: tt.name, Host: "localhost", Port: tt.port, - Timeout: time.Second, + Invert: tt.invert, + Timeout: tt.timeout, } instance.SetDefaults() @@ -56,8 +80,8 @@ func TestTCP(t *testing.T) { assert.NotNil(t, result) assert.Equal(t, tcp.TypeTCP, result.GetType()) - assert.Equal(t, instance.Name, result.GetName()) - assert.Equal(t, tt.status, result.GetStatus()) + assert.Equal(t, tt.name, result.GetName()) + assert.Equal(t, tt.expected, result.GetStatus()) }) } }