diff --git a/flyteadmin/auth/handler_utils.go b/flyteadmin/auth/handler_utils.go index e6fd1a7236..a6b4031ca8 100644 --- a/flyteadmin/auth/handler_utils.go +++ b/flyteadmin/auth/handler_utils.go @@ -8,6 +8,7 @@ import ( "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "github.com/flyteorg/flyte/flyteadmin/auth/config" + "github.com/flyteorg/flyte/flytestdlib/logger" ) const ( @@ -146,3 +147,32 @@ func GetPublicURL(ctx context.Context, req *http.Request, cfg *config.Config) *u return u } + +func isAuthorizedRedirectURL(url *url.URL, authorizedURL *url.URL) bool { + return url.Hostname() == authorizedURL.Hostname() && url.Port() == authorizedURL.Port() && url.Scheme == authorizedURL.Scheme +} + +func GetRedirectURLAllowed(ctx context.Context, urlRedirectParam string, cfg *config.Config) bool { + if len(urlRedirectParam) == 0 { + logger.Debugf(ctx, "not validating whether empty redirect url is authorized") + return true + } + redirectURL, err := url.Parse(urlRedirectParam) + if err != nil { + logger.Debugf(ctx, "failed to parse user-supplied redirect url: %s with err: %v", urlRedirectParam, err) + return false + } + if redirectURL.Host == "" { + logger.Debugf(ctx, "not validating whether relative redirect url is authorized") + return true + } + logger.Debugf(ctx, "validating whether redirect url: %s is authorized", redirectURL) + for _, authorizedURI := range cfg.AuthorizedURIs { + if isAuthorizedRedirectURL(redirectURL, &authorizedURI.URL) { + logger.Debugf(ctx, "authorizing redirect url: %s against authorized uri: %s", redirectURL.String(), authorizedURI.String()) + return true + } + } + logger.Debugf(ctx, "not authorizing redirect url: %s", redirectURL.String()) + return false +} diff --git a/flyteadmin/auth/handler_utils_test.go b/flyteadmin/auth/handler_utils_test.go index 441f83dbbb..c44a57b934 100644 --- a/flyteadmin/auth/handler_utils_test.go +++ b/flyteadmin/auth/handler_utils_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/flyteorg/flyte/flyteadmin/auth/config" - config2 "github.com/flyteorg/flyte/flytestdlib/config" + flytestdconfig "github.com/flyteorg/flyte/flytestdlib/config" ) func TestGetPublicURL(t *testing.T) { @@ -16,7 +16,7 @@ func TestGetPublicURL(t *testing.T) { req, err := http.NewRequest(http.MethodPost, "https://abc", nil) assert.NoError(t, err) u := GetPublicURL(context.Background(), req, &config.Config{ - AuthorizedURIs: []config2.URL{ + AuthorizedURIs: []flytestdconfig.URL{ {URL: *config.MustParseURL("https://xyz")}, {URL: *config.MustParseURL("https://abc")}, }, @@ -28,7 +28,7 @@ func TestGetPublicURL(t *testing.T) { req, err := http.NewRequest(http.MethodPost, "https://abc", nil) assert.NoError(t, err) u := GetPublicURL(context.Background(), req, &config.Config{ - AuthorizedURIs: []config2.URL{ + AuthorizedURIs: []flytestdconfig.URL{ {URL: *config.MustParseURL("https://xyz")}, {URL: *config.MustParseURL("http://abc")}, }, @@ -40,7 +40,7 @@ func TestGetPublicURL(t *testing.T) { req, err := http.NewRequest(http.MethodPost, "https://abc", nil) assert.NoError(t, err) u := GetPublicURL(context.Background(), req, &config.Config{ - AuthorizedURIs: []config2.URL{ + AuthorizedURIs: []flytestdconfig.URL{ {URL: *config.MustParseURL("https://xyz")}, {URL: *config.MustParseURL("http://xyz")}, }, @@ -61,7 +61,7 @@ func TestGetPublicURL(t *testing.T) { assert.NoError(t, err) u := GetPublicURL(context.Background(), req, &config.Config{ - AuthorizedURIs: []config2.URL{ + AuthorizedURIs: []flytestdconfig.URL{ {URL: *config.MustParseURL("http://flyteadmin:80")}, {URL: *config.MustParseURL("http://localhost:30081")}, {URL: *config.MustParseURL("http://localhost:8089")}, @@ -72,3 +72,28 @@ func TestGetPublicURL(t *testing.T) { assert.Equal(t, "http://localhost:30081", u.String()) }) } + +func TestGetRedirectURLAllowed(t *testing.T) { + ctx := context.TODO() + t.Run("relative url", func(t *testing.T) { + assert.True(t, GetRedirectURLAllowed(ctx, "/console", &config.Config{})) + }) + t.Run("no redirect url", func(t *testing.T) { + assert.True(t, GetRedirectURLAllowed(ctx, "", &config.Config{})) + }) + cfg := &config.Config{ + AuthorizedURIs: []flytestdconfig.URL{ + {URL: *config.MustParseURL("https://example.com")}, + {URL: *config.MustParseURL("http://localhost:3008")}, + }, + } + t.Run("authorized url", func(t *testing.T) { + assert.True(t, GetRedirectURLAllowed(ctx, "https://example.com", cfg)) + }) + t.Run("authorized localhost url", func(t *testing.T) { + assert.True(t, GetRedirectURLAllowed(ctx, "http://localhost:3008", cfg)) + }) + t.Run("unauthorized url", func(t *testing.T) { + assert.False(t, GetRedirectURLAllowed(ctx, "https://flyte.com", cfg)) + }) +} diff --git a/flyteadmin/auth/handlers.go b/flyteadmin/auth/handlers.go index 26e6428df3..a6c2e3b122 100644 --- a/flyteadmin/auth/handlers.go +++ b/flyteadmin/auth/handlers.go @@ -141,6 +141,11 @@ func GetLoginHandler(ctx context.Context, authCtx interfaces.AuthenticationConte logger.Debugf(ctx, "Setting CSRF state cookie to %s and state to %s\n", csrfToken, state) url := authCtx.OAuth2ClientConfig(GetPublicURL(ctx, request, authCtx.Options())).AuthCodeURL(state) queryParams := request.URL.Query() + if !GetRedirectURLAllowed(ctx, queryParams.Get(RedirectURLParameter), authCtx.Options()) { + logger.Infof(ctx, "unauthorized redirect URI") + writer.WriteHeader(http.StatusForbidden) + return + } if flowEndRedirectURL := queryParams.Get(RedirectURLParameter); flowEndRedirectURL != "" { redirectCookie := NewRedirectCookie(ctx, flowEndRedirectURL) if redirectCookie != nil {