diff --git a/.gitignore b/.gitignore index f24947dc55..0178cf5feb 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ node_modules/ .virtualgo boilerplate/lyft/end2end/tmp +dist diff --git a/cmd/entrypoints/clusterresource.go b/cmd/entrypoints/clusterresource.go index ab63b3377d..4091ce25d3 100644 --- a/cmd/entrypoints/clusterresource.go +++ b/cmd/entrypoints/clusterresource.go @@ -3,22 +3,12 @@ package entrypoints import ( "context" - "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + errors2 "github.com/pkg/errors" - "github.com/flyteorg/flyteadmin/pkg/clusterresource/impl" - "github.com/flyteorg/flyteadmin/pkg/clusterresource/interfaces" - execClusterIfaces "github.com/flyteorg/flyteadmin/pkg/executioncluster/interfaces" - "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" - "github.com/flyteorg/flyteadmin/pkg/repositories" "github.com/flyteorg/flytestdlib/promutils" - "github.com/flyteorg/flyteidl/clients/go/admin" - "github.com/flyteorg/flyteadmin/pkg/clusterresource" - "github.com/flyteorg/flyteadmin/pkg/config" - executioncluster "github.com/flyteorg/flyteadmin/pkg/executioncluster/impl" "github.com/flyteorg/flyteadmin/pkg/runtime" - runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flytestdlib/logger" "github.com/spf13/cobra" @@ -30,74 +20,40 @@ var parentClusterResourceCmd = &cobra.Command{ Short: "This command administers the ClusterResourceController. Please choose a subcommand.", } -func getClusterResourceController(ctx context.Context, scope promutils.Scope, configuration runtimeInterfaces.Configuration) clusterresource.Controller { - initializationErrorCounter := scope.MustNewCounter( - "flyteclient_initialization_error", - "count of errors encountered initializing a flyte client from kube config") - var listTargetsProvider execClusterIfaces.ListTargetsInterface - var err error - if len(configuration.ClusterConfiguration().GetClusterConfigs()) == 0 { - serverConfig := config.GetConfig() - listTargetsProvider, err = executioncluster.NewInCluster(initializationErrorCounter, serverConfig.KubeConfig, serverConfig.Master) - } else { - listTargetsProvider, err = executioncluster.NewListTargets(initializationErrorCounter, executioncluster.NewExecutionTargetProvider(), configuration.ClusterConfiguration()) - } - if err != nil { - panic(err) - } - - var adminDataProvider interfaces.FlyteAdminDataProvider - if configuration.ClusterResourceConfiguration().IsStandaloneDeployment() { - clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).Build(ctx) - if err != nil { - panic(err) - } - adminDataProvider = impl.NewAdminServiceDataProvider(clientSet.AdminClient()) - } else { - dbConfig := runtime.NewConfigurationProvider().ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, dbConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - dbScope := scope.NewSubScope("db") - - repo := repositories.NewGormRepo( - db, errors.NewPostgresErrorTransformer(dbScope.NewSubScope("errors")), dbScope) - - adminDataProvider = impl.NewDatabaseAdminDataProvider(repo, configuration, resources.NewResourceManager(repo, configuration.ApplicationConfiguration())) - } - - return clusterresource.NewClusterResourceController(adminDataProvider, listTargetsProvider, scope) -} - var controllerRunCmd = &cobra.Command{ Use: "run", Short: "This command will start a cluster resource controller to periodically sync cluster resources", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() configuration := runtime.NewConfigurationProvider() scope := promutils.NewScope(configuration.ApplicationConfiguration().GetTopLevelConfig().MetricsScope).NewSubScope("clusterresource") - clusterResourceController := getClusterResourceController(ctx, scope, configuration) + clusterResourceController, err := clusterresource.NewClusterResourceControllerFromConfig(ctx, scope, configuration) + if err != nil { + return err + } clusterResourceController.Run() logger.Infof(ctx, "ClusterResourceController started running successfully") + return nil }, } var controllerSyncCmd = &cobra.Command{ Use: "sync", Short: "This command will sync cluster resources", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() configuration := runtime.NewConfigurationProvider() scope := promutils.NewScope(configuration.ApplicationConfiguration().GetTopLevelConfig().MetricsScope).NewSubScope("clusterresource") - clusterResourceController := getClusterResourceController(ctx, scope, configuration) - err := clusterResourceController.Sync(ctx) + clusterResourceController, err := clusterresource.NewClusterResourceControllerFromConfig(ctx, scope, configuration) + if err != nil { + return err + } + err = clusterResourceController.Sync(ctx) if err != nil { - logger.Fatalf(ctx, "Failed to sync cluster resources [%+v]", err) + return errors2.Wrap(err, "Failed to sync cluster resources ") } logger.Infof(ctx, "ClusterResourceController synced successfully") + return nil }, } diff --git a/cmd/entrypoints/migrate.go b/cmd/entrypoints/migrate.go index df1393f6cc..030a39390e 100644 --- a/cmd/entrypoints/migrate.go +++ b/cmd/entrypoints/migrate.go @@ -3,12 +3,8 @@ package entrypoints import ( "context" - "github.com/flyteorg/flyteadmin/pkg/repositories" - "github.com/flyteorg/flyteadmin/pkg/repositories/config" - "github.com/flyteorg/flyteadmin/pkg/runtime" - "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flyteadmin/pkg/server" - "github.com/go-gormigrate/gormigrate/v2" "github.com/spf13/cobra" _ "gorm.io/driver/postgres" // Required to import database driver. ) @@ -22,35 +18,9 @@ var parentMigrateCmd = &cobra.Command{ var migrateCmd = &cobra.Command{ Use: "run", Short: "This command will run all the migrations for the database", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - configuration := runtime.NewConfigurationProvider() - databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - sqlDB, err := db.DB() - if err != nil { - logger.Fatal(ctx, err) - } - - defer func(deferCtx context.Context) { - if err = sqlDB.Close(); err != nil { - logger.Fatal(deferCtx, err) - } - }(ctx) - - if err = sqlDB.Ping(); err != nil { - logger.Fatal(ctx, err) - } - m := gormigrate.New(db, gormigrate.DefaultOptions, config.Migrations) - if err = m.Migrate(); err != nil { - logger.Fatalf(ctx, "Could not migrate: %v", err) - } - logger.Infof(ctx, "Migration ran successfully") + return server.Migrate(ctx) }, } @@ -58,36 +28,9 @@ var migrateCmd = &cobra.Command{ var rollbackCmd = &cobra.Command{ Use: "rollback", Short: "This command will rollback one migration", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - configuration := runtime.NewConfigurationProvider() - databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - sqlDB, err := db.DB() - if err != nil { - logger.Fatal(ctx, err) - } - defer func(deferCtx context.Context) { - if err = sqlDB.Close(); err != nil { - logger.Fatal(deferCtx, err) - } - }(ctx) - - if err = sqlDB.Ping(); err != nil { - logger.Fatal(ctx, err) - } - - m := gormigrate.New(db, gormigrate.DefaultOptions, config.Migrations) - err = m.RollbackLast() - if err != nil { - logger.Fatalf(ctx, "Could not rollback latest migration: %v", err) - } - logger.Infof(ctx, "Rolled back one migration successfully") + return server.Rollback(ctx) }, } @@ -95,36 +38,9 @@ var rollbackCmd = &cobra.Command{ var seedProjectsCmd = &cobra.Command{ Use: "seed-projects", Short: "Seed projects in the database.", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - configuration := runtime.NewConfigurationProvider() - databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - - sqlDB, err := db.DB() - if err != nil { - logger.Fatal(ctx, err) - } - - defer func(deferCtx context.Context) { - if err = sqlDB.Close(); err != nil { - logger.Fatal(deferCtx, err) - } - }(ctx) - - if err = sqlDB.Ping(); err != nil { - logger.Fatal(ctx, err) - } - - if err = config.SeedProjects(db, args); err != nil { - logger.Fatalf(ctx, "Could not add projects to database with err: %v", err) - } - logger.Infof(ctx, "Successfully added projects to database") + return server.SeedProjects(ctx, args) }, } diff --git a/cmd/entrypoints/serve.go b/cmd/entrypoints/serve.go index 2028dfeb9c..86ae225bbc 100644 --- a/cmd/entrypoints/serve.go +++ b/cmd/entrypoints/serve.go @@ -2,66 +2,38 @@ package entrypoints import ( "context" - "crypto/tls" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/flyteorg/flytestdlib/profutils" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" - - authConfig "github.com/flyteorg/flyteadmin/auth/config" - - "github.com/flyteorg/flyteadmin/auth/authzserver" - - "github.com/gorilla/handlers" - - "github.com/flyteorg/flyteadmin/auth" - "github.com/flyteorg/flyteadmin/auth/interfaces" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" - - "net" - "net/http" _ "net/http/pprof" // Required to serve application. - "strings" - - "github.com/flyteorg/flyteadmin/pkg/server" - "github.com/pkg/errors" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/health" - "google.golang.org/grpc/health/grpc_health_v1" "github.com/flyteorg/flyteadmin/pkg/common" - flyteService "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyteadmin/pkg/server" "github.com/flyteorg/flytestdlib/logger" - "github.com/grpc-ecosystem/grpc-gateway/runtime" - - "github.com/flyteorg/flyteadmin/pkg/config" - "github.com/flyteorg/flyteadmin/pkg/rpc/adminservice" "github.com/spf13/cobra" + runtimeConfig "github.com/flyteorg/flyteadmin/pkg/runtime" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled" - grpcPrometheus "github.com/grpc-ecosystem/go-grpc-prometheus" - "google.golang.org/grpc" - "google.golang.org/grpc/reflection" ) -var defaultCorsHeaders = []string{"Content-Type"} - // serveCmd represents the serve command var serveCmd = &cobra.Command{ Use: "serve", Short: "Launches the Flyte admin server", RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - serverConfig := config.GetConfig() - - if serverConfig.Security.Secure { - return serveGatewaySecure(ctx, serverConfig, authConfig.GetConfig()) - } + // Serve profiling endpoints. + cfg := runtimeConfig.NewConfigurationProvider() + go func() { + err := profutils.StartProfilingServerWithDefaultHandlers( + ctx, cfg.ApplicationConfiguration().GetTopLevelConfig().GetProfilerPort(), nil) + if err != nil { + logger.Panicf(ctx, "Failed to Start profiling and Metrics server. Error, %v", err) + } + }() - return serveGatewayInsecure(ctx, serverConfig, authConfig.GetConfig()) + return server.Serve(ctx, nil) }, } @@ -75,322 +47,3 @@ func init() { contextutils.ExecIDKey, contextutils.WorkflowIDKey, contextutils.NodeIDKey, contextutils.TaskIDKey, contextutils.TaskTypeKey, common.RuntimeTypeKey, common.RuntimeVersionKey) } - -func blanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( - resp interface{}, err error) { - - identityContext := auth.IdentityContextFromContext(ctx) - if identityContext.IsEmpty() { - return handler(ctx, req) - } - - if !identityContext.Scopes().Has(auth.ScopeAll) { - return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope") - } - - return handler(ctx, req) -} - -// Creates a new gRPC Server with all the configuration -func newGRPCServer(ctx context.Context, cfg *config.ServerConfig, authCtx interfaces.AuthenticationContext, - opts ...grpc.ServerOption) (*grpc.Server, error) { - // Not yet implemented for streaming - var chainedUnaryInterceptors grpc.UnaryServerInterceptor - if cfg.Security.UseAuth { - logger.Infof(ctx, "Creating gRPC server with authentication") - chainedUnaryInterceptors = grpc_middleware.ChainUnaryServer(grpcPrometheus.UnaryServerInterceptor, - auth.GetAuthenticationCustomMetadataInterceptor(authCtx), - grpcauth.UnaryServerInterceptor(auth.GetAuthenticationInterceptor(authCtx)), - auth.AuthenticationLoggingInterceptor, - blanketAuthorization, - ) - } else { - logger.Infof(ctx, "Creating gRPC server without authentication") - chainedUnaryInterceptors = grpc_middleware.ChainUnaryServer(grpcPrometheus.UnaryServerInterceptor) - } - - serverOpts := []grpc.ServerOption{ - grpc.StreamInterceptor(grpcPrometheus.StreamServerInterceptor), - grpc.UnaryInterceptor(chainedUnaryInterceptors), - } - if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { - serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes)) - } - serverOpts = append(serverOpts, opts...) - grpcServer := grpc.NewServer(serverOpts...) - grpcPrometheus.Register(grpcServer) - flyteService.RegisterAdminServiceServer(grpcServer, adminservice.NewAdminServer(ctx, cfg.KubeConfig, cfg.Master)) - if cfg.Security.UseAuth { - flyteService.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService()) - flyteService.RegisterIdentityServiceServer(grpcServer, authCtx.IdentityService()) - } - - healthServer := health.NewServer() - healthServer.SetServingStatus("flyteadmin", grpc_health_v1.HealthCheckResponse_SERVING) - grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) - if cfg.GrpcConfig.ServerReflection || cfg.GrpcServerReflection { - reflection.Register(grpcServer) - } - return grpcServer, nil -} - -func GetHandleOpenapiSpec(ctx context.Context) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - swaggerBytes, err := flyteService.Asset("admin.swagger.json") - if err != nil { - logger.Warningf(ctx, "Err %v", err) - w.WriteHeader(http.StatusFailedDependency) - } else { - w.WriteHeader(http.StatusOK) - _, err := w.Write(swaggerBytes) - if err != nil { - logger.Errorf(ctx, "failed to write openAPI information, error: %s", err.Error()) - } - } - } -} - -func healthCheckFunc(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config, authCtx interfaces.AuthenticationContext, - grpcAddress string, grpcConnectionOpts ...grpc.DialOption) (*http.ServeMux, error) { - - // Register the server that will serve HTTP/REST Traffic - mux := http.NewServeMux() - - // Register healthcheck - mux.HandleFunc("/healthcheck", healthCheckFunc) - - // Register OpenAPI endpoint - // This endpoint will serve the OpenAPI2 spec generated by the swagger protoc plugin, and bundled by go-bindata - mux.HandleFunc("/api/v1/openapi", GetHandleOpenapiSpec(ctx)) - - var gwmuxOptions = make([]runtime.ServeMuxOption, 0) - // This option means that http requests are served with protobufs, instead of json. We always want this. - gwmuxOptions = append(gwmuxOptions, runtime.WithMarshalerOption("application/octet-stream", &runtime.ProtoMarshaller{})) - - if cfg.Security.UseAuth { - // Add HTTP handlers for OIDC endpoints - auth.RegisterHandlers(ctx, mux, authCtx) - - // Add HTTP handlers for OAuth2 endpoints - authzserver.RegisterHandlers(mux, authCtx) - - // This option translates HTTP authorization data (cookies) into a gRPC metadata field - gwmuxOptions = append(gwmuxOptions, runtime.WithMetadata(auth.GetHTTPRequestCookieToMetadataHandler(authCtx))) - - // In an attempt to be able to selectively enforce whether or not authentication is required, we're going to tag - // the requests that come from the HTTP gateway. See the enforceHttp/Grpc options for more information. - gwmuxOptions = append(gwmuxOptions, runtime.WithMetadata(auth.GetHTTPMetadataTaggingHandler())) - } - - // Create the grpc-gateway server with the options specified - gwmux := runtime.NewServeMux(gwmuxOptions...) - - err := flyteService.RegisterAdminServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) - if err != nil { - return nil, errors.Wrap(err, "error registering admin service") - } - - err = flyteService.RegisterAuthMetadataServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) - if err != nil { - return nil, errors.Wrap(err, "error registering auth service") - } - - err = flyteService.RegisterIdentityServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) - if err != nil { - return nil, errors.Wrap(err, "error registering identity service") - } - - mux.Handle("/", gwmux) - - return mux, nil -} - -func serveGatewayInsecure(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config) error { - logger.Infof(ctx, "Serving Flyte Admin Insecure") - - // This will parse configuration and create the necessary objects for dealing with auth - var authCtx interfaces.AuthenticationContext - var err error - // This code is here to support authentication without SSL. This setup supports a network topology where - // Envoy does the SSL termination. The final hop is made over localhost only on a trusted machine. - // Warning: Running authentication without SSL in any other topology is a severe security flaw. - // See the auth.Config object for additional settings as well. - if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) - var oauth2Provider interfaces.OAuth2Provider - var oauth2ResourceServer interfaces.OAuth2ResourceServer - if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { - oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm) - if err != nil { - logger.Errorf(ctx, "Error creating authorization server %s", err) - return err - } - - oauth2ResourceServer = oauth2Provider - } else { - oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) - if err != nil { - logger.Errorf(ctx, "Error creating resource server %s", err) - return err - } - } - - oauth2MetadataProvider := authzserver.NewService(authCfg) - oidcUserInfoProvider := auth.NewUserInfoProvider() - - authCtx, err = auth.NewAuthenticationContext(ctx, sm, oauth2Provider, oauth2ResourceServer, oauth2MetadataProvider, oidcUserInfoProvider, authCfg) - if err != nil { - logger.Errorf(ctx, "Error creating auth context %s", err) - return err - } - } - - grpcServer, err := newGRPCServer(ctx, cfg, authCtx) - if err != nil { - return errors.Wrap(err, "failed to create GRPC server") - } - - logger.Infof(ctx, "Serving GRPC Traffic on: %s", cfg.GetGrpcHostAddress()) - lis, err := net.Listen("tcp", cfg.GetGrpcHostAddress()) - if err != nil { - return errors.Wrapf(err, "failed to listen on GRPC port: %s", cfg.GetGrpcHostAddress()) - } - - go func() { - err := grpcServer.Serve(lis) - logger.Fatalf(ctx, "Failed to create GRPC Server, Err: ", err) - }() - - logger.Infof(ctx, "Starting HTTP/1 Gateway server on %s", cfg.GetHostAddress()) - grpcOptions := []grpc.DialOption{ - grpc.WithInsecure(), - grpc.WithMaxHeaderListSize(common.MaxResponseStatusBytes), - } - if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { - grpcOptions = append(grpcOptions, - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) - } - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, cfg.GetGrpcHostAddress(), grpcOptions...) - if err != nil { - return err - } - - var handler http.Handler - if cfg.Security.AllowCors { - handler = handlers.CORS( - handlers.AllowCredentials(), - handlers.AllowedOrigins(cfg.Security.AllowedOrigins), - handlers.AllowedHeaders(append(defaultCorsHeaders, cfg.Security.AllowedHeaders...)), - handlers.AllowedMethods([]string{"GET", "POST", "DELETE", "HEAD", "PUT", "PATCH"}), - )(httpServer) - } else { - handler = httpServer - } - - err = http.ListenAndServe(cfg.GetHostAddress(), handler) - if err != nil { - return errors.Wrapf(err, "failed to Start HTTP Server") - } - - return nil -} - -// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC -// connections or otherHandler otherwise. -// See https://github.com/philips/grpc-gateway-example/blob/master/cmd/serve.go for reference -func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // This is a partial recreation of gRPC's internal checks - if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { - grpcServer.ServeHTTP(w, r) - } else { - otherHandler.ServeHTTP(w, r) - } - }) -} - -func serveGatewaySecure(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config) error { - certPool, cert, err := server.GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) - if err != nil { - return err - } - // This will parse configuration and create the necessary objects for dealing with auth - var authCtx interfaces.AuthenticationContext - if cfg.Security.UseAuth { - sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) - var oauth2Provider interfaces.OAuth2Provider - var oauth2ResourceServer interfaces.OAuth2ResourceServer - if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { - oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm) - if err != nil { - logger.Errorf(ctx, "Error creating authorization server %s", err) - return err - } - - oauth2ResourceServer = oauth2Provider - } else { - oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) - if err != nil { - logger.Errorf(ctx, "Error creating resource server %s", err) - return err - } - } - - oauth2MetadataProvider := authzserver.NewService(authCfg) - oidcUserInfoProvider := auth.NewUserInfoProvider() - - authCtx, err = auth.NewAuthenticationContext(ctx, sm, oauth2Provider, oauth2ResourceServer, oauth2MetadataProvider, oidcUserInfoProvider, authCfg) - if err != nil { - logger.Errorf(ctx, "Error creating auth context %s", err) - return err - } - } - - grpcServer, err := newGRPCServer(ctx, cfg, authCtx, - grpc.Creds(credentials.NewServerTLSFromCert(cert))) - if err != nil { - return errors.Wrap(err, "failed to create GRPC server") - } - - // Whatever certificate is used, pass it along for easier development - dialCreds := credentials.NewTLS(&tls.Config{ - ServerName: cfg.GetHostAddress(), - RootCAs: certPool, - }) - serverOpts := []grpc.DialOption{ - grpc.WithTransportCredentials(dialCreds), - } - if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { - serverOpts = append(serverOpts, - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) - } - httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, cfg.GetHostAddress(), serverOpts...) - if err != nil { - return err - } - - conn, err := net.Listen("tcp", cfg.GetHostAddress()) - if err != nil { - panic(err) - } - - srv := &http.Server{ - Addr: cfg.GetHostAddress(), - Handler: grpcHandlerFunc(grpcServer, httpServer), - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{*cert}, - NextProtos: []string{"h2"}, - }, - } - - err = srv.Serve(tls.NewListener(conn, srv.TLSConfig)) - - if err != nil { - return errors.Wrapf(err, "failed to Start HTTP/2 Server") - } - return nil -} diff --git a/cmd/scheduler/entrypoints/scheduler.go b/cmd/scheduler/entrypoints/scheduler.go index b981a3452e..e7ebe9f09b 100644 --- a/cmd/scheduler/entrypoints/scheduler.go +++ b/cmd/scheduler/entrypoints/scheduler.go @@ -2,20 +2,13 @@ package entrypoints import ( "context" - "fmt" - "runtime/debug" - - "github.com/flyteorg/flyteadmin/pkg/repositories" - "github.com/flyteorg/flyteadmin/pkg/repositories/errors" "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/runtime" "github.com/flyteorg/flyteadmin/scheduler" - "github.com/flyteorg/flyteidl/clients/go/admin" "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/profutils" - "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/promutils/labeled" "github.com/spf13/cobra" @@ -27,45 +20,7 @@ var schedulerRunCmd = &cobra.Command{ Short: "This command will start the flyte native scheduler and periodically get new schedules from the db for scheduling", RunE: func(cmd *cobra.Command, args []string) error { ctx := context.Background() - configuration := runtime.NewConfigurationProvider() - applicationConfiguration := configuration.ApplicationConfiguration().GetTopLevelConfig() - schedulerConfiguration := configuration.ApplicationConfiguration().GetSchedulerConfig() - - // Define the schedulerScope for prometheus metrics - schedulerScope := promutils.NewScope(applicationConfiguration.MetricsScope).NewSubScope("flytescheduler") - schedulerPanics := schedulerScope.MustNewCounter("initialization_panic", - "panics encountered initializing the flyte native scheduler") - - defer func() { - if err := recover(); err != nil { - schedulerPanics.Inc() - logger.Fatalf(ctx, fmt.Sprintf("caught panic: %v [%+v]", err, string(debug.Stack()))) - } - }() - - databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() - logConfig := logger.GetConfig() - - db, err := repositories.GetDB(ctx, databaseConfig, logConfig) - if err != nil { - logger.Fatal(ctx, err) - } - dbScope := schedulerScope.NewSubScope("database") - repo := repositories.NewGormRepo( - db, errors.NewPostgresErrorTransformer(schedulerScope.NewSubScope("errors")), dbScope) - - clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).Build(ctx) - if err != nil { - logger.Fatalf(ctx, "Flyte native scheduler failed to start due to %v", err) - return err - } - adminServiceClient := clientSet.AdminClient() - - scheduleExecutor := scheduler.NewScheduledExecutor(repo, - configuration.ApplicationConfiguration().GetSchedulerConfig().GetWorkflowExecutorConfig(), schedulerScope, adminServiceClient) - - logger.Info(ctx, "Successfully initialized a native flyte scheduler") - + schedulerConfiguration := runtime.NewConfigurationProvider().ApplicationConfiguration().GetSchedulerConfig() // Serve profiling endpoints. go func() { err := profutils.StartProfilingServerWithDefaultHandlers( @@ -74,13 +29,7 @@ var schedulerRunCmd = &cobra.Command{ logger.Panicf(ctx, "Failed to Start profiling and Metrics server. Error, %v", err) } }() - - err = scheduleExecutor.Run(ctx) - if err != nil { - logger.Fatalf(ctx, "Flyte native scheduler failed to start due to %v", err) - return err - } - return nil + return scheduler.StartScheduler(ctx) }, } diff --git a/pkg/clusterresource/controller.go b/pkg/clusterresource/controller.go index 677d2894e5..1fb55e0588 100644 --- a/pkg/clusterresource/controller.go +++ b/pkg/clusterresource/controller.go @@ -12,6 +12,14 @@ import ( "strings" "time" + impl2 "github.com/flyteorg/flyteadmin/pkg/clusterresource/impl" + "github.com/flyteorg/flyteadmin/pkg/config" + "github.com/flyteorg/flyteadmin/pkg/executioncluster/impl" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/resources" + "github.com/flyteorg/flyteadmin/pkg/repositories" + errors2 "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + admin2 "github.com/flyteorg/flyteidl/clients/go/admin" + "google.golang.org/grpc/status" "github.com/flyteorg/flyteadmin/pkg/clusterresource/interfaces" @@ -633,13 +641,55 @@ func newMetrics(scope promutils.Scope) controllerMetrics { } func NewClusterResourceController(adminDataProvider interfaces.FlyteAdminDataProvider, listTargets executionclusterIfaces.ListTargetsInterface, scope promutils.Scope) Controller { - config := runtime.NewConfigurationProvider() + cfg := runtime.NewConfigurationProvider() return &controller{ adminDataProvider: adminDataProvider, - config: config, + config: cfg, listTargets: listTargets, poller: make(chan struct{}), metrics: newMetrics(scope), appliedTemplates: make(map[string]map[string]time.Time), } } + +func NewClusterResourceControllerFromConfig(ctx context.Context, scope promutils.Scope, configuration runtimeInterfaces.Configuration) (Controller, error) { + initializationErrorCounter := scope.MustNewCounter( + "flyteclient_initialization_error", + "count of errors encountered initializing a flyte client from kube config") + var listTargetsProvider executionclusterIfaces.ListTargetsInterface + var err error + if len(configuration.ClusterConfiguration().GetClusterConfigs()) == 0 { + serverConfig := config.GetConfig() + listTargetsProvider, err = impl.NewInCluster(initializationErrorCounter, serverConfig.KubeConfig, serverConfig.Master) + } else { + listTargetsProvider, err = impl.NewListTargets(initializationErrorCounter, impl.NewExecutionTargetProvider(), configuration.ClusterConfiguration()) + } + if err != nil { + return nil, err + } + + var adminDataProvider interfaces.FlyteAdminDataProvider + if configuration.ClusterResourceConfiguration().IsStandaloneDeployment() { + clientSet, err := admin2.ClientSetBuilder().WithConfig(admin2.GetConfig(ctx)).Build(ctx) + if err != nil { + return nil, err + } + adminDataProvider = impl2.NewAdminServiceDataProvider(clientSet.AdminClient()) + } else { + dbConfig := runtime.NewConfigurationProvider().ApplicationConfiguration().GetDbConfig() + logConfig := logger.GetConfig() + + db, err := repositories.GetDB(ctx, dbConfig, logConfig) + if err != nil { + return nil, err + } + dbScope := scope.NewSubScope("db") + + repo := repositories.NewGormRepo( + db, errors2.NewPostgresErrorTransformer(dbScope.NewSubScope("errors")), dbScope) + + adminDataProvider = impl2.NewDatabaseAdminDataProvider(repo, configuration, resources.NewResourceManager(repo, configuration.ApplicationConfiguration())) + } + + return NewClusterResourceController(adminDataProvider, listTargetsProvider, scope), nil +} diff --git a/pkg/repositories/database.go b/pkg/repositories/database.go index 61c1616966..74f72a2bb2 100644 --- a/pkg/repositories/database.go +++ b/pkg/repositories/database.go @@ -9,9 +9,10 @@ import ( "strings" repoErrors "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + "gorm.io/driver/sqlite" + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" "github.com/flyteorg/flytestdlib/logger" - "github.com/jackc/pgconn" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -66,7 +67,7 @@ func resolvePassword(ctx context.Context, passwordVal, passwordPath string) stri } // Produces the DSN (data source name) for opening a postgres db connection. -func getPostgresDsn(ctx context.Context, pgConfig runtimeInterfaces.PostgresConfig) string { +func getPostgresDsn(ctx context.Context, pgConfig *runtimeInterfaces.PostgresConfig) string { password := resolvePassword(ctx, pgConfig.Password, pgConfig.PasswordPath) if len(password) == 0 { // The password-less case is included for development environments. @@ -80,7 +81,7 @@ func getPostgresDsn(ctx context.Context, pgConfig runtimeInterfaces.PostgresConf // GetDB uses the dbConfig to create gorm DB object. If the db doesn't exist for the dbConfig then a new one is created // using the default db for the provider. eg : postgres has default dbName as postgres func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig *logger.Config) ( - gormDb *gorm.DB, err error) { + *gorm.DB, error) { if dbConfig == nil { panic("Cannot initialize database repository from empty db config") } @@ -89,17 +90,26 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig DisableForeignKeyConstraintWhenMigrating: !dbConfig.EnableForeignKeyConstraintWhenMigrating, } - // TODO: add other gorm-supported db type handling in further case blocks. + var gormDb *gorm.DB + var err error + switch { - // TODO: Figure out a better proxy for a non-empty postgres config - case len(dbConfig.PostgresConfig.Host) > 0 || len(dbConfig.PostgresConfig.User) > 0 || len(dbConfig.PostgresConfig.DbName) > 0: + case dbConfig.SQLiteConfig != nil: + if dbConfig.SQLiteConfig.File == "" { + return nil, fmt.Errorf("illegal sqlite database configuration. `file` is a required parameter and should be a path") + } + gormDb, err = gorm.Open(sqlite.Open(dbConfig.SQLiteConfig.File), gormConfig) + if err != nil { + return nil, err + } + case dbConfig.PostgresConfig != nil && (len(dbConfig.PostgresConfig.Host) > 0 || len(dbConfig.PostgresConfig.User) > 0 || len(dbConfig.PostgresConfig.DbName) > 0): gormDb, err = createPostgresDbIfNotExists(ctx, gormConfig, dbConfig.PostgresConfig) if err != nil { return nil, err } case len(dbConfig.DeprecatedHost) > 0 || len(dbConfig.DeprecatedUser) > 0 || len(dbConfig.DeprecatedDbName) > 0: - pgConfig := runtimeInterfaces.PostgresConfig{ + pgConfig := &runtimeInterfaces.PostgresConfig{ Host: dbConfig.DeprecatedHost, Port: dbConfig.DeprecatedPort, DbName: dbConfig.DeprecatedDbName, @@ -114,7 +124,7 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig return nil, err } default: - panic(fmt.Sprintf("Unrecognized database config %v", dbConfig)) + return nil, fmt.Errorf("unrecognized database config, %v. Supported only postgres and sqlite", dbConfig) } // Setup connection pool settings @@ -122,7 +132,7 @@ func GetDB(ctx context.Context, dbConfig *runtimeInterfaces.DbConfig, logConfig } // Creates DB if it doesn't exist for the passed in config -func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig runtimeInterfaces.PostgresConfig) (*gorm.DB, error) { +func createPostgresDbIfNotExists(ctx context.Context, gormConfig *gorm.Config, pgConfig *runtimeInterfaces.PostgresConfig) (*gorm.DB, error) { dialector := postgres.Open(getPostgresDsn(ctx, pgConfig)) gormDb, err := gorm.Open(dialector, gormConfig) diff --git a/pkg/repositories/database_test.go b/pkg/repositories/database_test.go index 5b1de474e7..af81088089 100644 --- a/pkg/repositories/database_test.go +++ b/pkg/repositories/database_test.go @@ -4,6 +4,7 @@ import ( "context" "io/ioutil" "os" + "path" "path/filepath" "testing" "time" @@ -58,7 +59,7 @@ func TestResolvePassword(t *testing.T) { } func TestGetPostgresDsn(t *testing.T) { - pgConfig := runtimeInterfaces.PostgresConfig{ + pgConfig := &runtimeInterfaces.PostgresConfig{ Host: "localhost", Port: 5432, DbName: "postgres", @@ -143,3 +144,25 @@ func TestSetupDbConnectionPool(t *testing.T) { assert.NotNil(t, err) }) } + +func TestGetDB(t *testing.T) { + ctx := context.TODO() + + t.Run("missing DB Config", func(t *testing.T) { + _, err := GetDB(ctx, &runtimeInterfaces.DbConfig{}, &logger.Config{}) + assert.Error(t, err) + }) + + t.Run("sqlite config", func(t *testing.T) { + dbFile := path.Join(t.TempDir(), "admin.db") + db, err := GetDB(ctx, &runtimeInterfaces.DbConfig{ + SQLiteConfig: &runtimeInterfaces.SQLiteConfig{ + File: dbFile, + }, + }, &logger.Config{}) + assert.NoError(t, err) + assert.NotNil(t, db) + assert.FileExists(t, dbFile) + assert.Equal(t, "sqlite", db.Name()) + }) +} diff --git a/pkg/rpc/adminservice/base.go b/pkg/rpc/adminservice/base.go index 891f8f9a5b..f3f7400602 100644 --- a/pkg/rpc/adminservice/base.go +++ b/pkg/rpc/adminservice/base.go @@ -24,7 +24,6 @@ import ( "github.com/flyteorg/flyteadmin/pkg/workflowengine" workflowengineImpl "github.com/flyteorg/flyteadmin/pkg/workflowengine/impl" "github.com/flyteorg/flytestdlib/logger" - "github.com/flyteorg/flytestdlib/profutils" "github.com/flyteorg/flytestdlib/promutils" "github.com/flyteorg/flytestdlib/storage" "github.com/golang/protobuf/proto" @@ -155,15 +154,6 @@ func NewAdminServer(ctx context.Context, kubeConfig, master string) *AdminServic scheduledWorkflowExecutor.Run() }() - // Serve profiling endpoints. - go func() { - err := profutils.StartProfilingServerWithDefaultHandlers( - ctx, applicationConfiguration.GetProfilerPort(), nil) - if err != nil { - logger.Panicf(ctx, "Failed to Start profiling and Metrics server. Error, %v", err) - } - }() - nodeExecutionEventWriter := eventWriter.NewNodeExecutionEventWriter(repo, applicationConfiguration.GetAsyncEventsBufferSize()) go func() { nodeExecutionEventWriter.Run() diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index 2eca76c813..9f2a5d0ab1 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -22,7 +22,13 @@ type DbConfig struct { MaxIdleConnections int `json:"maxIdleConnections" pflag:",maxIdleConnections sets the maximum number of connections in the idle connection pool."` MaxOpenConnections int `json:"maxOpenConnections" pflag:",maxOpenConnections sets the maximum number of open connections to the database."` ConnMaxLifeTime config.Duration `json:"connMaxLifeTime" pflag:",sets the maximum amount of time a connection may be reused"` - PostgresConfig PostgresConfig `json:"postgres"` + PostgresConfig *PostgresConfig `json:"postgres,omitempty"` + SQLiteConfig *SQLiteConfig `json:"sqlite,omitempty"` +} + +// SQLiteConfig can be used to configure +type SQLiteConfig struct { + File string `json:"file" pflag:",The path to the file (existing or new) where the DB should be created / stored. If existing, then this will be re-used, else a new will be created"` } // PostgresConfig includes specific config options for opening a connection to a postgres database. @@ -38,7 +44,7 @@ type PostgresConfig struct { Debug bool `json:"debug" pflag:" Whether or not to start the database connection with debug mode enabled."` } -// This configuration is the base configuration to start admin +// ApplicationConfig is the base configuration to start admin type ApplicationConfig struct { // The RoleName key inserted as an annotation (https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/) // in Flyte Workflow CRDs created in the CreateExecution flow. The corresponding role value is defined in the diff --git a/pkg/server/initialize.go b/pkg/server/initialize.go new file mode 100644 index 0000000000..69c3846769 --- /dev/null +++ b/pkg/server/initialize.go @@ -0,0 +1,77 @@ +package server + +import ( + "context" + "fmt" + + "github.com/flyteorg/flyteadmin/pkg/repositories" + "github.com/flyteorg/flyteadmin/pkg/repositories/config" + "github.com/flyteorg/flyteadmin/pkg/runtime" + "github.com/flyteorg/flytestdlib/logger" + "github.com/go-gormigrate/gormigrate/v2" + "gorm.io/gorm" +) + +func withDB(ctx context.Context, do func(db *gorm.DB) error) error { + configuration := runtime.NewConfigurationProvider() + databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() + logConfig := logger.GetConfig() + + db, err := repositories.GetDB(ctx, databaseConfig, logConfig) + if err != nil { + logger.Fatal(ctx, err) + } + + sqlDB, err := db.DB() + if err != nil { + logger.Fatal(ctx, err) + } + + defer func(deferCtx context.Context) { + if err = sqlDB.Close(); err != nil { + logger.Fatal(deferCtx, err) + } + }(ctx) + + if err = sqlDB.Ping(); err != nil { + return err + } + + return do(db) +} + +// Migrate runs all configured migrations +func Migrate(ctx context.Context) error { + return withDB(ctx, func(db *gorm.DB) error { + m := gormigrate.New(db, gormigrate.DefaultOptions, config.Migrations) + if err := m.Migrate(); err != nil { + return fmt.Errorf("database migration failed: %v", err) + } + logger.Infof(ctx, "Migration ran successfully") + return nil + }) +} + +// Rollback rolls back the last migration +func Rollback(ctx context.Context) error { + return withDB(ctx, func(db *gorm.DB) error { + m := gormigrate.New(db, gormigrate.DefaultOptions, config.Migrations) + err := m.RollbackLast() + if err != nil { + return fmt.Errorf("could not rollback latest migration: %v", err) + } + logger.Infof(ctx, "Rolled back one migration successfully") + return nil + }) +} + +// SeedProjects creates a set of given projects in the DB +func SeedProjects(ctx context.Context, projects []string) error { + return withDB(ctx, func(db *gorm.DB) error { + if err := config.SeedProjects(db, projects); err != nil { + return fmt.Errorf("could not add projects to database with err: %v", err) + } + logger.Infof(ctx, "Successfully added projects to database") + return nil + }) +} diff --git a/pkg/server/service.go b/pkg/server/service.go new file mode 100644 index 0000000000..41ce924918 --- /dev/null +++ b/pkg/server/service.go @@ -0,0 +1,366 @@ +package server + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "strings" + + "github.com/flyteorg/flyteadmin/auth" + "github.com/flyteorg/flyteadmin/auth/authzserver" + authConfig "github.com/flyteorg/flyteadmin/auth/config" + "github.com/flyteorg/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/config" + "github.com/flyteorg/flyteadmin/pkg/rpc/adminservice" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/secretmanager" + "github.com/flyteorg/flytestdlib/logger" + "github.com/gorilla/handlers" + grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth" + grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" + "github.com/grpc-ecosystem/grpc-gateway/runtime" + "github.com/pkg/errors" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/reflection" + "google.golang.org/grpc/status" +) + +var defaultCorsHeaders = []string{"Content-Type"} + +// Serve starts a server and blocks the calling goroutine +func Serve(ctx context.Context, additionalHandlers map[string]func(http.ResponseWriter, *http.Request)) error { + serverConfig := config.GetConfig() + + if serverConfig.Security.Secure { + return serveGatewaySecure(ctx, serverConfig, authConfig.GetConfig(), additionalHandlers) + } + + return serveGatewayInsecure(ctx, serverConfig, authConfig.GetConfig(), additionalHandlers) +} + +func blanketAuthorization(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + + identityContext := auth.IdentityContextFromContext(ctx) + if identityContext.IsEmpty() { + return handler(ctx, req) + } + + if !identityContext.Scopes().Has(auth.ScopeAll) { + return nil, status.Errorf(codes.Unauthenticated, "authenticated user doesn't have required scope") + } + + return handler(ctx, req) +} + +// Creates a new gRPC Server with all the configuration +func newGRPCServer(ctx context.Context, cfg *config.ServerConfig, authCtx interfaces.AuthenticationContext, + opts ...grpc.ServerOption) *grpc.Server { + // Not yet implemented for streaming + var chainedUnaryInterceptors grpc.UnaryServerInterceptor + if cfg.Security.UseAuth { + logger.Infof(ctx, "Creating gRPC server with authentication") + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor, + auth.GetAuthenticationCustomMetadataInterceptor(authCtx), + grpcauth.UnaryServerInterceptor(auth.GetAuthenticationInterceptor(authCtx)), + auth.AuthenticationLoggingInterceptor, + blanketAuthorization, + ) + } else { + logger.Infof(ctx, "Creating gRPC server without authentication") + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor) + } + + serverOpts := []grpc.ServerOption{ + grpc.StreamInterceptor(grpcprometheus.StreamServerInterceptor), + grpc.UnaryInterceptor(chainedUnaryInterceptors), + } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes)) + } + serverOpts = append(serverOpts, opts...) + grpcServer := grpc.NewServer(serverOpts...) + grpcprometheus.Register(grpcServer) + service.RegisterAdminServiceServer(grpcServer, adminservice.NewAdminServer(ctx, cfg.KubeConfig, cfg.Master)) + if cfg.Security.UseAuth { + service.RegisterAuthMetadataServiceServer(grpcServer, authCtx.AuthMetadataService()) + service.RegisterIdentityServiceServer(grpcServer, authCtx.IdentityService()) + } + + healthServer := health.NewServer() + healthServer.SetServingStatus("flyteadmin", grpc_health_v1.HealthCheckResponse_SERVING) + grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) + if cfg.GrpcConfig.ServerReflection || cfg.GrpcServerReflection { + reflection.Register(grpcServer) + } + return grpcServer +} + +func GetHandleOpenapiSpec(ctx context.Context) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + swaggerBytes, err := service.Asset("admin.swagger.json") + if err != nil { + logger.Warningf(ctx, "Err %v", err) + w.WriteHeader(http.StatusFailedDependency) + } else { + w.WriteHeader(http.StatusOK) + _, err := w.Write(swaggerBytes) + if err != nil { + logger.Errorf(ctx, "failed to write openAPI information, error: %s", err.Error()) + } + } + } +} + +func healthCheckFunc(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) +} + +func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig.Config, authCtx interfaces.AuthenticationContext, + additionalHandlers map[string]func(http.ResponseWriter, *http.Request), + grpcAddress string, grpcConnectionOpts ...grpc.DialOption) (*http.ServeMux, error) { + + // Register the server that will serve HTTP/REST Traffic + mux := http.NewServeMux() + + // Add any additional handlers that have been passed in for the main HTTP server + for p, f := range additionalHandlers { + mux.HandleFunc(p, f) + } + + // Register healthcheck + mux.HandleFunc("/healthcheck", healthCheckFunc) + + // Register OpenAPI endpoint + // This endpoint will serve the OpenAPI2 spec generated by the swagger protoc plugin, and bundled by go-bindata + mux.HandleFunc("/api/v1/openapi", GetHandleOpenapiSpec(ctx)) + + var gwmuxOptions = make([]runtime.ServeMuxOption, 0) + // This option means that http requests are served with protobufs, instead of json. We always want this. + gwmuxOptions = append(gwmuxOptions, runtime.WithMarshalerOption("application/octet-stream", &runtime.ProtoMarshaller{})) + + if cfg.Security.UseAuth { + // Add HTTP handlers for OIDC endpoints + auth.RegisterHandlers(ctx, mux, authCtx) + + // Add HTTP handlers for OAuth2 endpoints + authzserver.RegisterHandlers(mux, authCtx) + + // This option translates HTTP authorization data (cookies) into a gRPC metadata field + gwmuxOptions = append(gwmuxOptions, runtime.WithMetadata(auth.GetHTTPRequestCookieToMetadataHandler(authCtx))) + + // In an attempt to be able to selectively enforce whether or not authentication is required, we're going to tag + // the requests that come from the HTTP gateway. See the enforceHttp/Grpc options for more information. + gwmuxOptions = append(gwmuxOptions, runtime.WithMetadata(auth.GetHTTPMetadataTaggingHandler())) + } + + // Create the grpc-gateway server with the options specified + gwmux := runtime.NewServeMux(gwmuxOptions...) + + err := service.RegisterAdminServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) + if err != nil { + return nil, errors.Wrap(err, "error registering admin service") + } + + err = service.RegisterAuthMetadataServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) + if err != nil { + return nil, errors.Wrap(err, "error registering auth service") + } + + err = service.RegisterIdentityServiceHandlerFromEndpoint(ctx, gwmux, grpcAddress, grpcConnectionOpts) + if err != nil { + return nil, errors.Wrap(err, "error registering identity service") + } + + mux.Handle("/", gwmux) + + return mux, nil +} + +func serveGatewayInsecure(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config, additionalHandlers map[string]func(http.ResponseWriter, *http.Request)) error { + logger.Infof(ctx, "Serving Flyte Admin Insecure") + + // This will parse configuration and create the necessary objects for dealing with auth + var authCtx interfaces.AuthenticationContext + var err error + // This code is here to support authentication without SSL. This setup supports a network topology where + // Envoy does the SSL termination. The final hop is made over localhost only on a trusted machine. + // Warning: Running authentication without SSL in any other topology is a severe security flaw. + // See the auth.Config object for additional settings as well. + if cfg.Security.UseAuth { + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + var oauth2Provider interfaces.OAuth2Provider + var oauth2ResourceServer interfaces.OAuth2ResourceServer + if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { + oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm) + if err != nil { + logger.Errorf(ctx, "Error creating authorization server %s", err) + return err + } + + oauth2ResourceServer = oauth2Provider + } else { + oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) + if err != nil { + logger.Errorf(ctx, "Error creating resource server %s", err) + return err + } + } + + oauth2MetadataProvider := authzserver.NewService(authCfg) + oidcUserInfoProvider := auth.NewUserInfoProvider() + + authCtx, err = auth.NewAuthenticationContext(ctx, sm, oauth2Provider, oauth2ResourceServer, oauth2MetadataProvider, oidcUserInfoProvider, authCfg) + if err != nil { + logger.Errorf(ctx, "Error creating auth context %s", err) + return err + } + } + + grpcServer := newGRPCServer(ctx, cfg, authCtx) + + logger.Infof(ctx, "Serving GRPC Traffic on: %s", cfg.GetGrpcHostAddress()) + lis, err := net.Listen("tcp", cfg.GetGrpcHostAddress()) + if err != nil { + return errors.Wrapf(err, "failed to listen on GRPC port: %s", cfg.GetGrpcHostAddress()) + } + + go func() { + err := grpcServer.Serve(lis) + logger.Fatalf(ctx, "Failed to create GRPC Server, Err: ", err) + }() + + logger.Infof(ctx, "Starting HTTP/1 Gateway server on %s", cfg.GetHostAddress()) + grpcOptions := []grpc.DialOption{ + grpc.WithInsecure(), + grpc.WithMaxHeaderListSize(common.MaxResponseStatusBytes), + } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + grpcOptions = append(grpcOptions, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) + } + httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetGrpcHostAddress(), grpcOptions...) + if err != nil { + return err + } + + var handler http.Handler + if cfg.Security.AllowCors { + handler = handlers.CORS( + handlers.AllowCredentials(), + handlers.AllowedOrigins(cfg.Security.AllowedOrigins), + handlers.AllowedHeaders(append(defaultCorsHeaders, cfg.Security.AllowedHeaders...)), + handlers.AllowedMethods([]string{"GET", "POST", "DELETE", "HEAD", "PUT", "PATCH"}), + )(httpServer) + } else { + handler = httpServer + } + + err = http.ListenAndServe(cfg.GetHostAddress(), handler) + if err != nil { + return errors.Wrapf(err, "failed to Start HTTP Server") + } + + return nil +} + +// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC +// connections or otherHandler otherwise. +// See https://github.com/philips/grpc-gateway-example/blob/master/cmd/serve.go for reference +func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This is a partial recreation of gRPC's internal checks + if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + grpcServer.ServeHTTP(w, r) + } else { + otherHandler.ServeHTTP(w, r) + } + }) +} + +func serveGatewaySecure(ctx context.Context, cfg *config.ServerConfig, authCfg *authConfig.Config, additionalHandlers map[string]func(http.ResponseWriter, *http.Request)) error { + certPool, cert, err := GetSslCredentials(ctx, cfg.Security.Ssl.CertificateFile, cfg.Security.Ssl.KeyFile) + if err != nil { + return err + } + // This will parse configuration and create the necessary objects for dealing with auth + var authCtx interfaces.AuthenticationContext + if cfg.Security.UseAuth { + sm := secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()) + var oauth2Provider interfaces.OAuth2Provider + var oauth2ResourceServer interfaces.OAuth2ResourceServer + if authCfg.AppAuth.AuthServerType == authConfig.AuthorizationServerTypeSelf { + oauth2Provider, err = authzserver.NewProvider(ctx, authCfg.AppAuth.SelfAuthServer, sm) + if err != nil { + logger.Errorf(ctx, "Error creating authorization server %s", err) + return err + } + + oauth2ResourceServer = oauth2Provider + } else { + oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) + if err != nil { + logger.Errorf(ctx, "Error creating resource server %s", err) + return err + } + } + + oauth2MetadataProvider := authzserver.NewService(authCfg) + oidcUserInfoProvider := auth.NewUserInfoProvider() + + authCtx, err = auth.NewAuthenticationContext(ctx, sm, oauth2Provider, oauth2ResourceServer, oauth2MetadataProvider, oidcUserInfoProvider, authCfg) + if err != nil { + logger.Errorf(ctx, "Error creating auth context %s", err) + return err + } + } + + grpcServer := newGRPCServer(ctx, cfg, authCtx, grpc.Creds(credentials.NewServerTLSFromCert(cert))) + + // Whatever certificate is used, pass it along for easier development + // #nosec G402 + dialCreds := credentials.NewTLS(&tls.Config{ + ServerName: cfg.GetHostAddress(), + RootCAs: certPool, + }) + serverOpts := []grpc.DialOption{ + grpc.WithTransportCredentials(dialCreds), + } + if cfg.GrpcConfig.MaxMessageSizeBytes > 0 { + serverOpts = append(serverOpts, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.GrpcConfig.MaxMessageSizeBytes))) + } + httpServer, err := newHTTPServer(ctx, cfg, authCfg, authCtx, additionalHandlers, cfg.GetHostAddress(), serverOpts...) + if err != nil { + return err + } + + conn, err := net.Listen("tcp", cfg.GetHostAddress()) + if err != nil { + panic(err) + } + + srv := &http.Server{ + Addr: cfg.GetHostAddress(), + Handler: grpcHandlerFunc(grpcServer, httpServer), + // #nosec G402 + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{*cert}, + NextProtos: []string{"h2"}, + }, + } + + err = srv.Serve(tls.NewListener(conn, srv.TLSConfig)) + + if err != nil { + return errors.Wrapf(err, "failed to Start HTTP/2 Server") + } + return nil +} diff --git a/scheduler/start.go b/scheduler/start.go new file mode 100644 index 0000000000..d93d09db8d --- /dev/null +++ b/scheduler/start.go @@ -0,0 +1,62 @@ +package scheduler + +import ( + "context" + "fmt" + "runtime/debug" + + "github.com/flyteorg/flyteadmin/pkg/repositories" + "github.com/flyteorg/flyteadmin/pkg/repositories/errors" + "github.com/flyteorg/flyteadmin/pkg/runtime" + "github.com/flyteorg/flyteidl/clients/go/admin" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" +) + +// StartScheduler creates and starts a new scheduler instance. This is a blocking call and will block the calling go-routine +func StartScheduler(ctx context.Context) error { + configuration := runtime.NewConfigurationProvider() + applicationConfiguration := configuration.ApplicationConfiguration().GetTopLevelConfig() + + // Define the schedulerScope for prometheus metrics + schedulerScope := promutils.NewScope(applicationConfiguration.MetricsScope).NewSubScope("flytescheduler") + schedulerPanics := schedulerScope.MustNewCounter("initialization_panic", + "panics encountered initializing the flyte native scheduler") + + defer func() { + if err := recover(); err != nil { + schedulerPanics.Inc() + logger.Fatalf(ctx, fmt.Sprintf("caught panic: %v [%+v]", err, string(debug.Stack()))) + } + }() + + databaseConfig := configuration.ApplicationConfiguration().GetDbConfig() + logConfig := logger.GetConfig() + + db, err := repositories.GetDB(ctx, databaseConfig, logConfig) + if err != nil { + logger.Fatal(ctx, err) + } + dbScope := schedulerScope.NewSubScope("database") + repo := repositories.NewGormRepo( + db, errors.NewPostgresErrorTransformer(schedulerScope.NewSubScope("errors")), dbScope) + + clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).Build(ctx) + if err != nil { + logger.Fatalf(ctx, "Flyte native scheduler failed to start due to %v", err) + return err + } + adminServiceClient := clientSet.AdminClient() + + scheduleExecutor := NewScheduledExecutor(repo, + configuration.ApplicationConfiguration().GetSchedulerConfig().GetWorkflowExecutorConfig(), schedulerScope, adminServiceClient) + + logger.Info(ctx, "Successfully initialized a native flyte scheduler") + + err = scheduleExecutor.Run(ctx) + if err != nil { + logger.Fatalf(ctx, "Flyte native scheduler failed to start due to %v", err) + return err + } + return nil +} diff --git a/tests/bootstrap.go b/tests/bootstrap.go index 57ca364163..44ca517cae 100644 --- a/tests/bootstrap.go +++ b/tests/bootstrap.go @@ -23,7 +23,7 @@ var adminScope = promutils.NewScope("flyteadmin") func getDbConfig() *runtimeInterfaces.DbConfig { return &runtimeInterfaces.DbConfig{ - PostgresConfig: runtimeInterfaces.PostgresConfig{ + PostgresConfig: &runtimeInterfaces.PostgresConfig{ Host: "postgres", Port: 5432, DbName: "postgres", @@ -34,7 +34,7 @@ func getDbConfig() *runtimeInterfaces.DbConfig { func getLocalDbConfig() *runtimeInterfaces.DbConfig { return &runtimeInterfaces.DbConfig{ - PostgresConfig: runtimeInterfaces.PostgresConfig{ + PostgresConfig: &runtimeInterfaces.PostgresConfig{ Host: "localhost", Port: 5432, DbName: "flyteadmin",