From ede2a6603344fdfc5f9a2758728ca543e20e59fc Mon Sep 17 00:00:00 2001 From: Daneyon Hansen Date: Tue, 25 Feb 2025 08:08:49 -0800 Subject: [PATCH] Initial support for inference extension deployer Signed-off-by: Daneyon Hansen --- go.mod | 7 +- go.sum | 14 +- hack/utils/oss_compliance/osa_provided.md | 5 +- internal/kgateway/controller/controller.go | 147 +++++++++++- .../controller/controller_suite_test.go | 224 +++++++++++++++--- .../controller/inferencepool_controller.go | 51 ++++ internal/kgateway/controller/start.go | 21 +- internal/kgateway/crds/inferencepools.yaml | 206 ++++++++++++++++ internal/kgateway/deployer/deployer.go | 118 +++++++-- internal/kgateway/deployer/deployer_test.go | 115 ++++++++- internal/kgateway/deployer/values.go | 12 +- internal/kgateway/helm/embed.go | 3 + .../helm/inference-extension/.helmignore | 30 +++ .../helm/inference-extension/Chart.yaml | 24 ++ .../templates/_helpers.tpl | 6 + .../templates/endpoint-picker/resources.yaml | 113 +++++++++ .../helm/inference-extension/values.yaml | 12 + internal/kgateway/wellknown/gwapi.go | 3 + pkg/schemes/extended_scheme.go | 22 ++ 19 files changed, 1065 insertions(+), 68 deletions(-) create mode 100644 internal/kgateway/controller/inferencepool_controller.go create mode 100644 internal/kgateway/crds/inferencepools.yaml create mode 100644 internal/kgateway/helm/inference-extension/.helmignore create mode 100644 internal/kgateway/helm/inference-extension/Chart.yaml create mode 100644 internal/kgateway/helm/inference-extension/templates/_helpers.tpl create mode 100644 internal/kgateway/helm/inference-extension/templates/endpoint-picker/resources.yaml create mode 100644 internal/kgateway/helm/inference-extension/values.yaml diff --git a/go.mod b/go.mod index 167f995d439..f11a6ff13a0 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/hashicorp/go-multierror v1.1.1 github.com/kelseyhightower/envconfig v1.4.0 - github.com/onsi/ginkgo/v2 v2.22.1 + github.com/onsi/ginkgo/v2 v2.22.2 github.com/onsi/gomega v1.36.2 github.com/pkg/errors v0.9.1 github.com/rotisserie/eris v0.5.4 @@ -52,9 +52,10 @@ require ( k8s.io/kube-openapi v0.0.0-20241212222426-2c72e554b1e7 k8s.io/utils v0.0.0-20241210054802-24370beab758 knative.dev/pkg v0.0.0-20211206113427-18589ac7627e - sigs.k8s.io/controller-runtime v0.20.0 + sigs.k8s.io/controller-runtime v0.20.2 sigs.k8s.io/controller-tools v0.16.5 sigs.k8s.io/gateway-api v1.2.1 + sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250219213427-2577f63f6a1c sigs.k8s.io/structured-merge-diff/v4 v4.5.0 sigs.k8s.io/yaml v1.4.0 ) @@ -98,7 +99,7 @@ require ( github.com/emicklei/go-restful/v3 v3.12.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/evanphx/json-patch v5.9.0+incompatible // indirect - github.com/evanphx/json-patch/v5 v5.9.0 // indirect + github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/go.sum b/go.sum index 1f01996868b..ea9cecee8b3 100644 --- a/go.sum +++ b/go.sum @@ -284,8 +284,8 @@ github.com/evanphx/json-patch v4.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLi github.com/evanphx/json-patch v5.9.0+incompatible h1:fBXyNpNMuTTDdquAq/uisOr2lShz4oaXpDTX2bLe7ls= github.com/evanphx/json-patch v5.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.6.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4= -github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= -github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= +github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= +github.com/evanphx/json-patch/v5 v5.9.11/go.mod h1:3j+LviiESTElxA4p3EMKAB9HXj3/XEtnUf6OZxqIQTM= github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f h1:Wl78ApPPB2Wvf/TIe2xdyJxTlb6obmF18d8QdkxNDu4= github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f/go.mod h1:OSYXu++VVOHnXeitef/D8n/6y4QV8uLHSFXX4NeXMGc= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= @@ -738,8 +738,8 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108 github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.22.1 h1:QW7tbJAUDyVDVOM5dFa7qaybo+CRfR7bemlQUN6Z8aM= -github.com/onsi/ginkgo/v2 v2.22.1/go.mod h1:S6aTpoRsSq2cZOd+pssHAlKW/Q/jZt6cPrPlnj4a1xM= +github.com/onsi/ginkgo/v2 v2.22.2 h1:/3X8Panh8/WwhU/3Ssa6rCKqPLuAkVY2I0RoyDLySlU= +github.com/onsi/ginkgo/v2 v2.22.2/go.mod h1:oeMosUL+8LtarXBHu/c0bx2D/K9zyQ6uX3cTyztHwsk= github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= @@ -1627,12 +1627,14 @@ rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.0.22/go.mod h1:LEScyzhFmoF5pso/YSeBstl57mOzx9xlU9n85RGrDQg= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.1 h1:uOuSLOMBWkJH0TWa9X6l+mj5nZdm6Ay6Bli8HL8rNfk= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.1/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= -sigs.k8s.io/controller-runtime v0.20.0 h1:jjkMo29xEXH+02Md9qaVXfEIaMESSpy3TBWPrsfQkQs= -sigs.k8s.io/controller-runtime v0.20.0/go.mod h1:BrP3w158MwvB3ZbNpaAcIKkHQ7YGpYnzpoSTZ8E14WU= +sigs.k8s.io/controller-runtime v0.20.2 h1:/439OZVxoEc02psi1h4QO3bHzTgu49bb347Xp4gW1pc= +sigs.k8s.io/controller-runtime v0.20.2/go.mod h1:xg2XB0K5ShQzAgsoujxuKN4LNXR2LfwwHsPj7Iaw+XY= sigs.k8s.io/controller-tools v0.16.5 h1:5k9FNRqziBPwqr17AMEPPV/En39ZBplLAdOwwQHruP4= sigs.k8s.io/controller-tools v0.16.5/go.mod h1:8vztuRVzs8IuuJqKqbXCSlXcw+lkAv/M2sTpg55qjMY= sigs.k8s.io/gateway-api v1.2.1 h1:fZZ/+RyRb+Y5tGkwxFKuYuSRQHu9dZtbjenblleOLHM= sigs.k8s.io/gateway-api v1.2.1/go.mod h1:EpNfEXNjiYfUJypf0eZ0P5iXA9ekSGWaS1WgPaM42X0= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250219213427-2577f63f6a1c h1:YyTNvnfjzdiHXFQdRzouvQO9SKFwZkgQffnbr9YADFE= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250219213427-2577f63f6a1c/go.mod h1:H2DbSVDbCxG2cNTTgYC+V3RiotW077Xkx3fA3mRAwXs= sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 h1:gBQPwqORJ8d8/YNZWEjoZs7npUVDpVXUUOFfW6CgAqE= sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= sigs.k8s.io/kustomize/api v0.18.0 h1:hTzp67k+3NEVInwz5BHyzc9rGxIauoXferXyjv5lWPo= diff --git a/hack/utils/oss_compliance/osa_provided.md b/hack/utils/oss_compliance/osa_provided.md index cde1b025ce8..8a6c55dfef7 100644 --- a/hack/utils/oss_compliance/osa_provided.md +++ b/hack/utils/oss_compliance/osa_provided.md @@ -17,7 +17,7 @@ Name|Version|License [google/go-cmp](https://github.com/google/go-cmp)|v0.6.0|BSD 3-clause "New" or "Revised" License [grpc-ecosystem/go-grpc-middleware](https://github.com/grpc-ecosystem/go-grpc-middleware)|v1.4.0|Apache License 2.0 [kelseyhightower/envconfig](https://github.com/kelseyhightower/envconfig)|v1.4.0|MIT License -[ginkgo/v2](https://github.com/onsi/ginkgo)|v2.22.1|MIT License +[ginkgo/v2](https://github.com/onsi/ginkgo)|v2.22.2|MIT License [onsi/gomega](https://github.com/onsi/gomega)|v1.36.2|MIT License [pkg/errors](https://github.com/pkg/errors)|v0.9.1|BSD 2-clause "Simplified" License [rotisserie/eris](https://github.com/rotisserie/eris)|v0.5.4|MIT License @@ -44,9 +44,10 @@ Name|Version|License [k8s.io/kube-openapi](https://k8s.io/kube-openapi)|v0.0.0-20241212222426-2c72e554b1e7|Apache License 2.0 [k8s.io/utils](https://k8s.io/utils)|v0.0.0-20241210054802-24370beab758|Apache License 2.0 [knative.dev/pkg](https://knative.dev/pkg)|v0.0.0-20211206113427-18589ac7627e|Apache License 2.0 -[sigs.k8s.io/controller-runtime](https://sigs.k8s.io/controller-runtime)|v0.20.0|Apache License 2.0 +[sigs.k8s.io/controller-runtime](https://sigs.k8s.io/controller-runtime)|v0.20.2|Apache License 2.0 [sigs.k8s.io/controller-tools](https://sigs.k8s.io/controller-tools)|v0.16.5|Apache License 2.0 [sigs.k8s.io/gateway-api](https://sigs.k8s.io/gateway-api)|v1.2.1|Apache License 2.0 +[sigs.k8s.io/gateway-api-inference-extension](https://sigs.k8s.io/gateway-api-inference-extension)|v0.0.0-20250219213427-2577f63f6a1c|Apache License 2.0 [structured-merge-diff/v4](https://sigs.k8s.io/structured-merge-diff/v4)|v4.5.0|Apache License 2.0 [sigs.k8s.io/yaml](https://sigs.k8s.io/yaml)|v1.4.0|MIT License [cmd/goimports](https://golang.org/x/tools/cmd/goimports)|latest|MIT License diff --git a/internal/kgateway/controller/controller.go b/internal/kgateway/controller/controller.go index d1c6a281547..f9d49cfd2e4 100644 --- a/internal/kgateway/controller/controller.go +++ b/internal/kgateway/controller/controller.go @@ -17,6 +17,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" apiv1 "sigs.k8s.io/gateway-api/apis/v1" "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" @@ -29,6 +30,7 @@ const ( GatewayParamsField = "gateway-params" ) +// TODO [danehans]: Refactor so controller config is organized into shared and Gateway/InferencePool-specific controllers. type GatewayConfig struct { Mgr manager.Manager @@ -45,7 +47,7 @@ type GatewayConfig struct { func NewBaseGatewayController(ctx context.Context, cfg GatewayConfig) error { log := log.FromContext(ctx) - log.V(5).Info("starting controller", "controllerName", cfg.ControllerName) + log.V(5).Info("starting gateway controller", "controllerName", cfg.ControllerName) controllerBuilder := &controllerBuilder{ cfg: cfg, @@ -62,6 +64,58 @@ func NewBaseGatewayController(ctx context.Context, cfg GatewayConfig) error { ) } +type InferencePoolConfig struct { + Mgr manager.Manager + ControllerName string + InferenceExt *deployer.InferenceExtInfo + // OurPool determines whether a given InferencePool is managed by this controller. + OurPool func(pool *infextv1a1.InferencePool) bool +} + +func NewBaseInferencePoolController(ctx context.Context, poolCfg *InferencePoolConfig, gwCfg *GatewayConfig) error { + log := log.FromContext(ctx) + log.V(5).Info("starting inferencepool controller", "controllerName", poolCfg.ControllerName) + + // Set default OurPool if not provided. + if poolCfg.OurPool == nil { + poolCfg.OurPool = func(pool *infextv1a1.InferencePool) bool { + // List HTTPRoutes in the same namespace. + var routes apiv1.HTTPRouteList + if err := poolCfg.Mgr.GetClient().List(ctx, &routes, client.InNamespace(pool.Namespace)); err != nil { + log.Error(err, "failed to list HTTPRoutes", "namespace", pool.Namespace) + return false + } + // Iterate over each HTTPRoute to see if any rule references this pool and has a matching controllerName. + for _, route := range routes.Items { + for _, rule := range route.Spec.Rules { + for _, ref := range rule.BackendRefs { + if ref.Kind != nil && *ref.Kind == wellknown.InferencePoolKind && string(ref.Name) == pool.Name { + for _, p := range route.Status.Parents { + if p.ControllerName == apiv1.GatewayController(poolCfg.ControllerName) { + return true + } + } + } + } + } + } + return false + } + } + + // TODO [danehans]: Make GatewayConfig optional since Gateway and InferencePool are independent controllers. + controllerBuilder := &controllerBuilder{ + cfg: *gwCfg, + poolCfg: poolCfg, + reconciler: &controllerReconciler{ + cli: poolCfg.Mgr.GetClient(), + scheme: poolCfg.Mgr.GetScheme(), + }, + } + + return run(ctx, controllerBuilder.watchInferencePool) +} + func run(ctx context.Context, funcs ...func(ctx context.Context) error) error { for _, f := range funcs { if err := f(ctx); err != nil { @@ -72,8 +126,8 @@ func run(ctx context.Context, funcs ...func(ctx context.Context) error) error { } type controllerBuilder struct { - cfg GatewayConfig - + cfg GatewayConfig + poolCfg *InferencePoolConfig reconciler *controllerReconciler } @@ -98,7 +152,7 @@ func (c *controllerBuilder) watchGw(ctx context.Context) error { // setup a deployer log := log.FromContext(ctx) - log.Info("creating deployer", "ctrlname", c.cfg.ControllerName, "server", c.cfg.ControlPlane.XdsHost, "port", c.cfg.ControlPlane.XdsPort) + log.Info("creating gateway deployer", "ctrlname", c.cfg.ControllerName, "server", c.cfg.ControlPlane.XdsHost, "port", c.cfg.ControlPlane.XdsPort) d, err := deployer.NewDeployer(c.cfg.Mgr.GetClient(), &deployer.Inputs{ ControllerName: c.cfg.ControllerName, Dev: c.cfg.Dev, @@ -181,6 +235,91 @@ func (c *controllerBuilder) watchGw(ctx context.Context) error { return nil } +// watchInferencePool adds a watch on InferencePool objects (filtered by OurPool) +// as well as on HTTPRoute objects to trigger reconciliation for referenced pools. +func (c *controllerBuilder) watchInferencePool(ctx context.Context) error { + log := log.FromContext(ctx) + log.Info("creating inference extension deployer", "controller", c.cfg.ControllerName) + + d, err := deployer.NewDeployer(c.cfg.Mgr.GetClient(), &deployer.Inputs{ + ControllerName: c.cfg.ControllerName, + InferenceExtension: c.poolCfg.InferenceExt, + }) + if err != nil { + return err + } + + buildr := ctrl.NewControllerManagedBy(c.cfg.Mgr). + // For InferencePool objects that satisfy our OurPool predicate. + For(&infextv1a1.InferencePool{}, builder.WithPredicates( + predicate.NewPredicateFuncs(func(object client.Object) bool { + if pool, ok := object.(*infextv1a1.InferencePool); ok { + return c.poolCfg.OurPool(pool) + } + return false + }), + predicate.Or( + predicate.AnnotationChangedPredicate{}, + predicate.GenerationChangedPredicate{}, + ), + )). + // Watch HTTPRoute objects so that changes there trigger a reconcile for referenced pools. + Watches(&apiv1.HTTPRoute{}, handler.EnqueueRequestsFromMapFunc(func(ctx context.Context, obj client.Object) []reconcile.Request { + var reqs []reconcile.Request + route, ok := obj.(*apiv1.HTTPRoute) + if !ok { + return nil + } + // For every backend ref in every rule of the route... + for _, rule := range route.Spec.Rules { + for _, ref := range rule.BackendRefs { + if ref.Kind != nil && *ref.Kind == wellknown.InferencePoolKind { + reqs = append(reqs, reconcile.Request{ + NamespacedName: client.ObjectKey{ + Namespace: route.Namespace, + Name: string(ref.Name), + }, + }) + } + } + } + return reqs + })) + + // Watch child objects, e.g. Deployments, created by the inference pool deployer. + gvks, err := d.GetGvksToWatch(ctx) + if err != nil { + return err + } + for _, gvk := range gvks { + obj, err := c.cfg.Mgr.GetScheme().New(gvk) + if err != nil { + return err + } + clientObj, ok := obj.(client.Object) + if !ok { + return fmt.Errorf("object %T is not a client.Object", obj) + } + log.Info("watching gvk as inferencepool child", "gvk", gvk) + var opts []builder.OwnsOption + if shouldIgnoreStatusChild(gvk) { + opts = append(opts, builder.WithPredicates(predicate.GenerationChangedPredicate{})) + } + buildr.Owns(clientObj, opts...) + } + + r := &inferencePoolReconciler{ + cli: c.cfg.Mgr.GetClient(), + scheme: c.cfg.Mgr.GetScheme(), + deployer: d, + } + if err := buildr.Complete(r); err != nil { + return err + } + + return nil +} + func shouldIgnoreStatusChild(gvk schema.GroupVersionKind) bool { // avoid triggering on pod changes that update deployment status return gvk.Kind == "Deployment" diff --git a/internal/kgateway/controller/controller_suite_test.go b/internal/kgateway/controller/controller_suite_test.go index 8a0977b8166..502b4429093 100644 --- a/internal/kgateway/controller/controller_suite_test.go +++ b/internal/kgateway/controller/controller_suite_test.go @@ -15,7 +15,9 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/rest" @@ -28,11 +30,12 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/webhook" - api "sigs.k8s.io/gateway-api/apis/v1" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" apiv1 "sigs.k8s.io/gateway-api/apis/v1" "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/controller" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/deployer" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" ) @@ -51,7 +54,7 @@ var ( const ( gatewayClassName = "clsname" altGatewayClassName = "clsname-alt" - gatewayControllerName = "controller/name" + gatewayControllerName = "kgateway.dev/kgateway" ) func getAssetsDir() string { @@ -72,6 +75,14 @@ var _ = BeforeSuite(func() { ctx, cancel = context.WithCancel(context.TODO()) By("bootstrapping test environment") + // Create a scheme and add both Gateway and InferencePool types. + scheme := schemes.GatewayScheme() + err := infextv1a1.AddToScheme(scheme) + Expect(err).NotTo(HaveOccurred()) + // Required to deploy endpoint picker RBAC resources. + err = rbacv1.AddToScheme(scheme) + Expect(err).NotTo(HaveOccurred()) + testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{ filepath.Join("..", "crds"), @@ -80,14 +91,12 @@ var _ = BeforeSuite(func() { ErrorIfCRDPathMissing: true, // set assets dir so we can run without the makefile BinaryAssetsDirectory: getAssetsDir(), - // web hook to add cluster ips to services - } - var err error - cfg, err = testEnv.Start() - Expect(err).NotTo(HaveOccurred()) + var err2 error + cfg, err2 = testEnv.Start() + Expect(err2).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) - scheme := schemes.GatewayScheme() + k8sClient, err = client.New(cfg, client.Options{Scheme: scheme}) Expect(err).NotTo(HaveOccurred()) Expect(k8sClient).NotTo(BeNil()) @@ -114,9 +123,8 @@ var _ = BeforeSuite(func() { kubeconfig = generateKubeConfiguration(cfg) mgr.GetLogger().Info("starting manager", "kubeconfig", kubeconfig) - Expect(err).ToNot(HaveOccurred()) - - cfg := controller.GatewayConfig{ + // Start the Gateway controller. + gwCfg := controller.GatewayConfig{ Mgr: mgr, ControllerName: gatewayControllerName, OurGateway: func(gw *apiv1.Gateway) bool { @@ -124,27 +132,19 @@ var _ = BeforeSuite(func() { }, AutoProvision: true, } - err = controller.NewBaseGatewayController(ctx, cfg) + err = controller.NewBaseGatewayController(ctx, gwCfg) Expect(err).ToNot(HaveOccurred()) - for class := range gwClasses { - err = k8sClient.Create(ctx, &api.GatewayClass{ - ObjectMeta: metav1.ObjectMeta{ - Name: class, - }, - Spec: api.GatewayClassSpec{ - ControllerName: api.GatewayController(gatewayControllerName), - ParametersRef: &api.ParametersReference{ - Group: api.Group(v1alpha1.GroupVersion.Group), - Kind: api.Kind("GatewayParameters"), - Name: wellknown.DefaultGatewayParametersName, - Namespace: ptr.To(api.Namespace("default")), - }, - }, - }) - Expect(err).NotTo(HaveOccurred()) + // Start the inference pool controller. + poolCfg := &controller.InferencePoolConfig{ + Mgr: mgr, + ControllerName: wellknown.GatewayControllerName, + InferenceExt: new(deployer.InferenceExtInfo), } + err = controller.NewBaseInferencePoolController(ctx, poolCfg, &gwCfg) + Expect(err).ToNot(HaveOccurred()) + // Create the default GatewayParameters and GatewayClass. err = k8sClient.Create(ctx, &v1alpha1.GatewayParameters{ ObjectMeta: metav1.ObjectMeta{ Name: wellknown.DefaultGatewayParametersName, @@ -161,6 +161,25 @@ var _ = BeforeSuite(func() { }) Expect(err).NotTo(HaveOccurred()) + for class := range gwClasses { + err = k8sClient.Create(ctx, &apiv1.GatewayClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: class, + }, + Spec: apiv1.GatewayClassSpec{ + ControllerName: apiv1.GatewayController(gatewayControllerName), + ParametersRef: &apiv1.ParametersReference{ + Group: apiv1.Group(v1alpha1.GroupVersion.Group), + Kind: "GatewayParameters", + Name: wellknown.DefaultGatewayParametersName, + Namespace: ptr.To(apiv1.Namespace("default")), + }, + }, + }) + Expect(err).NotTo(HaveOccurred()) + } + + // Start the manager. go func() { defer GinkgoRecover() err = mgr.Start(ctx) @@ -204,11 +223,10 @@ func generateKubeConfiguration(restconfig *rest.Config) string { } clientConfig := clientcmdapi.Config{ - Kind: "Config", - APIVersion: "v1", - Clusters: clusters, - Contexts: contexts, - // current context must be mgmt cluster for now, as the api server doesn't have context configurable. + Kind: "Config", + APIVersion: "v1", + Clusters: clusters, + Contexts: contexts, CurrentContext: "cluster", AuthInfos: authinfos, } @@ -220,3 +238,143 @@ func generateKubeConfiguration(restconfig *rest.Config) string { Expect(err).NotTo(HaveOccurred()) return tmpfile.Name() } + +var _ = Describe("InferencePool controller", func() { + const defaultNamespace = "default" + + It("should reconcile an InferencePool referenced by an HTTPRoute managed by our controller", func() { + // Create a test Gateway that will be referenced by the HTTPRoute. + testGw := &apiv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-gateway", + Namespace: defaultNamespace, + }, + Spec: apiv1.GatewaySpec{ + GatewayClassName: gatewayClassName, + Listeners: []apiv1.Listener{ + { + Name: "listener-1", + Protocol: apiv1.HTTPProtocolType, + Port: 80, + }, + }, + }, + } + err := k8sClient.Create(ctx, testGw) + Expect(err).NotTo(HaveOccurred()) + + // Create an HTTPRoute without a status. + httpRoute := &apiv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: defaultNamespace, + }, + Spec: apiv1.HTTPRouteSpec{ + Rules: []apiv1.HTTPRouteRule{ + { + BackendRefs: []apiv1.HTTPBackendRef{ + { + BackendRef: apiv1.BackendRef{ + BackendObjectReference: apiv1.BackendObjectReference{ + Group: ptr.To(apiv1.Group(infextv1a1.GroupVersion.Group)), + Kind: ptr.To(apiv1.Kind("InferencePool")), + Name: "pool1", + }, + }, + }, + }, + }, + }, + }, + } + err = k8sClient.Create(ctx, httpRoute) + Expect(err).NotTo(HaveOccurred()) + + // Now update the status to include a valid Parents field. + httpRoute.Status = apiv1.HTTPRouteStatus{ + RouteStatus: apiv1.RouteStatus{ + Parents: []apiv1.RouteParentStatus{ + { + ParentRef: apiv1.ParentReference{ + Group: ptr.To(apiv1.Group("gateway.networking.k8s.io")), + Kind: ptr.To(apiv1.Kind("Gateway")), + Name: apiv1.ObjectName(testGw.Name), + Namespace: ptr.To(apiv1.Namespace(defaultNamespace)), + }, + ControllerName: gatewayControllerName, + }, + }, + }, + } + err = k8sClient.Status().Update(ctx, httpRoute) + Expect(err).NotTo(HaveOccurred()) + + // Create an InferencePool resource that is referenced by the HTTPRoute. + pool := &infextv1a1.InferencePool{ + TypeMeta: metav1.TypeMeta{ + Kind: "InferencePool", + APIVersion: infextv1a1.GroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "pool1", + Namespace: defaultNamespace, + UID: "pool-uid", + }, + Spec: infextv1a1.InferencePoolSpec{ + Selector: map[infextv1a1.LabelKey]infextv1a1.LabelValue{}, + TargetPortNumber: 1234, + EndpointPickerConfig: infextv1a1.EndpointPickerConfig{ + ExtensionRef: &infextv1a1.Extension{ + ExtensionReference: infextv1a1.ExtensionReference{ + Name: "doesnt-matter", + }, + }, + }, + }, + } + err = k8sClient.Create(ctx, pool) + Expect(err).NotTo(HaveOccurred()) + + // The secondary watch on HTTPRoute should now trigger reconciliation of pool "pool1". + // We expect the deployer to render and deploy an endpoint picker Deployment with name "pool1-endpoint-picker". + expectedName := fmt.Sprintf("%s-endpoint-picker", pool.Name) + var deploy appsv1.Deployment + Eventually(func() error { + return k8sClient.Get(ctx, client.ObjectKey{Namespace: defaultNamespace, Name: expectedName}, &deploy) + }, "10s", "1s").Should(Succeed()) + }) + + It("should ignore an InferencePool not referenced by any HTTPRoute", func() { + // Create an InferencePool that is not referenced by any HTTPRoute. + pool := &infextv1a1.InferencePool{ + TypeMeta: metav1.TypeMeta{ + Kind: "InferencePool", + APIVersion: infextv1a1.GroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "pool2", + Namespace: defaultNamespace, + UID: "pool2-uid", + }, + Spec: infextv1a1.InferencePoolSpec{ + Selector: map[infextv1a1.LabelKey]infextv1a1.LabelValue{}, + TargetPortNumber: 1234, + EndpointPickerConfig: infextv1a1.EndpointPickerConfig{ + ExtensionRef: &infextv1a1.Extension{ + ExtensionReference: infextv1a1.ExtensionReference{ + Name: "doesnt-matter", + }, + }, + }, + }, + } + err := k8sClient.Create(ctx, pool) + Expect(err).NotTo(HaveOccurred()) + + // Consistently check that no endpoint picker deployment is created. + Consistently(func() error { + var dep appsv1.Deployment + return k8sClient.Get(ctx, client.ObjectKey{Namespace: defaultNamespace, Name: fmt.Sprintf("%s-endpoint-picker", pool.Name)}, &dep) + }, "5s", "1s").ShouldNot(Succeed()) + }) +}) diff --git a/internal/kgateway/controller/inferencepool_controller.go b/internal/kgateway/controller/inferencepool_controller.go new file mode 100644 index 00000000000..e028a41805c --- /dev/null +++ b/internal/kgateway/controller/inferencepool_controller.go @@ -0,0 +1,51 @@ +package controller + +import ( + "context" + + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" + + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/deployer" +) + +type inferencePoolReconciler struct { + cli client.Client + scheme *runtime.Scheme + deployer *deployer.Deployer +} + +func (r *inferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := log.FromContext(ctx).WithValues("inferencepool", req.NamespacedName) + log.V(1).Info("reconciling request", "request", req) + + pool := new(infextv1a1.InferencePool) + if err := r.cli.Get(ctx, req.NamespacedName, pool); err != nil { + return ctrl.Result{}, client.IgnoreNotFound(err) + } + + if pool.GetDeletionTimestamp() != nil { + // no need to do anything as we have owner refs, so children will be deleted + log.Info("inferencepool deleted, no need for reconciling") + return ctrl.Result{}, nil + } + + objs, err := r.deployer.GetEndpointPickerObjs(pool) + if err != nil { + return ctrl.Result{}, err + } + + // TODO [danehans]: Manage inferencepool status conditions. + + err = r.deployer.DeployObjs(ctx, objs) + if err != nil { + return ctrl.Result{}, err + } + + log.V(1).Info("reconciled request", "request", req) + + return ctrl.Result{}, nil +} diff --git a/internal/kgateway/controller/start.go b/internal/kgateway/controller/start.go index b4816223139..5536ecbb55e 100644 --- a/internal/kgateway/controller/start.go +++ b/internal/kgateway/controller/start.go @@ -23,6 +23,7 @@ import ( istiokube "istio.io/istio/pkg/kube" "istio.io/istio/pkg/kube/krt" istiolog "istio.io/istio/pkg/log" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" apiv1 "sigs.k8s.io/gateway-api/apis/v1" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/deployer" @@ -103,6 +104,11 @@ func NewControllerBuilder(ctx context.Context, cfg StartConfig) (*ControllerBuil return nil, err } + // Extend the scheme if the InferencePool CRD exists. + if _, err := glooschemes.AddInferExtV1A1Scheme(cfg.RestConfig, scheme); err != nil { + return nil, err + } + mgrOpts := ctrl.Options{ BaseContext: func() context.Context { return ctx }, Scheme: scheme, @@ -227,9 +233,22 @@ func (c *ControllerBuilder) Start(ctx context.Context) error { } if err := NewBaseGatewayController(ctx, gwCfg); err != nil { - setupLog.Error(err, "unable to create controller") + setupLog.Error(err, "unable to create gateway controller") return err } + // Create the InferencePool controller if the inference extension API group is registered. + if c.mgr.GetScheme().IsGroupRegistered(infextv1a1.GroupVersion.Group) { + poolCfg := &InferencePoolConfig{ + Mgr: c.mgr, + ControllerName: wellknown.GatewayControllerName, + InferenceExt: new(deployer.InferenceExtInfo), + } + if err := NewBaseInferencePoolController(ctx, poolCfg, &gwCfg); err != nil { + setupLog.Error(err, "unable to create inferencepool controller") + return err + } + } + return c.mgr.Start(ctx) } diff --git a/internal/kgateway/crds/inferencepools.yaml b/internal/kgateway/crds/inferencepools.yaml new file mode 100644 index 00000000000..9e6473b9e20 --- /dev/null +++ b/internal/kgateway/crds/inferencepools.yaml @@ -0,0 +1,206 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.16.1 + name: inferencepools.inference.networking.x-k8s.io +spec: + group: inference.networking.x-k8s.io + names: + kind: InferencePool + listKind: InferencePoolList + plural: inferencepools + singular: inferencepool + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: InferencePool is the Schema for the InferencePools API. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: InferencePoolSpec defines the desired state of InferencePool + properties: + extensionRef: + description: Extension configures an endpoint picker as an extension + service. + properties: + failureMode: + default: FailClose + description: |- + Configures how the gateway handles the case when the extension is not responsive. + Defaults to failClose. + enum: + - FailOpen + - FailClose + type: string + group: + default: "" + description: |- + Group is the group of the referent. + When unspecified or empty string, core API group is inferred. + type: string + kind: + default: Service + description: |- + Kind is the Kubernetes resource kind of the referent. For example + "Service". + + Defaults to "Service" when not specified. + + ExternalName services can refer to CNAME DNS records that may live + outside of the cluster and as such are difficult to reason about in + terms of conformance. They also may not be safe to forward to (see + CVE-2021-25740 for more information). Implementations MUST NOT + support ExternalName Services. + type: string + name: + description: Name is the name of the referent. + type: string + targetPortNumber: + description: |- + The port number on the pods running the extension. When unspecified, implementations SHOULD infer a + default value of 9002 when the Kind is Service. + format: int32 + maximum: 65535 + minimum: 1 + type: integer + required: + - name + type: object + selector: + additionalProperties: + description: |- + LabelValue is the value of a label. This is used for validation + of maps. This matches the Kubernetes label validation rules: + * must be 63 characters or less (can be empty), + * unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]), + * could contain dashes (-), underscores (_), dots (.), and alphanumerics between. + + Valid values include: + + * MyValue + * my.name + * 123-my-value + maxLength: 63 + minLength: 0 + pattern: ^(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?$ + type: string + description: |- + Selector defines a map of labels to watch model server pods + that should be included in the InferencePool. + In some cases, implementations may translate this field to a Service selector, so this matches the simple + map used for Service selectors instead of the full Kubernetes LabelSelector type. + type: object + targetPortNumber: + description: |- + TargetPortNumber defines the port number to access the selected model servers. + The number must be in the range 1 to 65535. + format: int32 + maximum: 65535 + minimum: 1 + type: integer + required: + - extensionRef + - selector + - targetPortNumber + type: object + status: + description: InferencePoolStatus defines the observed state of InferencePool + properties: + conditions: + default: + - lastTransitionTime: "1970-01-01T00:00:00Z" + message: Waiting for controller + reason: Pending + status: Unknown + type: Ready + description: |- + Conditions track the state of the InferencePool. + + Known condition types are: + + * "Ready" + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + maxItems: 8 + type: array + x-kubernetes-list-map-keys: + - type + x-kubernetes-list-type: map + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/internal/kgateway/deployer/deployer.go b/internal/kgateway/deployer/deployer.go index 572342c82c1..764aa1569f6 100644 --- a/internal/kgateway/deployer/deployer.go +++ b/internal/kgateway/deployer/deployer.go @@ -25,6 +25,7 @@ import ( "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" api "sigs.k8s.io/gateway-api/apis/v1" "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" @@ -62,25 +63,39 @@ type AwsInfo struct { StsUri string } -// Inputs is the set of options used to configure the gateway deployer deployment +// InferenceExtInfo defines the runtime state of the Gateway API inference extension. +type InferenceExtInfo struct{} + +// Inputs is the set of options used to configure the deployer deployment type Inputs struct { ControllerName string Dev bool IstioIntegrationEnabled bool ControlPlane ControlPlaneInfo Aws *AwsInfo + InferenceExtension *InferenceExtInfo } // NewDeployer creates a new gateway deployer +// TODO [danehans]: Reloading the chart for every reconciliation is inefficient. +// See https://github.com/kgateway-dev/kgateway/issues/10672 for details. func NewDeployer(cli client.Client, inputs *Inputs) (*Deployer, error) { if inputs == nil { return nil, NilDeployerInputsErr } - helmChart, err := loadFs(helm.GlooGatewayHelmChart) - if err != nil { - return nil, err + var err error + helmChart := new(chart.Chart) + if inputs.InferenceExtension != nil { + if helmChart, err = loadFs(helm.InferenceExtensionHelmChart); err != nil { + return nil, err + } + } else { + if helmChart, err = loadFs(helm.GlooGatewayHelmChart); err != nil { + return nil, err + } } + // simulate what `helm package` in the Makefile does if version.Version != version.UndefinedVersion { helmChart.Metadata.AppVersion = version.Version @@ -109,12 +124,7 @@ func (d *Deployer) GetGvksToWatch(ctx context.Context) ([]schema.GroupVersionKin // _slightly_ more dynamic way of getting the GVKs. It isn't a perfect solution since if // we add more resources to the helm chart that are gated by a flag, we may forget to // update the values here to enable them. - emptyGw := &api.Gateway{ - ObjectMeta: metav1.ObjectMeta{ - Name: "default", - Namespace: "default", - }, - } + // TODO(Law): these must be set explicitly as we don't have defaults for them // and the internal template isn't robust enough. // This should be empty eventually -- the template must be resilient against nil-pointers @@ -128,7 +138,16 @@ func (d *Deployer) GetGvksToWatch(ctx context.Context) ([]schema.GroupVersionKin }, } - objs, err := d.renderChartToObjects(emptyGw, vals) + if d.inputs.InferenceExtension != nil { + vals = map[string]any{ + "inferenceExtension": map[string]any{ + "endpointPicker": map[string]any{}, + }, + } + } + + // The namespace and name do not matter since we only care about the GVKs of the rendered resources. + objs, err := d.renderChartToObjects("default", "default", vals) if err != nil { return nil, err } @@ -152,14 +171,14 @@ func jsonConvert(in *helmConfig, out interface{}) error { return json.Unmarshal(b, out) } -func (d *Deployer) renderChartToObjects(gw *api.Gateway, vals map[string]any) ([]client.Object, error) { - objs, err := d.Render(gw.Name, gw.Namespace, vals) +func (d *Deployer) renderChartToObjects(ns, name string, vals map[string]any) ([]client.Object, error) { + objs, err := d.Render(name, ns, vals) if err != nil { return nil, err } for _, obj := range objs { - obj.SetNamespace(gw.Namespace) + obj.SetNamespace(ns) } return objs, nil @@ -365,6 +384,25 @@ func (d *Deployer) getValues(gw *api.Gateway, gwParam *v1alpha1.GatewayParameter return vals, nil } +func (d *Deployer) getInferExtVals(pool *infextv1a1.InferencePool) (*helmConfig, error) { + if d.inputs.InferenceExtension == nil { + return nil, fmt.Errorf("inference extension input not defined for deployer") + + } + + // construct the default values + vals := &helmConfig{ + InferenceExtension: &helmInferenceExtension{ + EndpointPicker: &helmEndpointPickerExtension{ + PoolName: pool.Name, + PoolNamespace: pool.Namespace, + }, + }, + } + + return vals, nil +} + // Render relies on a `helm install` to render the Chart with the injected values // It returns the list of Objects that are rendered, and an optional error if rendering failed, // or converting the rendered manifests to objects failed. @@ -384,14 +422,19 @@ func (d *Deployer) Render(name, ns string, vals map[string]any) ([]client.Object install.ClientOnly = true installCtx := context.Background() + chartType := "gateway" + if d.inputs.InferenceExtension != nil { + chartType = "inference extension" + } + release, err := install.RunWithContext(installCtx, d.chart, vals) if err != nil { - return nil, fmt.Errorf("failed to render helm chart for gateway %s.%s: %w", ns, name, err) + return nil, fmt.Errorf("failed to render helm chart for %s %s.%s: %w", chartType, ns, name, err) } objs, err := ConvertYAMLToObjects(d.cli.Scheme(), []byte(release.Manifest)) if err != nil { - return nil, fmt.Errorf("failed to convert helm manifest yaml to objects for gateway %s.%s: %w", ns, name, err) + return nil, fmt.Errorf("failed to convert helm manifest yaml to objects for %s %s.%s: %w", chartType, ns, name, err) } return objs, nil } @@ -432,7 +475,7 @@ func (d *Deployer) GetObjsToDeploy(ctx context.Context, gw *api.Gateway) ([]clie if err != nil { return nil, fmt.Errorf("failed to convert helm values for gateway %s.%s: %w", gw.GetNamespace(), gw.GetName(), err) } - objs, err := d.renderChartToObjects(gw, convertedVals) + objs, err := d.renderChartToObjects(gw.Namespace, gw.Name, convertedVals) if err != nil { return nil, fmt.Errorf("failed to get objects to deploy for gateway %s.%s: %w", gw.GetNamespace(), gw.GetName(), err) } @@ -451,6 +494,47 @@ func (d *Deployer) GetObjsToDeploy(ctx context.Context, gw *api.Gateway) ([]clie return objs, nil } +// GetEndpointPickerObjs renders endpoint picker objects using the helm chart. +// It builds helm values from the Gateway and its associated GatewayParameters and +// sets a flag so that the chart renders only the endpoint picker objects. +func (d *Deployer) GetEndpointPickerObjs(pool *infextv1a1.InferencePool) ([]client.Object, error) { + // Build the helm values for the inference extension. + vals, err := d.getInferExtVals(pool) + if err != nil { + return nil, err + } + + // Convert the helm values struct. + var convertedVals map[string]any + if err := jsonConvert(vals, &convertedVals); err != nil { + return nil, fmt.Errorf("failed to convert inference extension helm values: %w", err) + } + + // Use a unique release name for the endpoint picker child objects. + releaseName := fmt.Sprintf("%s-endpoint-picker", pool.Name) + objs, err := d.Render(releaseName, pool.Namespace, convertedVals) + if err != nil { + return nil, fmt.Errorf("failed to render inference extension objects: %w", err) + } + + // Ensure that each rendered object has its namespace set. + for _, obj := range objs { + if obj.GetNamespace() == "" { + obj.SetNamespace(pool.Namespace) + } + // Set owner references so that these objects are tied to the InferencePool. + obj.SetOwnerReferences([]metav1.OwnerReference{{ + APIVersion: pool.APIVersion, + Kind: pool.Kind, + Name: pool.Name, + UID: pool.UID, + Controller: ptr.To(true), + }}) + } + + return objs, nil +} + func (d *Deployer) DeployObjs(ctx context.Context, objs []client.Object) error { logger := log.FromContext(ctx) for _, obj := range objs { diff --git a/internal/kgateway/deployer/deployer_test.go b/internal/kgateway/deployer/deployer_test.go index 50bc0bfa5d8..cbfa9a4cdae 100644 --- a/internal/kgateway/deployer/deployer_test.go +++ b/internal/kgateway/deployer/deployer_test.go @@ -15,12 +15,14 @@ import ( "google.golang.org/protobuf/proto" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" api "sigs.k8s.io/gateway-api/apis/v1" gw2_v1alpha1 "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" @@ -1439,12 +1441,123 @@ var _ = Describe("Deployer", func() { }), ) }) + + Context("Inference Extension endpoint picker", func() { + const defaultNamespace = "default" + + It("should deploy endpoint picker resources for an InferencePool", func() { + // Create a fake InferencePool resource. + pool := &infextv1a1.InferencePool{ + TypeMeta: metav1.TypeMeta{ + Kind: wellknown.InferencePoolKind, + APIVersion: fmt.Sprintf("%s/%s", infextv1a1.GroupVersion.Group, infextv1a1.GroupVersion.Version), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "pool1", + Namespace: defaultNamespace, + UID: "pool-uid", + }, + } + + // Initialize a new deployer with InferenceExtension inputs. + d, err := deployer.NewDeployer(newFakeClientWithObjs(pool), &deployer.Inputs{ + ControllerName: wellknown.GatewayControllerName, + InferenceExtension: &deployer.InferenceExtInfo{}, + }) + Expect(err).NotTo(HaveOccurred()) + + // Get the endpoint picker objects for the InferencePool. + objs, err := d.GetEndpointPickerObjs(pool) + Expect(err).NotTo(HaveOccurred()) + Expect(objs).NotTo(BeEmpty(), "expected non-empty objects for endpoint picker deployment") + Expect(objs).To(HaveLen(5)) + + // Find the child objects. + var sa *corev1.ServiceAccount + var role *rbacv1.Role + var rb *rbacv1.RoleBinding + var dep *appsv1.Deployment + var svc *corev1.Service + for _, obj := range objs { + switch t := obj.(type) { + case *corev1.ServiceAccount: + sa = t + case *rbacv1.Role: + role = t + case *rbacv1.RoleBinding: + rb = t + case *appsv1.Deployment: + dep = t + case *corev1.Service: + svc = t + } + } + Expect(sa).NotTo(BeNil(), "expected a ServiceAccount to be rendered") + Expect(role).NotTo(BeNil(), "expected a Role to be rendered") + Expect(rb).NotTo(BeNil(), "expected a RoleBinding to be rendered") + Expect(dep).NotTo(BeNil(), "expected a Deployment to be rendered") + Expect(svc).NotTo(BeNil(), "expected a Service to be rendered") + + // Check that owner references are set on all rendered objects to the InferencePool. + for _, obj := range objs { + ownerRefs := obj.GetOwnerReferences() + Expect(ownerRefs).To(HaveLen(1)) + ref := ownerRefs[0] + Expect(ref.Name).To(Equal(pool.Name)) + Expect(ref.UID).To(Equal(pool.UID)) + Expect(ref.Kind).To(Equal(pool.Kind)) + Expect(ref.APIVersion).To(Equal(pool.APIVersion)) + Expect(*ref.Controller).To(BeTrue()) + } + + // Validate that the rendered Deployment and Service have the expected names. + // (The template hardcodes the names to "inference-gateway-ext-proc".) + expectedName := fmt.Sprintf("%s-endpoint-picker", pool.Name) + Expect(sa.Name).To(Equal(expectedName)) + Expect(role.Name).To(Equal(expectedName)) + Expect(rb.Name).To(Equal(expectedName)) + Expect(dep.Name).To(Equal(expectedName)) + Expect(svc.Name).To(Equal(expectedName)) + + // Check the container args for the expected poolName. + Expect(dep.Spec.Template.Spec.Containers).To(HaveLen(1)) + pickerContainer := dep.Spec.Template.Spec.Containers[0] + Expect(pickerContainer.Args).To(Equal([]string{ + "-poolName", + pool.Name, + "-v", + "3", + "-grpcPort", + "9002", + "-grpcHealthPort", + "9003", + })) + }) + }) }) // initialize a fake controller-runtime client with the given list of objects func newFakeClientWithObjs(objs ...client.Object) client.Client { + scheme := schemes.GatewayScheme() + + // Ensure the rbac types are registered. + if err := rbacv1.AddToScheme(scheme); err != nil { + panic(fmt.Sprintf("failed to add rbacv1 scheme: %v", err)) + } + + // Check if any object is an InferencePool, and add its scheme if needed. + for _, obj := range objs { + gvk := obj.GetObjectKind().GroupVersionKind() + if gvk.Kind == wellknown.InferencePoolKind { + if err := infextv1a1.AddToScheme(scheme); err != nil { + panic(fmt.Sprintf("failed to add InferenceExtension scheme: %v", err)) + } + break + } + } + return fake.NewClientBuilder(). - WithScheme(schemes.GatewayScheme()). + WithScheme(scheme). WithObjects(objs...). Build() } diff --git a/internal/kgateway/deployer/values.go b/internal/kgateway/deployer/values.go index c46bb028aad..7b3505b46f5 100644 --- a/internal/kgateway/deployer/values.go +++ b/internal/kgateway/deployer/values.go @@ -8,7 +8,8 @@ import ( // The top-level helm values used by the deployer. type helmConfig struct { - Gateway *helmGateway `json:"gateway,omitempty"` + Gateway *helmGateway `json:"gateway,omitempty"` + InferenceExtension *helmInferenceExtension `json:"inferenceExtension,omitempty"` } type helmGateway struct { @@ -161,3 +162,12 @@ type helmAws struct { StsClusterName *string `json:"stsClusterName,omitempty"` StsUri *string `json:"stsUri,omitempty"` } + +type helmInferenceExtension struct { + EndpointPicker *helmEndpointPickerExtension `json:"endpointPicker,omitempty"` +} + +type helmEndpointPickerExtension struct { + PoolName string `json:"poolName"` + PoolNamespace string `json:"poolNamespace"` +} diff --git a/internal/kgateway/helm/embed.go b/internal/kgateway/helm/embed.go index a3c20aec616..0993f28b063 100644 --- a/internal/kgateway/helm/embed.go +++ b/internal/kgateway/helm/embed.go @@ -6,3 +6,6 @@ import ( //go:embed all:gloo-gateway var GlooGatewayHelmChart embed.FS + +//go:embed all:inference-extension +var InferenceExtensionHelmChart embed.FS diff --git a/internal/kgateway/helm/inference-extension/.helmignore b/internal/kgateway/helm/inference-extension/.helmignore new file mode 100644 index 00000000000..ede6884aa6b --- /dev/null +++ b/internal/kgateway/helm/inference-extension/.helmignore @@ -0,0 +1,30 @@ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ + +# template files +*-template.yaml + +# generator files +*.go +generate/ diff --git a/internal/kgateway/helm/inference-extension/Chart.yaml b/internal/kgateway/helm/inference-extension/Chart.yaml new file mode 100644 index 00000000000..4476b5e761f --- /dev/null +++ b/internal/kgateway/helm/inference-extension/Chart.yaml @@ -0,0 +1,24 @@ +apiVersion: v2 +name: inference-extension +description: A Helm chart for managing Gateway API Inference Extensions + +# A chart can be either an 'application' or a 'library' chart. +# +# Application charts are a collection of templates that can be packaged into versioned archives +# to be deployed. +# +# Library charts provide useful utilities or functions for the chart developer. They're included as +# a dependency of application charts to inject those utilities and functions into the rendering +# pipeline. Library charts do not define any templates and therefore cannot be deployed. +type: application + +# This is the chart version. This version number should be incremented each time you make changes +# to the chart and its templates, including the app version. +# Versions are expected to follow Semantic Versioning (https://semver.org/) +version: 0.0.1-alpha1 + +# This is the version number of the application being deployed. This version number should be +# incremented each time you make changes to the application. Versions are not expected to +# follow Semantic Versioning. They should reflect the version the application is using. +# It is recommended to use it with quotes. +appVersion: "0.1.0" diff --git a/internal/kgateway/helm/inference-extension/templates/_helpers.tpl b/internal/kgateway/helm/inference-extension/templates/_helpers.tpl new file mode 100644 index 00000000000..3282a1b2a3c --- /dev/null +++ b/internal/kgateway/helm/inference-extension/templates/_helpers.tpl @@ -0,0 +1,6 @@ +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "inference-extension.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} diff --git a/internal/kgateway/helm/inference-extension/templates/endpoint-picker/resources.yaml b/internal/kgateway/helm/inference-extension/templates/endpoint-picker/resources.yaml new file mode 100644 index 00000000000..0c8bac5dddc --- /dev/null +++ b/internal/kgateway/helm/inference-extension/templates/endpoint-picker/resources.yaml @@ -0,0 +1,113 @@ +{{- $endpointPicker := .Values.inferenceExtension.endpointPicker }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ .Release.Name }} +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: {{ .Release.Name }} +rules: +- apiGroups: ["inference.networking.x-k8s.io"] + resources: ["inferencemodels"] + verbs: ["get", "watch", "list"] +- apiGroups: [""] + resources: ["pods"] + verbs: ["get", "watch", "list"] +- apiGroups: ["inference.networking.x-k8s.io"] + resources: ["inferencepools"] + verbs: ["get", "watch", "list"] +- apiGroups: ["discovery.k8s.io"] + resources: ["endpointslices"] + verbs: ["get", "watch", "list"] +- apiGroups: + - authentication.k8s.io + resources: + - tokenreviews + verbs: + - create +- apiGroups: + - authorization.k8s.io + resources: + - subjectaccessreviews + verbs: + - create +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: {{ .Release.Name }} +subjects: +- kind: ServiceAccount + name: {{ .Release.Name }} +roleRef: + kind: Role + name: {{ .Release.Name }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ .Release.Name }} + labels: + app.kubernetes.io/component: endpoint-picker + app.kubernetes.io/name: {{ .Release.Name }} + app.kubernetes.io/instance: {{ .Release.Name }} +spec: + replicas: 1 + selector: + matchLabels: + app: {{ .Release.Name }} + template: + metadata: + labels: + app: {{ .Release.Name }} + spec: + serviceAccountName: {{ .Release.Name }} + containers: + - name: endpoint-picker + args: + - -poolName + - {{ $endpointPicker.poolName }} + - -v + - "3" + - -grpcPort + - "9002" + - -grpcHealthPort + - "9003" + image: "registry.k8s.io/gateway-api-inference-extension/epp:v0.1.0" + imagePullPolicy: IfNotPresent + ports: + - containerPort: 9002 + - containerPort: 9003 + - name: metrics + containerPort: 9090 + livenessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 + readinessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 +--- +apiVersion: v1 +kind: Service +metadata: + name: {{ .Release.Name }} + labels: + app.kubernetes.io/component: endpoint-picker + app.kubernetes.io/name: {{ .Release.Name }} + app.kubernetes.io/instance: {{ .Release.Name }} +spec: + selector: + app: {{ .Release.Name }} + ports: + - protocol: TCP + port: 9002 + targetPort: 9002 + type: ClusterIP diff --git a/internal/kgateway/helm/inference-extension/values.yaml b/internal/kgateway/helm/inference-extension/values.yaml new file mode 100644 index 00000000000..6cb93c4c50a --- /dev/null +++ b/internal/kgateway/helm/inference-extension/values.yaml @@ -0,0 +1,12 @@ +# These values represent configurable values for the dynamic inference extension chart +# They are not intended to be actual "defaults," rather they are just placeholder values +# meant to allow rendering of the chart/template, as the real values will come from: +# * The `InferencePool` resource driving the inference extension provisioning +# * A (possibly merged) GatewayParameters object translated to helm values +# The actual defaults for these values should come from the "default GatewayParameters" object +# See: (install/helm/kgateway/templates/gatewayparameters.yaml) + +inferenceExtension: + endpointPicker: + poolName: default + poolNamespace: default diff --git a/internal/kgateway/wellknown/gwapi.go b/internal/kgateway/wellknown/gwapi.go index 4dc308bb109..b430ddbad7e 100644 --- a/internal/kgateway/wellknown/gwapi.go +++ b/internal/kgateway/wellknown/gwapi.go @@ -32,6 +32,9 @@ const ( // Kind string for ReferenceGrant resource ReferenceGrantKind = "ReferenceGrant" + // Kind string for InferencePool resource + InferencePoolKind = "InferencePool" + // Kind strings for Gateway API list types HTTPRouteListKind = "HTTPRouteList" GatewayListKind = "GatewayList" diff --git a/pkg/schemes/extended_scheme.go b/pkg/schemes/extended_scheme.go index 26a2938d04d..dc4e1b23afc 100644 --- a/pkg/schemes/extended_scheme.go +++ b/pkg/schemes/extended_scheme.go @@ -3,6 +3,7 @@ package schemes import ( "fmt" + rbacv1 "k8s.io/api/rbac/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/runtime" @@ -11,6 +12,7 @@ import ( "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" gwv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2" ) @@ -30,6 +32,26 @@ func AddGatewayV1A2Scheme(restConfig *rest.Config, scheme *runtime.Scheme) error return nil } +// AddInferExtV1A1Scheme adds the Inference Extension v1alpha1 scheme to the provided scheme if the InferencePool CRD exists. +func AddInferExtV1A1Scheme(restConfig *rest.Config, scheme *runtime.Scheme) (bool, error) { + exists, err := CRDExists(restConfig, infextv1a1.GroupVersion.Group, gwv1a2.GroupVersion.Version, wellknown.InferencePoolKind) + if err != nil { + return false, fmt.Errorf("error checking if %s CRD exists: %w", wellknown.InferencePoolKind, err) + } + + if exists { + // Required to deploy RBAC resources for endpoint picker extension. + if err := rbacv1.AddToScheme(scheme); err != nil { + return false, fmt.Errorf("error adding RBAC v1 to scheme: %w", err) + } + if err := infextv1a1.AddToScheme(scheme); err != nil { + return false, fmt.Errorf("error adding Gateway API Inference Extension v1alpha1 to scheme: %w", err) + } + } + + return exists, nil +} + // Helper function to check if a CRD exists func CRDExists(restConfig *rest.Config, group, version, kind string) (bool, error) { discoveryClient, err := discovery.NewDiscoveryClientForConfig(restConfig)