Skip to content

[apiserver] Add retry and timeout to apiserver V2 #3869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions apiserversdk/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package apiserversdk

import "time"

// TODO: Make apiserver configs compatible with V1
const (
// Max retry times for HTTP Client
HTTPClientDefaultMaxRetry = 3

// Retry backoff settings
HTTPClientDefaultBackoffBase = float64(2)
HTTPClientDefaultInitBackoff = 500 * time.Millisecond
HTTPClientDefaultMaxBackoff = 10 * time.Second

// Overall timeout for retries
HTTPClientDefaultOverallTimeout = 30 * time.Second
)
124 changes: 121 additions & 3 deletions apiserversdk/proxy.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package apiserversdk

import (
"bytes"
"fmt"
"io"
"math"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/net"
Expand All @@ -22,12 +27,14 @@ type MuxConfig struct {
func NewMux(config MuxConfig) (*http.ServeMux, error) {
u, err := url.Parse(config.KubernetesConfig.Host) // parse the K8s API server URL from the KubernetesConfig.
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to parse url %s from config: %w", config.KubernetesConfig.Host, err)
}
proxy := httputil.NewSingleHostReverseProxy(u)
if proxy.Transport, err = rest.TransportFor(config.KubernetesConfig); err != nil { // rest.TransportFor provides the auth to the K8s API server.
return nil, err
baseTransport, err := rest.TransportFor(config.KubernetesConfig) // rest.TransportFor provides the auth to the K8s API server.
if err != nil {
return nil, fmt.Errorf("failed to get transport for config: %w", err)
}
proxy.Transport = newRetryRoundTripper(baseTransport)
var handler http.Handler = proxy
if config.Middleware != nil {
handler = config.Middleware(proxy)
Expand Down Expand Up @@ -84,3 +91,114 @@ func requireKubeRayService(handler http.Handler, k8sClient *kubernetes.Clientset
handler.ServeHTTP(w, r)
})
}

// retryRoundTripper is a custom implementation of http.RoundTripper that retries HTTP requests.
// It verifies retryable HTTP status codes and retries using exponential backoff.
type retryRoundTripper struct {
base http.RoundTripper

// Num of retries after the initial attempt
maxRetries int
}

func newRetryRoundTripper(base http.RoundTripper) http.RoundTripper {
return &retryRoundTripper{base: base, maxRetries: HTTPClientDefaultMaxRetry}
}

func (rrt *retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()

var resp *http.Response
var err error
for attempt := 0; attempt <= rrt.maxRetries; attempt++ {
/* Try up to (rrt.maxRetries + 1) times: initial attempt + retries */

if attempt == 0 && req.Body != nil && req.GetBody == nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me understand what these two if blocks are doing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see we are reusing the body. Should we add comments explaining a bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments added. PTAL

/* Reuse request body in each attempt */
bodyBytes, err := io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body for retry support: %w", err)
}
err = req.Body.Close()
if err != nil {
return nil, fmt.Errorf("failed to close request body: %w", err)
}
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(bodyBytes)), nil
}
}

if attempt > 0 && req.GetBody != nil {
var bodyCopy io.ReadCloser
bodyCopy, err = req.GetBody()
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
req.Body = bodyCopy
}

