diff --git a/internal/backend/workloadproxy/handler.go b/internal/backend/workloadproxy/handler.go index 11260584..40995749 100644 --- a/internal/backend/workloadproxy/handler.go +++ b/internal/backend/workloadproxy/handler.go @@ -174,7 +174,7 @@ func (h *HTTPHandler) parseServiceAliasFromHost(request *http.Request) string { return "" } - if isNewFormat := len(proxyServiceHostPrefixParts) == 2; isNewFormat { + if isNewFormat := proxyServiceHostPrefixParts[0] != LegacyHostPrefix; isNewFormat { return proxyServiceHostPrefixParts[0] } diff --git a/internal/backend/workloadproxy/handler_test.go b/internal/backend/workloadproxy/handler_test.go index 9dcff5a4..72026012 100644 --- a/internal/backend/workloadproxy/handler_test.go +++ b/internal/backend/workloadproxy/handler_test.go @@ -127,6 +127,18 @@ func TestHandler(t *testing.T) { t.Run("subdomain request with cookies", func(t *testing.T) { t.Parallel() + testSubdomainRequestWithCookies(ctx, t, mainURL, "instanceid") + }) + + t.Run("subdomain request with cookies - dash in instance name", func(t *testing.T) { + t.Parallel() + + testSubdomainRequestWithCookies(ctx, t, mainURL, "instance-id") + }) + + t.Run("subdomain request with cookies - legacy format", func(t *testing.T) { + t.Parallel() + next := &mockHandler{} proxyProvider := &mockProxyProvider{} accessValidator := &mockAccessValidator{} @@ -139,7 +151,7 @@ func TestHandler(t *testing.T) { testServiceAlias := "testsvc2" - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://%s-instanceid.proxy-us.example.com/example", testServiceAlias), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://%s-%s-instanceid.example.com/example", workloadproxy.LegacyHostPrefix, testServiceAlias), nil) require.NoError(t, err) testPublicKeyID := "test-public-key-id" @@ -158,39 +170,37 @@ func TestHandler(t *testing.T) { require.Equal(t, []string{testPublicKeyIDSignatureBase64}, accessValidator.publicKeyIDSignatureBase64s) require.Equal(t, []resource.ID{"test-cluster"}, accessValidator.clusterIDs) }) +} - t.Run("subdomain request with cookies - legacy format", func(t *testing.T) { - t.Parallel() - - next := &mockHandler{} - proxyProvider := &mockProxyProvider{} - accessValidator := &mockAccessValidator{} - logger := zaptest.NewLogger(t) +func testSubdomainRequestWithCookies(ctx context.Context, t *testing.T, mainURL *url.URL, instanceID string) { + next := &mockHandler{} + proxyProvider := &mockProxyProvider{} + accessValidator := &mockAccessValidator{} + logger := zaptest.NewLogger(t) - handler, err := workloadproxy.NewHTTPHandler(next, proxyProvider, accessValidator, mainURL, "proxy-us", logger) - require.NoError(t, err) + handler, err := workloadproxy.NewHTTPHandler(next, proxyProvider, accessValidator, mainURL, "proxy-us", logger) + require.NoError(t, err) - rr := httptest.NewRecorder() + rr := httptest.NewRecorder() - testServiceAlias := "testsvc2" + testServiceAlias := "testsvc2" - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://%s-%s-instanceid.example.com/example", workloadproxy.LegacyHostPrefix, testServiceAlias), nil) - require.NoError(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("https://%s-%s.proxy-us.example.com/example", testServiceAlias, instanceID), nil) + require.NoError(t, err) - testPublicKeyID := "test-public-key-id" - testPublicKeyIDSignatureBase64 := base64.StdEncoding.EncodeToString([]byte("test-signed-public-key-id")) + testPublicKeyID := "test-public-key-id" + testPublicKeyIDSignatureBase64 := base64.StdEncoding.EncodeToString([]byte("test-signed-public-key-id")) - req.AddCookie(&http.Cookie{Name: workloadproxy.PublicKeyIDCookie, Value: testPublicKeyID}) - req.AddCookie(&http.Cookie{Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: testPublicKeyIDSignatureBase64}) + req.AddCookie(&http.Cookie{Name: workloadproxy.PublicKeyIDCookie, Value: testPublicKeyID}) + req.AddCookie(&http.Cookie{Name: workloadproxy.PublicKeyIDSignatureBase64Cookie, Value: testPublicKeyIDSignatureBase64}) - handler.ServeHTTP(rr, req) + handler.ServeHTTP(rr, req) - require.Equal(t, []string{testServiceAlias}, proxyProvider.aliases) + require.Equal(t, []string{testServiceAlias}, proxyProvider.aliases) - require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, http.StatusOK, rr.Code) - require.Equal(t, []string{testPublicKeyID}, accessValidator.publicKeyIDs) - require.Equal(t, []string{testPublicKeyIDSignatureBase64}, accessValidator.publicKeyIDSignatureBase64s) - require.Equal(t, []resource.ID{"test-cluster"}, accessValidator.clusterIDs) - }) + require.Equal(t, []string{testPublicKeyID}, accessValidator.publicKeyIDs) + require.Equal(t, []string{testPublicKeyIDSignatureBase64}, accessValidator.publicKeyIDSignatureBase64s) + require.Equal(t, []resource.ID{"test-cluster"}, accessValidator.clusterIDs) }