diff --git a/pkg/utils/client/clientset.go b/pkg/utils/client/clientset.go index 040f9fd92..ceb1df22f 100644 --- a/pkg/utils/client/clientset.go +++ b/pkg/utils/client/clientset.go @@ -101,6 +101,7 @@ func (g *clientset) ToRESTConfig() (*rest.Config, error) { restConfig.RateLimiter = flowcontrol.NewFakeAlwaysRateLimiter() restConfig.UserAgent = version.DefaultUserAgent() restConfig.NegotiatedSerializer = unstructuredscheme.NewUnstructuredNegotiatedSerializer() + restConfig.Wrap(newRoundTripperPool) g.restConfig = restConfig for _, opt := range g.opts { diff --git a/pkg/utils/client/pools.go b/pkg/utils/client/pools.go new file mode 100644 index 000000000..530dcf1a1 --- /dev/null +++ b/pkg/utils/client/pools.go @@ -0,0 +1,101 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "fmt" + "io" + "net/http" + "sync" + + "sigs.k8s.io/kwok/pkg/utils/pools" +) + +type roundTripperPool struct { + p *pools.Pool[http.RoundTripper] +} + +func newRoundTripperPool(rt http.RoundTripper) http.RoundTripper { + if rt == nil { + rt = http.DefaultTransport + } + + return &roundTripperPool{ + p: pools.NewPool(func() http.RoundTripper { + return cloneRoundTripper(rt) + }), + } +} + +func (p *roundTripperPool) RoundTrip(req *http.Request) (*http.Response, error) { + t := p.p.Get() + + resp, err := t.RoundTrip(req) + if err != nil { + p.p.Put(t) + return resp, err + } + + if resp.Body == nil { + p.p.Put(t) + } else { + resp.Body = &responseBody{ + fun: func() { + p.p.Put(t) + }, + rc: resp.Body, + } + } + + return resp, err +} + +func cloneRoundTripper(rt http.RoundTripper) http.RoundTripper { + transport, isTransport := rt.(*http.Transport) + if !isTransport { + panic(fmt.Sprintf("unexpected non-http transport %T", rt)) + } + + return transport.Clone() +} + +type responseBody struct { + o sync.Once + fun func() + rc io.ReadCloser + err error +} + +func (b *responseBody) cleanup() { + b.o.Do(func() { + b.err = b.rc.Close() + b.fun() + }) +} + +func (b *responseBody) Read(p []byte) (n int, err error) { + n, err = b.rc.Read(p) + if err != nil { + b.cleanup() + } + return n, err +} + +func (b *responseBody) Close() error { + b.cleanup() + return b.err +}