resp, err = rrt.base.RoundTrip(req)
if err != nil {
return resp, fmt.Errorf("request to %s %s failed with error: %w", req.Method, req.URL.String(), err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to return nil instead of respwhen error occured?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, if any infomation is return in resp on err != nil from K8s APIServer, we should also return them.

}

if isSuccessfulStatusCode(resp.StatusCode) {
return resp, nil
}

if !isRetryableHTTPStatusCodes(resp.StatusCode) {
return resp, nil
}

if attempt < rrt.maxRetries && resp.Body != nil {
/* If not last attempt, drain response body */
if _, err = io.Copy(io.Discard, resp.Body); err != nil {
return nil, fmt.Errorf("retryRoundTripper internal failure to drain response body: %w", err)
}
if err = resp.Body.Close(); err != nil {
return nil, fmt.Errorf("retryRoundTripper internal failure to close response body: %w", err)
}
}

// TODO: move to HTTP util function in independent util file
sleepDuration := HTTPClientDefaultInitBackoff * time.Duration(math.Pow(HTTPClientDefaultBackoffBase, float64(attempt)))
if sleepDuration > HTTPClientDefaultMaxBackoff {
sleepDuration = HTTPClientDefaultMaxBackoff
}

// TODO: merge common utils for apiserver v1 and v2
if deadline, ok := ctx.Deadline(); ok {
remaining := time.Until(deadline)
if remaining <= 0 {
return resp, fmt.Errorf("retry timeout exceeded context deadline")
}
if sleepDuration > remaining {
sleepDuration = remaining
}
Comment on lines +176 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd better not cap sleepDuration but return an error directly. Otherwise, we will just sleep to the deadline, and the request will still fail.

}

time.Sleep(sleepDuration)
}
return resp, err
}

// TODO: move HTTP util function into independent util file / folder
func isSuccessfulStatusCode(statusCode int) bool {
return 200 <= statusCode && statusCode < 300
}

// TODO: merge common utils for apiserver v1 and v2
func isRetryableHTTPStatusCodes(statusCode int) bool {
switch statusCode {
case http.StatusRequestTimeout, // 408
http.StatusTooManyRequests, // 429
http.StatusInternalServerError, // 500
http.StatusBadGateway, // 502
http.StatusServiceUnavailable, // 503
http.StatusGatewayTimeout: // 504
return true
default:
return false
}
}
120 changes: 120 additions & 0 deletions apiserversdk/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package apiserversdk
import (
"context"
"errors"
"io"
"net"
"net/http"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -325,3 +327,121 @@ var _ = Describe("kuberay service", Ordered, func() {
})
})
})

var _ = Describe("retryRoundTripper", func() {
It("should not retry on successful status OK", func() {
var attempts int32
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&attempts, 1)
return &http.Response{ /* Always return OK status */
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("OK")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
Expect(attempts).To(Equal(int32(1)))
})

It("should retry failed requests and eventually succeed", func() {
const maxFailure = 2
var attempts int32
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
count := atomic.AddInt32(&attempts, 1)
if count <= maxFailure {
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("internal error")),
}, nil
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("ok")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
Expect(attempts).To(Equal(int32(maxFailure + 1)))
})

It("Retries exceed maximum retry counts", func() {
var attempts int32
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&attempts, 1)
return &http.Response{ /* Always return retriable status */
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("internal error")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError))
Expect(attempts).To(Equal(int32(HTTPClientDefaultMaxRetry + 1)))
})

It("should not retry on non-retriable status", func() {
var attempts int32
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&attempts, 1)
return &http.Response{ /* Always return non-retriable status */
StatusCode: http.StatusNotFound,
Body: io.NopCloser(strings.NewReader("Not Found")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
req, err := http.NewRequest(http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusNotFound))
Expect(attempts).To(Equal(int32(1)))
})

It("should respect context timeout and stop retrying", func() {
mock := &mockRoundTripper{
fn: func(_ *http.Request) (*http.Response, error) {
time.Sleep(100 * time.Millisecond)
return &http.Response{
StatusCode: http.StatusInternalServerError,
Body: io.NopCloser(strings.NewReader("internal error")),
}, nil
},
}
retrier := newRetryRoundTripper(mock)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://test", nil)
Expect(err).ToNot(HaveOccurred())
resp, err := retrier.RoundTrip(req)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("retry timeout exceeded context deadline"))
Expect(resp).ToNot(BeNil())
})
})

type mockRoundTripper struct {
fn func(*http.Request) (*http.Response, error)
}

func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return m.fn(req)
}
Loading