Skip to content

Commit

Permalink
Merge pull request #165 from kubeflow/main
Browse files Browse the repository at this point in the history
[pull] main from kubeflow:main
  • Loading branch information
openshift-merge-bot[bot] authored Feb 3, 2025
2 parents 9c9a392 + 9a47415 commit df2ec33
Show file tree
Hide file tree
Showing 9 changed files with 702 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/fossa-license-scanning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
uses: actions/checkout@v4

- name: Run FOSSA scan and upload build data
uses: fossas/fossa-action@v1.4.0
uses: fossas/fossa-action@v1.5.0
with:
api-key: ${{ env.FOSSA_API_KEY }}
project: "github.com/kubeflow/model-registry"
86 changes: 86 additions & 0 deletions api/openapi/model-registry.yaml

Large diffs are not rendered by default.

148 changes: 111 additions & 37 deletions cmd/proxy.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
package cmd

import (
"context"
"fmt"
"net/http"
"time"

"github.com/golang/glog"
"github.com/kubeflow/model-registry/internal/mlmdtypes"
"github.com/kubeflow/model-registry/internal/proxy"
"github.com/kubeflow/model-registry/internal/server/openapi"
"github.com/kubeflow/model-registry/pkg/core"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)

const (
// mlmdUnavailableMessage is the message returned when the MLMD server is down or unavailable.
mlmdUnavailableMessage = "MLMD server is down or unavailable. Please check that the database is reachable and try again later."
// maxGRPCRetryAttempts is the maximum number of attempts to retry GRPC requests to the MLMD server.
maxGRPCRetryAttempts = 25 // 25 attempts with incremental backoff (1s, 2s, 3s, ..., 25s) it's ~5 minutes
)

// proxyCmd represents the proxy command
Expand All @@ -27,43 +36,108 @@ hostname and port where it listens.'`,
}

func runProxyServer(cmd *cobra.Command, args []string) error {
glog.Infof("proxy server started at %s:%v", cfg.Hostname, cfg.Port)

ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort)
glog.Infof("connecting to MLMD server %s..", mlmdAddr)
conn, err := grpc.DialContext( // nolint:staticcheck
ctxTimeout,
mlmdAddr,
grpc.WithReturnConnectionError(), // nolint:staticcheck
grpc.WithBlock(), // nolint:staticcheck
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return fmt.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err)
}
defer conn.Close()
glog.Infof("connected to MLMD server")

mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()
_, err = mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig)
if err != nil {
return fmt.Errorf("error creating MLMD types: %v", err)
var conn *grpc.ClientConn
var err error

errMLMDChan := make(chan error, 1)
errProxyChan := make(chan error, 1)

router := proxy.NewDynamicRouter()

router.SetRouter(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, mlmdUnavailableMessage, http.StatusServiceUnavailable)
}))

// Start the connection to the MLMD server in a separate goroutine, so that
// we can start the proxy server and start serving requests while we wait
// for the connection to be established.
go func() {
defer close(errMLMDChan)

mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort)
glog.Infof("connecting to MLMD server %s..", mlmdAddr)
conn, err = grpc.NewClient(mlmdAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
errMLMDChan <- fmt.Errorf("error dialing connection to mlmd server %s: %w", mlmdAddr, err)

return
}

mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()

// Backoff and retry GRPC requests to the MLMD server, until the server
// becomes available or the maximum number of attempts is reached.
for i := 0; i < maxGRPCRetryAttempts; i++ {
_, err := mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig)
if err == nil {
break
}

st, ok := status.FromError(err)
if !ok || st.Code() != codes.Unavailable {
errMLMDChan <- fmt.Errorf("error creating MLMD types: %w", err)

return
}

time.Sleep(time.Duration(i+1) * time.Second)
}

service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig)
if err != nil {
errMLMDChan <- fmt.Errorf("error creating core service: %w", err)

return
}

ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(service)
ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService)

router.SetRouter(openapi.NewRouter(ModelRegistryServiceAPIController))

glog.Infof("connected to MLMD server")
}()

// Start the proxy server in a separate goroutine so that we can handle
// errors from both the proxy server and the connection to the MLMD server.
go func() {
defer close(errProxyChan)

glog.Infof("proxy server started at %s:%v", cfg.Hostname, cfg.Port)

err := http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
if err != nil {
errProxyChan <- fmt.Errorf("error starting proxy server: %w", err)
}
}()

defer func() {
if conn != nil {
glog.Info("closing connection to MLMD server")

conn.Close()
}
}()

