Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from kubeflow:main #165

Merged
merged 5 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/fossa-license-scanning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
uses: actions/checkout@v4

- name: Run FOSSA scan and upload build data
uses: fossas/fossa-action@v1.4.0
uses: fossas/fossa-action@v1.5.0
with:
api-key: ${{ env.FOSSA_API_KEY }}
project: "github.com/kubeflow/model-registry"
86 changes: 86 additions & 0 deletions api/openapi/model-registry.yaml

Large diffs are not rendered by default.

148 changes: 111 additions & 37 deletions cmd/proxy.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
package cmd

import (
"context"
"fmt"
"net/http"
"time"

"github.com/golang/glog"
"github.com/kubeflow/model-registry/internal/mlmdtypes"
"github.com/kubeflow/model-registry/internal/proxy"
"github.com/kubeflow/model-registry/internal/server/openapi"
"github.com/kubeflow/model-registry/pkg/core"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)

const (
// mlmdUnavailableMessage is the message returned when the MLMD server is down or unavailable.
mlmdUnavailableMessage = "MLMD server is down or unavailable. Please check that the database is reachable and try again later."
// maxGRPCRetryAttempts is the maximum number of attempts to retry GRPC requests to the MLMD server.
maxGRPCRetryAttempts = 25 // 25 attempts with incremental backoff (1s, 2s, 3s, ..., 25s) it's ~5 minutes
)

// proxyCmd represents the proxy command
Expand All @@ -27,43 +36,108 @@ hostname and port where it listens.'`,
}

func runProxyServer(cmd *cobra.Command, args []string) error {
glog.Infof("proxy server started at %s:%v", cfg.Hostname, cfg.Port)

ctxTimeout, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort)
glog.Infof("connecting to MLMD server %s..", mlmdAddr)
conn, err := grpc.DialContext( // nolint:staticcheck
ctxTimeout,
mlmdAddr,
grpc.WithReturnConnectionError(), // nolint:staticcheck
grpc.WithBlock(), // nolint:staticcheck
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return fmt.Errorf("error dialing connection to mlmd server %s: %v", mlmdAddr, err)
}
defer conn.Close()
glog.Infof("connected to MLMD server")

mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()
_, err = mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig)
if err != nil {
return fmt.Errorf("error creating MLMD types: %v", err)
var conn *grpc.ClientConn
var err error

errMLMDChan := make(chan error, 1)
errProxyChan := make(chan error, 1)

router := proxy.NewDynamicRouter()

router.SetRouter(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, mlmdUnavailableMessage, http.StatusServiceUnavailable)
}))

// Start the connection to the MLMD server in a separate goroutine, so that
// we can start the proxy server and start serving requests while we wait
// for the connection to be established.
go func() {
defer close(errMLMDChan)

mlmdAddr := fmt.Sprintf("%s:%d", proxyCfg.MLMDHostname, proxyCfg.MLMDPort)
glog.Infof("connecting to MLMD server %s..", mlmdAddr)
conn, err = grpc.NewClient(mlmdAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
errMLMDChan <- fmt.Errorf("error dialing connection to mlmd server %s: %w", mlmdAddr, err)

return
}

mlmdTypeNamesConfig := mlmdtypes.NewMLMDTypeNamesConfigFromDefaults()

// Backoff and retry GRPC requests to the MLMD server, until the server
// becomes available or the maximum number of attempts is reached.
for i := 0; i < maxGRPCRetryAttempts; i++ {
_, err := mlmdtypes.CreateMLMDTypes(conn, mlmdTypeNamesConfig)
if err == nil {
break
}

st, ok := status.FromError(err)
if !ok || st.Code() != codes.Unavailable {
errMLMDChan <- fmt.Errorf("error creating MLMD types: %w", err)

return
}

time.Sleep(time.Duration(i+1) * time.Second)
}

service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig)
if err != nil {
errMLMDChan <- fmt.Errorf("error creating core service: %w", err)

return
}

ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(service)
ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService)

router.SetRouter(openapi.NewRouter(ModelRegistryServiceAPIController))

glog.Infof("connected to MLMD server")
}()

// Start the proxy server in a separate goroutine so that we can handle
// errors from both the proxy server and the connection to the MLMD server.
go func() {
defer close(errProxyChan)

glog.Infof("proxy server started at %s:%v", cfg.Hostname, cfg.Port)

err := http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router)
if err != nil {
errProxyChan <- fmt.Errorf("error starting proxy server: %w", err)
}
}()

defer func() {
if conn != nil {
glog.Info("closing connection to MLMD server")

conn.Close()
}
}()

