diff --git a/net/http/roundtripware/circuitbreaker.go b/net/http/roundtripware/circuitbreaker.go index 6066fe1..82ae6d1 100644 --- a/net/http/roundtripware/circuitbreaker.go +++ b/net/http/roundtripware/circuitbreaker.go @@ -51,7 +51,7 @@ type CircuitBreakerSettings struct { type CircuitBreakerOptions struct { Counter syncint64.Counter - IsSuccessful func(err error, req *http.Request, resp *http.Response) error + IsSuccessful func(err error, req *http.Request, resp *http.Response) (e error, ignore bool) CopyReqBody bool CopyRespBody bool } @@ -60,8 +60,8 @@ func NewDefaultCircuitBreakerOptions() *CircuitBreakerOptions { return &CircuitBreakerOptions{ Counter: nil, - IsSuccessful: func(err error, req *http.Request, resp *http.Response) error { - return err + IsSuccessful: func(err error, req *http.Request, resp *http.Response) (e error, ignore bool) { + return err, false }, CopyReqBody: false, CopyRespBody: false, @@ -91,7 +91,7 @@ func CircuitBreakerWithMetric( } func CircuitBreakerWithIsSuccessful( - isSuccessful func(err error, req *http.Request, resp *http.Response) error, + isSuccessful func(err error, req *http.Request, resp *http.Response) (e error, ignore bool), copyReqBody bool, copyRespBody bool, ) CircuitBreakerOption { @@ -122,7 +122,7 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R ReadyToTrip: set.ReadyToTrip, OnStateChange: set.OnStateChange, } - circuitBreaker := gobreaker.NewCircuitBreaker(cbrSettings) + circuitBreaker := gobreaker.NewTwoStepCircuitBreaker(cbrSettings) return func(l *zap.Logger, next Handler) Handler { return func(r *http.Request) (*http.Response, error) { @@ -139,9 +139,15 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R defer reqCopy.Body.Close() } - // call the next handler enclosed in the circuit breaker. - resp, err := circuitBreaker.Execute(func() (interface{}, error) { - resp, err := next(r) + // check whether the circuit breaker is closed (an error is returned if not) + done, err := circuitBreaker.Allow() + + var resp *http.Response + // wrap the error in case it was produced because of the circuit breaker being (half-)open + if errors.Is(gobreaker.ErrTooManyRequests, err) || errors.Is(gobreaker.ErrOpenState, err) { + err = keelerrors.NewWrappedError(ErrCircuitBreaker, err) + } else if err == nil { + resp, err = next(r) // clone the response and the body if wanted respCopy, errCopy := copyResponse(resp, o.CopyRespBody) @@ -153,8 +159,12 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R defer respCopy.Body.Close() } - return resp, o.IsSuccessful(err, reqCopy, respCopy) - }) + var ignore bool + err, ignore = o.IsSuccessful(err, reqCopy, respCopy) + if !ignore { + done(err == nil) + } + } // detect and log a state change toState := circuitBreaker.State() @@ -165,11 +175,6 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R ) } - // wrap the error in case it was produced because of the circuit breaker being (half-)open - if errors.Is(gobreaker.ErrTooManyRequests, err) || errors.Is(gobreaker.ErrOpenState, err) { - err = keelerrors.NewWrappedError(ErrCircuitBreaker, err) - } - attributes := []attribute.KeyValue{ attribute.String("current_state", toState.String()), attribute.String("previous_state", fromState.String()), @@ -188,11 +193,7 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R o.Counter.Add(r.Context(), 1, attributes...) } - if res, ok := resp.(*http.Response); ok { - return res, nil - } else { - return nil, errors.New("result is no *http.Response") - } + return resp, nil } } } diff --git a/net/http/roundtripware/circuitbreaker_test.go b/net/http/roundtripware/circuitbreaker_test.go index 784a881..f5bc285 100644 --- a/net/http/roundtripware/circuitbreaker_test.go +++ b/net/http/roundtripware/circuitbreaker_test.go @@ -61,17 +61,37 @@ func TestCircuitBreaker(t *testing.T) { keelhttp.HTTPClientWithRoundTripware(l, roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreakerWithIsSuccessful( - func(err error, req *http.Request, resp *http.Response) error { + func(err error, req *http.Request, resp *http.Response) (error, bool) { if resp.StatusCode >= http.StatusInternalServerError { - return errors.New("invalid status code") + return errors.New("invalid status code"), false } - return nil + return nil, false }, true, true, ), ), ), ) + { + client := keelhttp.NewHTTPClient( + keelhttp.HTTPClientWithRoundTripware(l, + roundtripware.CircuitBreaker( + &roundtripware.CircuitBreakerSettings{ + Name: "my little circuit breakerâ„¢", + MaxRequests: 1, + Interval: time.Minute, + Timeout: 30 * time.Second, + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures > 3 + }, + }, + ), + ), + ) + + _ = client + } + // do requests to trigger the circuit breaker for i := 0; i <= 3; i++ { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) @@ -130,7 +150,7 @@ func TestCircuitBreakerCopyBodies(t *testing.T) { keelhttp.HTTPClientWithRoundTripware(l, roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreakerWithIsSuccessful( - func(err error, req *http.Request, resp *http.Response) error { + func(err error, req *http.Request, resp *http.Response) (error, bool) { // read the bodies _, errRead := io.ReadAll(req.Body) require.NoError(t, errRead) @@ -141,7 +161,7 @@ func TestCircuitBreakerCopyBodies(t *testing.T) { // also try to close one of the bodies (should also be handled by the RoundTripware) req.Body.Close() - return err + return err, false }, true, true, ), ), @@ -187,14 +207,14 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { keelhttp.HTTPClientWithRoundTripware(l, roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreakerWithIsSuccessful( - func(err error, req *http.Request, resp *http.Response) error { + func(err error, req *http.Request, resp *http.Response) (error, bool) { // read the bodies _, errRead := io.ReadAll(req.Body) if errRead != nil { - return errRead + return errRead, false } - return err + return err, false }, false, true, ), ), @@ -215,14 +235,14 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { keelhttp.HTTPClientWithRoundTripware(l, roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreakerWithIsSuccessful( - func(err error, req *http.Request, resp *http.Response) error { + func(err error, req *http.Request, resp *http.Response) (error, bool) { // read the bodies _, errRead := io.ReadAll(resp.Body) if errRead != nil { - return errRead + return errRead, false } - return err + return err, false }, true, false, ), ), @@ -263,11 +283,11 @@ func TestCircuitBreakerInterval(t *testing.T) { }, }, roundtripware.CircuitBreakerWithIsSuccessful( - func(err error, req *http.Request, resp *http.Response) error { + func(err error, req *http.Request, resp *http.Response) (error, bool) { if resp.StatusCode >= http.StatusInternalServerError { - return errors.New("invalid status code") + return errors.New("invalid status code"), false } - return nil + return nil, false }, true, true, ), ), @@ -309,3 +329,47 @@ func TestCircuitBreakerInterval(t *testing.T) { } require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker) } + +func TestCircuitBreakerIgnore(t *testing.T) { + // create logger + l := zaptest.NewLogger(t) + + // create http server with handler + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // always return an invalid status code + w.WriteHeader(http.StatusInternalServerError) + })) + defer svr.Close() + + // create http client + client := keelhttp.NewHTTPClient( + keelhttp.HTTPClientWithRoundTripware(l, + roundtripware.CircuitBreaker(cbSettings, + roundtripware.CircuitBreakerWithIsSuccessful( + func(err error, req *http.Request, resp *http.Response) (error, bool) { + if req.Method == http.MethodGet { + return errors.New("some ignored error"), true + } + if resp.StatusCode >= http.StatusInternalServerError { + return errors.New("invalid status code"), false + } + return nil, false + }, true, true, + ), + ), + ), + ) + + // send 4 requests (lower than the maximum amount of allowed consecutive failures), but they are ignored + // -> circuit breaker should remain open + for i := 0; i < 5; i++ { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) + require.NoError(t, err) + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + } + require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) + require.Error(t, err) + } +}