// Wait for either the MLMD server connection or the proxy server to return an error
// or for both to finish successfully.
for {
select {
case err := <-errMLMDChan:
if err != nil {
return err
}

case err := <-errProxyChan:
if err != nil {
return err
}
}

if errMLMDChan == nil && errProxyChan == nil {
return nil
}
}
service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig)
if err != nil {
return fmt.Errorf("error creating core service: %v", err)
}

ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(service)
ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService)

router := openapi.NewRouter(ModelRegistryServiceAPIController)

glog.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router))
return nil
}

func init() {
Expand Down
6 changes: 4 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ module github.com/kubeflow/model-registry

go 1.22

toolchain go1.22.11

require (
github.com/go-chi/chi/v5 v5.1.0
github.com/go-chi/cors v1.2.1
github.com/go-logr/logr v1.4.1
github.com/golang/glog v1.2.2
github.com/go-logr/logr v1.4.2
github.com/golang/glog v1.2.4
github.com/kserve/kserve v0.12.1
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.30.0
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ=
Expand All @@ -115,8 +115,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY=
github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc=
github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down
14 changes: 7 additions & 7 deletions internal/mlmdtypes/mlmdtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,37 +128,37 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface, nameConfig MLMDTypeNamesConfig

registeredModelResp, err := client.PutContextType(context.Background(), &registeredModelReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.RegisteredModelTypeName, err)
return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.RegisteredModelTypeName, err)
}

modelVersionResp, err := client.PutContextType(context.Background(), &modelVersionReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.ModelVersionTypeName, err)
return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.ModelVersionTypeName, err)
}

docArtifactResp, err := client.PutArtifactType(context.Background(), &docArtifactReq)
if err != nil {
return nil, fmt.Errorf("error setting up artifact type %s: %v", nameConfig.DocArtifactTypeName, err)
return nil, fmt.Errorf("error setting up artifact type %s: %w", nameConfig.DocArtifactTypeName, err)
}

modelArtifactResp, err := client.PutArtifactType(context.Background(), &modelArtifactReq)
if err != nil {
return nil, fmt.Errorf("error setting up artifact type %s: %v", nameConfig.ModelArtifactTypeName, err)
return nil, fmt.Errorf("error setting up artifact type %s: %w", nameConfig.ModelArtifactTypeName, err)
}

servingEnvironmentResp, err := client.PutContextType(context.Background(), &servingEnvironmentReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.ServingEnvironmentTypeName, err)
return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.ServingEnvironmentTypeName, err)
}

inferenceServiceResp, err := client.PutContextType(context.Background(), &inferenceServiceReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.InferenceServiceTypeName, err)
return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.InferenceServiceTypeName, err)
}

serveModelResp, err := client.PutExecutionType(context.Background(), &serveModelReq)
if err != nil {
return nil, fmt.Errorf("error setting up execution type %s: %v", nameConfig.ServeModelTypeName, err)
return nil, fmt.Errorf("error setting up execution type %s: %w", nameConfig.ServeModelTypeName, err)
}

typesMap := map[string]int64{
Expand Down
39 changes: 39 additions & 0 deletions internal/proxy/dynamic_router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Package proxy provides dynamic routing capabilities for HTTP servers.
//
// This file contains the implementation of a dynamic router that allows
// changing the HTTP handler at runtime in a thread-safe manner. It is
// particularly useful for proxy servers that need to update their routing
// logic wihtout restarting the server.
package proxy

import (
"net/http"
"sync"
)

type dynamicRouter struct {
mu sync.RWMutex
router http.Handler
}

func NewDynamicRouter() *dynamicRouter {
return &dynamicRouter{}
}

func (d *dynamicRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
d.mu.RLock()

router := d.router

d.mu.RUnlock()

router.ServeHTTP(w, r)
}

func (d *dynamicRouter) SetRouter(router http.Handler) {
d.mu.Lock()

d.router = router

d.mu.Unlock()
}
10 changes: 10 additions & 0 deletions pkg/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package api
import (
"errors"
"net/http"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
Expand All @@ -11,6 +14,13 @@ var (
)

func ErrToStatus(err error) int {
// If the error is a gRPC error, we can extract the status code.
if status, ok := status.FromError(err); ok {
if status.Code() == codes.Unavailable {
return http.StatusServiceUnavailable
}
}

switch errors.Unwrap(err) {
case ErrBadRequest:
return http.StatusBadRequest
Expand Down
Loading

0 comments on commit df2ec33

Please sign in to comment.