// Wait for either the MLMD server connection or the proxy server to return an error
// or for both to finish successfully.
for {
select {
case err := <-errMLMDChan:
if err != nil {
return err
}

case err := <-errProxyChan:
if err != nil {
return err
}
}

if errMLMDChan == nil && errProxyChan == nil {
return nil
}
}
service, err := core.NewModelRegistryService(conn, mlmdTypeNamesConfig)
if err != nil {
return fmt.Errorf("error creating core service: %v", err)
}

ModelRegistryServiceAPIService := openapi.NewModelRegistryServiceAPIService(service)
ModelRegistryServiceAPIController := openapi.NewModelRegistryServiceAPIController(ModelRegistryServiceAPIService)

router := openapi.NewRouter(ModelRegistryServiceAPIController)

glog.Fatal(http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router))
return nil
}

func init() {
Expand Down
6 changes: 4 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ module github.com/kubeflow/model-registry

go 1.22

toolchain go1.22.11

require (
github.com/go-chi/chi/v5 v5.1.0
github.com/go-chi/cors v1.2.1
github.com/go-logr/logr v1.4.1
github.com/golang/glog v1.2.2
github.com/go-logr/logr v1.4.2
github.com/golang/glog v1.2.4
github.com/kserve/kserve v0.12.1
github.com/onsi/ginkgo v1.16.5
github.com/onsi/gomega v1.30.0
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.3.0/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ=
github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ=
Expand All @@ -115,8 +115,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY=
github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc=
github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down
14 changes: 7 additions & 7 deletions internal/mlmdtypes/mlmdtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,37 +128,37 @@ func CreateMLMDTypes(cc grpc.ClientConnInterface, nameConfig MLMDTypeNamesConfig

registeredModelResp, err := client.PutContextType(context.Background(), &registeredModelReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.RegisteredModelTypeName, err)
return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.RegisteredModelTypeName, err)
}

modelVersionResp, err := client.PutContextType(context.Background(), &modelVersionReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.ModelVersionTypeName, err)
return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.ModelVersionTypeName, err)
}

docArtifactResp, err := client.PutArtifactType(context.Background(), &docArtifactReq)
if err != nil {
return nil, fmt.Errorf("error setting up artifact type %s: %v", nameConfig.DocArtifactTypeName, err)
return nil, fmt.Errorf("error setting up artifact type %s: %w", nameConfig.DocArtifactTypeName, err)
}

modelArtifactResp, err := client.PutArtifactType(context.Background(), &modelArtifactReq)
if err != nil {
return nil, fmt.Errorf("error setting up artifact type %s: %v", nameConfig.ModelArtifactTypeName, err)
return nil, fmt.Errorf("error setting up artifact type %s: %w", nameConfig.ModelArtifactTypeName, err)
}

servingEnvironmentResp, err := client.PutContextType(context.Background(), &servingEnvironmentReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.ServingEnvironmentTypeName, err)
return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.ServingEnvironmentTypeName, err)
}

inferenceServiceResp, err := client.PutContextType(context.Background(), &inferenceServiceReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", nameConfig.InferenceServiceTypeName, err)
return nil, fmt.Errorf("error setting up context type %s: %w", nameConfig.InferenceServiceTypeName, err)
}

serveModelResp, err := client.PutExecutionType(context.Background(), &serveModelReq)
if err != nil {
return nil, fmt.Errorf("error setting up execution type %s: %v", nameConfig.ServeModelTypeName, err)
return nil, fmt.Errorf("error setting up execution type %s: %w", nameConfig.ServeModelTypeName, err)
}

typesMap := map[string]int64{
Expand Down
39 changes: 39 additions & 0 deletions internal/proxy/dynamic_router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Package proxy provides dynamic routing capabilities for HTTP servers.
//
// This file contains the implementation of a dynamic router that allows
// changing the HTTP handler at runtime in a thread-safe manner. It is
// particularly useful for proxy servers that need to update their routing
// logic wihtout restarting the server.
package proxy

import (
"net/http"
"sync"
)

type dynamicRouter struct {
mu sync.RWMutex
router http.Handler
}

func NewDynamicRouter() *dynamicRouter {
return &dynamicRouter{}
}

func (d *dynamicRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
d.mu.RLock()

router := d.router

d.mu.RUnlock()

router.ServeHTTP(w, r)
}

func (d *dynamicRouter) SetRouter(router http.Handler) {
d.mu.Lock()

d.router = router

d.mu.Unlock()
}
10 changes: 10 additions & 0 deletions pkg/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package api
import (
"errors"
"net/http"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var (
Expand All @@ -11,6 +14,13 @@ var (
)

func ErrToStatus(err error) int {
// If the error is a gRPC error, we can extract the status code.
if status, ok := status.FromError(err); ok {
if status.Code() == codes.Unavailable {
return http.StatusServiceUnavailable
}
}

switch errors.Unwrap(err) {
case ErrBadRequest:
return http.StatusBadRequest
Expand Down
Loading