From 10642615d11d873676794e381ebeff99ee4f6796 Mon Sep 17 00:00:00 2001 From: Andrea Lamparelli Date: Mon, 11 Dec 2023 21:06:27 +0100 Subject: [PATCH] Mock model registry service Signed-off-by: Andrea Lamparelli --- controllers/modelregistry_controller.go | 121 +++--- controllers/modelregistry_controller_test.go | 368 +++++------------- .../mr_servingenvironment_reconciler.go | 1 - controllers/suite_test.go | 51 ++- .../modelregistry-inference-service-1.yaml | 18 + .../modelregistry/modelregistry-mock.go | 205 ++++++++++ go.mod | 3 + go.sum | 2 + 8 files changed, 425 insertions(+), 344 deletions(-) create mode 100644 controllers/testdata/deploy/modelregistry-inference-service-1.yaml create mode 100644 controllers/testdata/modelregistry/modelregistry-mock.go diff --git a/controllers/modelregistry_controller.go b/controllers/modelregistry_controller.go index e53448d2..1d3db3d9 100644 --- a/controllers/modelregistry_controller.go +++ b/controllers/modelregistry_controller.go @@ -8,6 +8,7 @@ import ( "github.com/go-logr/logr" kservev1alpha1 "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" + "github.com/opendatahub-io/model-registry/pkg/api" "github.com/opendatahub-io/model-registry/pkg/core" "github.com/opendatahub-io/odh-model-controller/controllers/constants" "github.com/opendatahub-io/odh-model-controller/controllers/reconcilers" @@ -29,6 +30,7 @@ type ModelRegistryReconciler struct { Scheme *runtime.Scheme Log logr.Logger Period time.Duration + mrService api.ModelRegistryApi mrISReconciler *reconcilers.ModelRegistryInferenceServiceReconciler mrSEReconciler *reconcilers.ModelRegistryServingEnvironmentReconciler } @@ -49,77 +51,80 @@ func (r *ModelRegistryReconciler) Reconcile(ctx context.Context, req ctrl.Reques log := r.Log.WithValues("ServingRuntime", req.Name, "namespace", req.Namespace) log.Info("Reconciling ModelRegistry serving for ServingRuntime: " + req.Name) - mlmdAddr := os.Getenv(constants.MLMDAddressEnv) - if mlmdAddr == "" { - // Env variable not set, look for existing model registry service - opts := []client.ListOption{client.InNamespace(req.Namespace), client.MatchingLabels{ - "component": "model-registry", - }} - mrServiceList := &corev1.ServiceList{} - err := r.Client.List(ctx, mrServiceList, opts...) - if err != nil && apierrs.IsNotFound(err) { - // No model registry deployed in the provided namespace, skipping serving reconciliation - log.Info("Stop ModelRegistry serving reconciliation") - return ctrl.Result{}, nil - } + mr := r.mrService + if mr == nil { + mlmdAddr := os.Getenv(constants.MLMDAddressEnv) + if mlmdAddr == "" { + // Env variable not set, look for existing model registry service + opts := []client.ListOption{client.InNamespace(req.Namespace), client.MatchingLabels{ + "component": "model-registry", + }} + mrServiceList := &corev1.ServiceList{} + err := r.Client.List(ctx, mrServiceList, opts...) + if err != nil && apierrs.IsNotFound(err) { + // No model registry deployed in the provided namespace, skipping serving reconciliation + log.Info("Stop ModelRegistry serving reconciliation") + return ctrl.Result{}, nil + } - if len(mrServiceList.Items) == 0 { - log.Info("No Model Registry service found for Namespace: " + req.Namespace) - log.Info("Stop ModelRegistry serving reconciliation") - return ctrl.Result{}, nil - } + if len(mrServiceList.Items) == 0 { + log.Info("No Model Registry service found for Namespace: " + req.Namespace) + log.Info("Stop ModelRegistry serving reconciliation") + return ctrl.Result{}, nil + } - // Actually we could iterate over every mrService, as nothing prevents to setup multiple MR in the same namespace - if len(mrServiceList.Items) > 1 { - log.Error(fmt.Errorf("multiple services with component=model-registry for Namespace %s", req.Namespace), "Stop ModelRegistry serving reconciliation") - return ctrl.Result{}, nil - } + // Actually we could iterate over every mrService, as nothing prevents to setup multiple MR in the same namespace + if len(mrServiceList.Items) > 1 { + log.Error(fmt.Errorf("multiple services with component=model-registry for Namespace %s", req.Namespace), "Stop ModelRegistry serving reconciliation") + return ctrl.Result{}, nil + } + + mrService := mrServiceList.Items[0] - mrService := mrServiceList.Items[0] + var grpcPort *int32 + for _, port := range mrService.Spec.Ports { + if port.Name == "grpc-api" { + grpcPort = &port.Port + break + } + } - var grpcPort *int32 - for _, port := range mrService.Spec.Ports { - if port.Name == "grpc-api" { - grpcPort = &port.Port - break + if grpcPort == nil { + log.Error(fmt.Errorf("cannot find grpc-api port for service %s", mrService.Name), "Stop ModelRegistry serving reconciliation") + return ctrl.Result{}, nil } + + mlmdAddr = fmt.Sprintf("%s.%s.svc.cluster.local:%d", mrService.Name, req.Namespace, *grpcPort) } - if grpcPort == nil { - log.Error(fmt.Errorf("cannot find grpc-api port for service %s", mrService.Name), "Stop ModelRegistry serving reconciliation") + // setup grpc connection to ml-metadata + ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + // Setup model registry service + log.Info("Connecting to " + mlmdAddr) + conn, err := grpc.DialContext( + ctxTimeout, + mlmdAddr, + grpc.WithReturnConnectionError(), + grpc.WithBlock(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + log.Error(err, "Stop ModelRegistry serving reconciliation") return ctrl.Result{}, nil } + defer conn.Close() - mlmdAddr = fmt.Sprintf("%s.%s.svc.cluster.local:%d", mrService.Name, req.Namespace, *grpcPort) - } - - // setup grpc connection to ml-metadata - ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - // Setup model registry service - log.Info("Connecting to " + mlmdAddr) - conn, err := grpc.DialContext( - ctxTimeout, - mlmdAddr, - grpc.WithReturnConnectionError(), - grpc.WithBlock(), - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil { - log.Error(err, "Stop ModelRegistry serving reconciliation") - return ctrl.Result{}, nil - } - defer conn.Close() - - mr, err := core.NewModelRegistryService(conn) - if err != nil { - log.Error(err, "Stop ModelRegistry serving reconciliation") - return ctrl.Result{}, nil + mr, err = core.NewModelRegistryService(conn) + if err != nil { + log.Error(err, "Stop ModelRegistry serving reconciliation") + return ctrl.Result{}, nil + } } // Reconcile the ServingEnvironment from Model Registry - err = r.mrSEReconciler.Reconcile(ctx, log, mr, req.Namespace) + err := r.mrSEReconciler.Reconcile(ctx, log, mr, req.Namespace) if err != nil { log.Error(err, "Stop ModelRegistry serving reconciliation") return ctrl.Result{}, nil diff --git a/controllers/modelregistry_controller_test.go b/controllers/modelregistry_controller_test.go index 0eba084b..365c4e4e 100644 --- a/controllers/modelregistry_controller_test.go +++ b/controllers/modelregistry_controller_test.go @@ -17,311 +17,145 @@ package controllers import ( "context" - "fmt" - "os" - "path/filepath" - "strings" - "time" kservev1alpha1 "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" kservev1beta1 "github.com/kserve/kserve/pkg/apis/serving/v1beta1" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "github.com/opendatahub-io/model-registry/pkg/api" - "github.com/opendatahub-io/model-registry/pkg/core" "github.com/opendatahub-io/model-registry/pkg/openapi" - "github.com/opendatahub-io/odh-model-controller/controllers/constants" - "github.com/testcontainers/testcontainers-go" - "github.com/testcontainers/testcontainers-go/wait" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" + "github.com/stretchr/testify/mock" "k8s.io/apimachinery/pkg/types" ) var ( - mlmdAddr string - mlmdTeardown func() error // mr content - registeredModel *openapi.RegisteredModel - modelVersion *openapi.ModelVersion - modelArtifact *openapi.ModelArtifact - // data - modelName = "dummy-model" - versionName = "dummy-version" - modelFormatName = "onnx" - modelFormatVersion = "1" - storagePath = "path/to/model" - storageKey = "aws-connection-models" inferenceServiceName = "dummy-inference-service" // filled at runtime servingRuntime = &kservev1alpha1.ServingRuntime{} ) -const ( - useProvider = testcontainers.ProviderDefault // or explicit to testcontainers.ProviderPodman if needed - mlmdImage = "gcr.io/tfx-oss-public/ml_metadata_store_server:1.14.0" - modelRegistryImage = "quay.io/opendatahub/model-registry:latest" - connectionConfigContent = `connection_config { - sqlite { - filename_uri: '/tmp/mlmd/odh_metadata.sqlite-%s.db' - connection_mode: READWRITE_OPENCREATE - } -} -` +var ( + // ids + servingEnvironmentId = "1" + registeredModelId = "2" + inferenceServiceId = "4" + serveModelId = "1" ) var _ = Describe("ModelRegistry controller", func() { ctx := context.Background() - var mrService api.ModelRegistryApi BeforeEach(func() { - var err error - - ctxTimeout, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - // Setup testcontainer with ml-metadata - mlmdAddr, mlmdTeardown, err = setupMlMetadataContainer() - Expect(err).NotTo(HaveOccurred()) - // override the mlmd address setting the env variable - os.Setenv(constants.MLMDAddressEnv, mlmdAddr) - - conn, err := grpc.DialContext( - ctxTimeout, - mlmdAddr, - grpc.WithReturnConnectionError(), - grpc.WithBlock(), - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - Expect(err).ToNot(HaveOccurred()) - - mrService, err = core.NewModelRegistryService(conn) - Expect(err).ToNot(HaveOccurred()) - Expect(mrService).ToNot(BeNil()) - servingRuntime = &kservev1alpha1.ServingRuntime{} - err = convertToStructuredResource(KserveServingRuntimePath1, servingRuntime) + err := convertToStructuredResource(KserveServingRuntimePath1, servingRuntime) Expect(err).NotTo(HaveOccurred()) Expect(cli.Create(ctx, servingRuntime)).Should(Succeed()) - - // fill mr with some models - registeredModel, modelVersion, modelArtifact, err = setupModels(mrService) - Expect(err).NotTo(HaveOccurred()) - }) - - AfterEach(func() { - By("Tearing down model registry test container") - Expect(mlmdTeardown()).NotTo(HaveOccurred()) + resetCalls() }) When("when a ServingRuntime is applied in WorkingNamespace", func() { - It("should create and delete InferenceService CRs based on model registry content", func() { - By("by checking that the controller has created the corresponding ServingEnvironment in model registry") - envId := "" + It("should create an InferenceService CR when model registry state is DEPLOYED", func() { + // simulate existing IS in model registry in DEPLOYED state for all calls + modelRegistryMock.On("GetInferenceServices").Return(&openapi.InferenceServiceList{ + PageSize: 1, + Size: 1, + Items: []openapi.InferenceService{ + { + Id: &inferenceServiceId, + Name: &inferenceServiceName, + RegisteredModelId: registeredModelId, + ServingEnvironmentId: servingEnvironmentId, + Runtime: &servingRuntime.Name, + State: openapi.INFERENCESERVICESTATE_DEPLOYED.Ptr(), + }, + }, + }, nil) + + isvc := &kservev1beta1.InferenceService{} Eventually(func() error { - ns := WorkingNamespace - se, err := mrService.GetServingEnvironmentByParams(&ns, nil) - if err != nil { - return err - } - envId = *se.Id - return nil + key := types.NamespacedName{Name: inferenceServiceName, Namespace: WorkingNamespace} + return cli.Get(ctx, key, isvc) }, timeout, interval).ShouldNot(HaveOccurred()) - - When("an InferenceService is created in the model registry with existing model runtime", func() { - By("by checking that the controller has created the InferenceService CR in WorkingNamespace") - // create the IS in the model registry - is, err := mrService.UpsertInferenceService(&openapi.InferenceService{ - Name: &inferenceServiceName, - RegisteredModelId: registeredModel.GetId(), - ServingEnvironmentId: envId, - Runtime: &servingRuntime.Name, - State: openapi.INFERENCESERVICESTATE_DEPLOYED.Ptr(), - }) - Expect(err).NotTo(HaveOccurred()) - isvc := &kservev1beta1.InferenceService{} - Eventually(func() error { - key := types.NamespacedName{Name: inferenceServiceName, Namespace: WorkingNamespace} - return cli.Get(ctx, key, isvc) - }, timeout, interval).ShouldNot(HaveOccurred()) - - By("by checking that the controller has removed the InferenceService CR from WorkingNamespace") - // set state to UNDEPLOYED for IS in the model registry - is.SetState(openapi.INFERENCESERVICESTATE_UNDEPLOYED) - _, err = mrService.UpsertInferenceService(is) - Expect(err).NotTo(HaveOccurred()) - isvc = &kservev1beta1.InferenceService{} - Eventually(func() error { - key := types.NamespacedName{Name: inferenceServiceName, Namespace: WorkingNamespace} - return cli.Get(ctx, key, isvc) - }, timeout, interval).Should(HaveOccurred()) - - By("by checking that the controller has created new ServeModel in the model registry") - Eventually(func() error { - sm, err := mrService.GetServeModels(api.ListOptions{}, is.Id) - if err != nil { - return err - } - if sm.Size == 0 { - return fmt.Errorf("empty serve models list") - } - return nil - }, timeout, interval).ShouldNot(HaveOccurred()) - }) }) - It("should update InferenceService CR based on model registry content", func() { - By("by checking that the controller has created the corresponding ServingEnvironment in model registry") - envId := "" + It("should delete InferenceService CR when model registry state is UNDEPLOYED", func() { + // simulate existing IS in model registry in UNDEPLOYED state for all calls + modelRegistryMock.On("GetInferenceServices").Return(&openapi.InferenceServiceList{ + PageSize: 1, + Size: 1, + Items: []openapi.InferenceService{ + { + Id: &inferenceServiceId, + Name: &inferenceServiceName, + RegisteredModelId: registeredModelId, + ServingEnvironmentId: servingEnvironmentId, + Runtime: &servingRuntime.Name, + State: openapi.INFERENCESERVICESTATE_UNDEPLOYED.Ptr(), + }, + }, + }, nil) + + // create an existing ISVC + inferenceService := &kservev1beta1.InferenceService{} + err := convertToStructuredResource(ModelRegistryInferenceServicePath1, inferenceService) + Expect(err).NotTo(HaveOccurred()) + Expect(cli.Create(ctx, inferenceService)).Should(Succeed()) + By("by checking that the InferenceService CR is in WorkingNamespace") + isvc := &kservev1beta1.InferenceService{} Eventually(func() error { - ns := WorkingNamespace - se, err := mrService.GetServingEnvironmentByParams(&ns, nil) - if err != nil { - return err - } - envId = *se.Id - return nil + key := types.NamespacedName{Name: inferenceServiceName, Namespace: WorkingNamespace} + return cli.Get(ctx, key, isvc) }, timeout, interval).ShouldNot(HaveOccurred()) - When("an InferenceService is created in the model registry with existing model runtime", func() { - By("by checking that the controller has created the InferenceService CR in WorkingNamespace") - // create the IS in the model registry - is, err := mrService.UpsertInferenceService(&openapi.InferenceService{ - Name: &inferenceServiceName, - RegisteredModelId: registeredModel.GetId(), - ServingEnvironmentId: envId, - Runtime: &servingRuntime.Name, - State: openapi.INFERENCESERVICESTATE_DEPLOYED.Ptr(), - }) - Expect(err).NotTo(HaveOccurred()) - isvc := &kservev1beta1.InferenceService{} - Eventually(func() error { - key := types.NamespacedName{Name: inferenceServiceName, Namespace: WorkingNamespace} - return cli.Get(ctx, key, isvc) - }, timeout, interval).ShouldNot(HaveOccurred()) - - By("by checking that the controller has correctly updated the InferenceService CR in WorkingNamespace") - // update model artifact content - newStoragePath := "/new/path" - modelArtifact.SetStoragePath(newStoragePath) - modelArtifact, err = mrService.UpsertModelArtifact(modelArtifact, modelVersion.Id) - Expect(err).NotTo(HaveOccurred()) - - isvc = &kservev1beta1.InferenceService{} - Eventually(func() string { - key := types.NamespacedName{Name: inferenceServiceName, Namespace: WorkingNamespace} - err := cli.Get(ctx, key, isvc) - if err != nil { - return "" - } - return *isvc.Spec.Predictor.Model.Storage.Path - }, timeout, interval).Should(Equal(newStoragePath)) - - By("by checking that the controller has created new ServeModel in the model registry") - Eventually(func() error { - sm, err := mrService.GetServeModels(api.ListOptions{}, is.Id) - if err != nil { - return err - } - if sm.Size == 0 { - return fmt.Errorf("empty serve models list") - } - return nil - }, timeout, interval).ShouldNot(HaveOccurred()) - }) - }) - }) -}) - -func setupModels(mr api.ModelRegistryApi) (*openapi.RegisteredModel, *openapi.ModelVersion, *openapi.ModelArtifact, error) { - model, err := mr.GetRegisteredModelByParams(&modelName, nil) - if err != nil { - // register a new model - model, err = mr.UpsertRegisteredModel(&openapi.RegisteredModel{ - Name: &modelName, + By("by checking that the controller has removed the InferenceService CR from WorkingNamespace") + isvc = &kservev1beta1.InferenceService{} + Eventually(func() error { + key := types.NamespacedName{Name: inferenceServiceName, Namespace: WorkingNamespace} + return cli.Get(ctx, key, isvc) + }, timeout, interval).Should(HaveOccurred()) }) - if err != nil { - return nil, nil, nil, err - } - } - - version, err := mr.GetModelVersionByParams(&versionName, model.Id, nil) - if err != nil { - version, err = mr.UpsertModelVersion(&openapi.ModelVersion{ - Name: &versionName, - }, model.Id) - if err != nil { - return nil, nil, nil, err - } - } - modelArtifactName := fmt.Sprintf("%s-artifact", versionName) - artifact, err := mr.GetModelArtifactByParams(&modelArtifactName, version.Id, nil) - if err != nil { - artifact, err = mr.UpsertModelArtifact(&openapi.ModelArtifact{ - Name: &modelArtifactName, - ModelFormatName: &modelFormatName, - ModelFormatVersion: &modelFormatVersion, - StorageKey: &storageKey, - StoragePath: &storagePath, - }, version.Id) - if err != nil { - return nil, nil, nil, err - } - } - - return model, version, artifact, nil -} - -func setupMlMetadataContainer() (string, func() error, error) { - mlmdCtx := context.TODO() - connConfigFile, err := os.CreateTemp("", "odh_mlmd_conn_config-*.pb") - if err != nil { - return "", nil, err - } - uid := strings.Replace(strings.Split(filepath.Base(connConfigFile.Name()), "-")[1], ".pb", "", 1) - _, err = connConfigFile.WriteString(fmt.Sprintf(connectionConfigContent, uid)) - if err != nil { - return "", nil, err - } - - req := testcontainers.ContainerRequest{ - Image: mlmdImage, - ExposedPorts: []string{"8080/tcp"}, - Env: map[string]string{ - "METADATA_STORE_SERVER_CONFIG_FILE": fmt.Sprintf("/tmp/mlmd/%s", filepath.Base(connConfigFile.Name())), - }, - User: "1000:1000", - Mounts: testcontainers.ContainerMounts{ - testcontainers.ContainerMount{ - Source: testcontainers.GenericBindMountSource{ - HostPath: os.TempDir(), + It("should update InferenceService CR when model registry state is DEPLOYED and there is delta", func() { + // simulate existing IS in model registry in DEPLOYED state for all other calls + modelRegistryMock.On("GetInferenceServices").Return(&openapi.InferenceServiceList{ + PageSize: 1, + Size: 1, + Items: []openapi.InferenceService{ + { + Id: &inferenceServiceId, + Name: &inferenceServiceName, + RegisteredModelId: registeredModelId, + ServingEnvironmentId: servingEnvironmentId, + Runtime: &servingRuntime.Name, + State: openapi.INFERENCESERVICESTATE_DEPLOYED.Ptr(), + }, }, - Target: "/tmp/mlmd", - }, - }, - WaitingFor: wait.ForLog("Server listening on"), - } - - mlmdgrpc, err := testcontainers.GenericContainer(mlmdCtx, testcontainers.GenericContainerRequest{ - ProviderType: useProvider, - ContainerRequest: req, - Started: true, + }, nil) + + // create an existing ISVC + inferenceService := &kservev1beta1.InferenceService{} + // storage path is /path/to/old/model + err := convertToStructuredResource(ModelRegistryInferenceServicePath1, inferenceService) + Expect(err).NotTo(HaveOccurred()) + Expect(cli.Create(ctx, inferenceService)).Should(Succeed()) + + By("by checking that the controller has correctly updated the InferenceService CR storage path") + isvc := &kservev1beta1.InferenceService{} + Eventually(func() string { + key := types.NamespacedName{Name: inferenceServiceName, Namespace: WorkingNamespace} + err := cli.Get(ctx, key, isvc) + if err != nil { + return "" + } + return *isvc.Spec.Predictor.Model.Storage.Path + }, timeout, interval).Should(Equal("path/to/model")) + }) }) - if err != nil { - return "", nil, err - } - - mappedHost, err := mlmdgrpc.Host(mlmdCtx) - if err != nil { - return "", nil, err - } - mappedPort, err := mlmdgrpc.MappedPort(mlmdCtx, "8080") - if err != nil { - return "", nil, err - } +}) - return fmt.Sprintf("%s:%s", mappedHost, mappedPort.Port()), func() (err error) { - return mlmdgrpc.Terminate(mlmdCtx) - }, nil +// resetCalls reset expected and existing calls on mocked model registry +// this is a workaround as this functionality is not exported at the moment +func resetCalls() { + modelRegistryMock.ExpectedCalls = []*mock.Call{} + modelRegistryMock.Calls = []mock.Call{} } diff --git a/controllers/reconcilers/mr_servingenvironment_reconciler.go b/controllers/reconcilers/mr_servingenvironment_reconciler.go index df6eb46f..33c091e2 100644 --- a/controllers/reconcilers/mr_servingenvironment_reconciler.go +++ b/controllers/reconcilers/mr_servingenvironment_reconciler.go @@ -24,7 +24,6 @@ func (r *ModelRegistryServingEnvironmentReconciler) Reconcile(ctx context.Contex _, err := mrClient.GetServingEnvironmentByParams(&namespace, nil) if err != nil { // Create new ServingEnvironment as not already existing - // TODO: we could fetch additional custom props from the ServingRuntime CR, needed? log.Info("Creating new ServingEnvironment for " + namespace) _, err = mrClient.UpsertServingEnvironment(&openapi.ServingEnvironment{ Name: &namespace, diff --git a/controllers/suite_test.go b/controllers/suite_test.go index 2664ad51..3d89eae9 100644 --- a/controllers/suite_test.go +++ b/controllers/suite_test.go @@ -27,7 +27,9 @@ import ( kservev1alpha1 "github.com/kserve/kserve/pkg/apis/serving/v1alpha1" kservev1beta1 "github.com/kserve/kserve/pkg/apis/serving/v1beta1" + "github.com/opendatahub-io/model-registry/pkg/openapi" "github.com/opendatahub-io/odh-model-controller/controllers/reconcilers" + "github.com/opendatahub-io/odh-model-controller/controllers/testdata/modelregistry" monitoringv1 "github.com/prometheus-operator/prometheus-operator/pkg/apis/monitoring/v1" "go.uber.org/zap/zapcore" k8srbacv1 "k8s.io/api/rbac/v1" @@ -57,27 +59,29 @@ import ( // +kubebuilder:docs-gen:collapse=Imports var ( - cli client.Client - envTest *envtest.Environment - ctx context.Context - cancel context.CancelFunc + cli client.Client + envTest *envtest.Environment + ctx context.Context + cancel context.CancelFunc + modelRegistryMock *modelregistry.ModelRegistryServiceMocked ) const ( - WorkingNamespace = "default" - MonitoringNS = "monitoring-ns" - RoleBindingPath = "./testdata/results/model-server-ns-role.yaml" - ServingRuntimePath1 = "./testdata/deploy/test-openvino-serving-runtime-1.yaml" - KserveServingRuntimePath1 = "./testdata/deploy/kserve-openvino-serving-runtime-1.yaml" - ServingRuntimePath2 = "./testdata/deploy/test-openvino-serving-runtime-2.yaml" - InferenceService1 = "./testdata/deploy/openvino-inference-service-1.yaml" - InferenceServiceNoRuntime = "./testdata/deploy/openvino-inference-service-no-runtime.yaml" - KserveInferenceServicePath1 = "./testdata/deploy/kserve-openvino-inference-service-1.yaml" - InferenceServiceConfigPath1 = "./testdata/configmaps/inferenceservice-config.yaml" - ExpectedRoutePath = "./testdata/results/example-onnx-mnist-route.yaml" - ExpectedRouteNoRuntimePath = "./testdata/results/example-onnx-mnist-no-runtime-route.yaml" - timeout = time.Second * 20 - interval = time.Millisecond * 10 + WorkingNamespace = "default" + MonitoringNS = "monitoring-ns" + RoleBindingPath = "./testdata/results/model-server-ns-role.yaml" + ServingRuntimePath1 = "./testdata/deploy/test-openvino-serving-runtime-1.yaml" + KserveServingRuntimePath1 = "./testdata/deploy/kserve-openvino-serving-runtime-1.yaml" + ServingRuntimePath2 = "./testdata/deploy/test-openvino-serving-runtime-2.yaml" + InferenceService1 = "./testdata/deploy/openvino-inference-service-1.yaml" + InferenceServiceNoRuntime = "./testdata/deploy/openvino-inference-service-no-runtime.yaml" + KserveInferenceServicePath1 = "./testdata/deploy/kserve-openvino-inference-service-1.yaml" + ModelRegistryInferenceServicePath1 = "./testdata/deploy/modelregistry-inference-service-1.yaml" + InferenceServiceConfigPath1 = "./testdata/configmaps/inferenceservice-config.yaml" + ExpectedRoutePath = "./testdata/results/example-onnx-mnist-route.yaml" + ExpectedRouteNoRuntimePath = "./testdata/results/example-onnx-mnist-no-runtime-route.yaml" + timeout = time.Second * 20 + interval = time.Millisecond * 10 ) func TestAPIs(t *testing.T) { @@ -161,11 +165,14 @@ var _ = BeforeSuite(func() { }).SetupWithManager(mgr) Expect(err).ToNot(HaveOccurred()) + modelRegistryMock = &modelregistry.ModelRegistryServiceMocked{} + err = (&ModelRegistryReconciler{ Client: cli, Log: ctrl.Log.WithName("controllers").WithName("Model-Registry-Controller"), Scheme: scheme.Scheme, Period: time.Duration(1) * time.Second, + mrService: modelRegistryMock, mrISReconciler: reconcilers.NewModelRegistryInferenceServiceReconciler(cli), mrSEReconciler: reconcilers.NewModelRegistryServingEnvironmentReconciler(cli), }).SetupWithManager(mgr) @@ -202,6 +209,14 @@ var _ = AfterSuite(func() { } }) +var _ = BeforeEach(func() { + // ensure all other tests do not trigger mr reconcilation by forcing to retun empty IS from model registry + modelRegistryMock.On("GetInferenceServices").Return(&openapi.InferenceServiceList{ + Size: 0, + Items: []openapi.InferenceService{}, + }, nil) +}) + // Cleanup resources to not contaminate between tests var _ = AfterEach(func() { inNamespace := client.InNamespace(WorkingNamespace) diff --git a/controllers/testdata/deploy/modelregistry-inference-service-1.yaml b/controllers/testdata/deploy/modelregistry-inference-service-1.yaml new file mode 100644 index 00000000..2f70718f --- /dev/null +++ b/controllers/testdata/deploy/modelregistry-inference-service-1.yaml @@ -0,0 +1,18 @@ +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + name: dummy-inference-service + namespace: default + labels: + "model-registry-is-id": "4" + "model-registry-sm-id": "1" +spec: + predictor: + model: + modelFormat: + name: onnx + version: 1 + runtime: ovms-1.x + storage: + key: aws-connection-models + path: /path/to/old/model diff --git a/controllers/testdata/modelregistry/modelregistry-mock.go b/controllers/testdata/modelregistry/modelregistry-mock.go new file mode 100644 index 00000000..7dc45510 --- /dev/null +++ b/controllers/testdata/modelregistry/modelregistry-mock.go @@ -0,0 +1,205 @@ +package modelregistry + +import ( + "fmt" + + "github.com/opendatahub-io/model-registry/pkg/api" + "github.com/opendatahub-io/model-registry/pkg/openapi" + "github.com/stretchr/testify/mock" +) + +func NewModelRegistryServiceMocked() api.ModelRegistryApi { + return &ModelRegistryServiceMocked{} +} + +var ( + // ids + servingEnvironmentId = "1" + registeredModelId = "2" + modelVersionId = "3" + modelArtifactId = "1" + serveModelId = "1" + // data + modelName = "dummy-model" + versionName = "dummy-version" + modelFormatName = "onnx" + modelFormatVersion = "1" + storagePath = "path/to/model" + storageKey = "aws-connection-models" +) + +type ModelRegistryServiceMocked struct { + mock.Mock +} + +// GetInferenceServiceById implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetInferenceServiceById(id string) (*openapi.InferenceService, error) { + panic("unimplemented") +} + +// GetInferenceServiceByParams implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetInferenceServiceByParams(name *string, parentResourceId *string, externalId *string) (*openapi.InferenceService, error) { + panic("unimplemented") +} + +// GetInferenceServices implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetInferenceServices(listOptions api.ListOptions, servingEnvironmentId *string, runtime *string) (*openapi.InferenceServiceList, error) { + args := m.MethodCalled("GetInferenceServices") + return args[0].(*openapi.InferenceServiceList), asError(args[1]) +} + +// GetModelArtifactById implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetModelArtifactById(id string) (*openapi.ModelArtifact, error) { + panic("unimplemented") +} + +// GetModelArtifactByInferenceService implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetModelArtifactByInferenceService(inferenceServiceId string) (*openapi.ModelArtifact, error) { + modelArtifactName := fmt.Sprintf("%s-artifact", versionName) + return &openapi.ModelArtifact{ + Id: &modelArtifactId, + Name: &modelArtifactName, + ModelFormatName: &modelFormatName, + ModelFormatVersion: &modelFormatVersion, + StorageKey: &storageKey, + StoragePath: &storagePath, + }, nil +} + +// GetModelArtifactByParams implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetModelArtifactByParams(artifactName *string, modelVersionId *string, externalId *string) (*openapi.ModelArtifact, error) { + panic("unimplemented") +} + +// GetModelArtifacts implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetModelArtifacts(listOptions api.ListOptions, modelVersionId *string) (*openapi.ModelArtifactList, error) { + panic("unimplemented") +} + +// GetModelVersionById implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetModelVersionById(id string) (*openapi.ModelVersion, error) { + panic("unimplemented") +} + +// GetModelVersionByInferenceService implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetModelVersionByInferenceService(inferenceServiceId string) (*openapi.ModelVersion, error) { + return &openapi.ModelVersion{ + Id: &modelVersionId, + Name: &versionName, + }, nil +} + +// GetModelVersionByParams implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetModelVersionByParams(versionName *string, registeredModelId *string, externalId *string) (*openapi.ModelVersion, error) { + panic("unimplemented") +} + +// GetModelVersions implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetModelVersions(listOptions api.ListOptions, registeredModelId *string) (*openapi.ModelVersionList, error) { + panic("unimplemented") +} + +// GetRegisteredModelById implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) { + panic("unimplemented") +} + +// GetRegisteredModelByInferenceService implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetRegisteredModelByInferenceService(inferenceServiceId string) (*openapi.RegisteredModel, error) { + panic("unimplemented") +} + +// GetRegisteredModelByParams implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetRegisteredModelByParams(name *string, externalId *string) (*openapi.RegisteredModel, error) { + return &openapi.RegisteredModel{ + Id: ®isteredModelId, + Name: &modelName, + }, nil +} + +// GetRegisteredModels implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetRegisteredModels(listOptions api.ListOptions) (*openapi.RegisteredModelList, error) { + panic("unimplemented") +} + +// GetServeModelById implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetServeModelById(id string) (*openapi.ServeModel, error) { + return &openapi.ServeModel{ + Id: &id, + LastKnownState: openapi.EXECUTIONSTATE_RUNNING.Ptr(), + ModelVersionId: modelVersionId, + }, nil +} + +// GetServeModels implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetServeModels(listOptions api.ListOptions, inferenceServiceId *string) (*openapi.ServeModelList, error) { + return &openapi.ServeModelList{ + PageSize: 1, + Size: 1, + Items: []openapi.ServeModel{ + openapi.ServeModel{ + Id: &serveModelId, + LastKnownState: openapi.EXECUTIONSTATE_RUNNING.Ptr(), + ModelVersionId: modelVersionId, + }, + }}, nil +} + +// GetServingEnvironmentById implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetServingEnvironmentById(id string) (*openapi.ServingEnvironment, error) { + panic("unimplemented") +} + +// GetServingEnvironmentByParams implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetServingEnvironmentByParams(name *string, externalId *string) (*openapi.ServingEnvironment, error) { + return &openapi.ServingEnvironment{ + Id: &servingEnvironmentId, + Name: name, + }, nil +} + +// GetServingEnvironments implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) GetServingEnvironments(listOptions api.ListOptions) (*openapi.ServingEnvironmentList, error) { + panic("unimplemented") +} + +// UpsertInferenceService implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) UpsertInferenceService(inferenceService *openapi.InferenceService) (*openapi.InferenceService, error) { + panic("unimplemented") +} + +// UpsertModelArtifact implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) UpsertModelArtifact(modelArtifact *openapi.ModelArtifact, modelVersionId *string) (*openapi.ModelArtifact, error) { + panic("unimplemented") +} + +// UpsertModelVersion implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) UpsertModelVersion(modelVersion *openapi.ModelVersion, registeredModelId *string) (*openapi.ModelVersion, error) { + panic("unimplemented") +} + +// UpsertRegisteredModel implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) { + panic("unimplemented") +} + +// UpsertServeModel implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) UpsertServeModel(serveModel *openapi.ServeModel, inferenceServiceId *string) (*openapi.ServeModel, error) { + return &openapi.ServeModel{ + Id: &serveModelId, + LastKnownState: openapi.EXECUTIONSTATE_RUNNING.Ptr(), + ModelVersionId: modelVersionId, + }, nil +} + +// UpsertServingEnvironment implements api.ModelRegistryApi. +func (m *ModelRegistryServiceMocked) UpsertServingEnvironment(registeredModel *openapi.ServingEnvironment) (*openapi.ServingEnvironment, error) { + panic("unimplemented") +} + +func asError(arg interface{}) (err error) { + if arg != nil { + err = arg.(error) + } + return +} diff --git a/go.mod b/go.mod index cd07f91e..ed540745 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/opendatahub-io/model-registry v0.0.0-20231201084346-63807ee17566 github.com/openshift/api v3.9.0+incompatible github.com/prometheus-operator/prometheus-operator/pkg/apis/monitoring v0.64.1 + github.com/stretchr/testify v1.8.4 github.com/testcontainers/testcontainers-go v0.26.0 go.uber.org/zap v1.24.0 google.golang.org/grpc v1.59.0 @@ -90,6 +91,7 @@ require ( github.com/opencontainers/image-spec v1.1.0-rc5 // indirect github.com/opencontainers/runc v1.1.5 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/prometheus/client_golang v1.15.1 // indirect github.com/prometheus/client_model v0.4.0 // indirect @@ -99,6 +101,7 @@ require ( github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect diff --git a/go.sum b/go.sum index 724ddda8..eba44977 100644 --- a/go.sum +++ b/go.sum @@ -271,6 +271,7 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prometheus-operator/prometheus-operator/pkg/apis/monitoring v0.64.1 h1:bvntWler8vOjDJtxBwGDakGNC6srSZmgawGM9Jf7HC8= @@ -305,6 +306,7 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=