diff --git a/net/http/roundtripware/circuitbreaker.go b/net/http/roundtripware/circuitbreaker.go index e66b3b1..0b62d00 100644 --- a/net/http/roundtripware/circuitbreaker.go +++ b/net/http/roundtripware/circuitbreaker.go @@ -28,7 +28,7 @@ var ( ErrCircuitBreaker = errors.New("circuit breaker triggered") ) -// CircuitBreakerSettings is a copy of the gobreaker.Settings, except that the IsSuccessful function is ommited since we +// CircuitBreakerSettings is a copy of the gobreaker.Settings, except that the IsSuccessful function is omitted since we // want to allow access to the request and response. See `CircuitBreakerWithIsSuccessful` for more. type CircuitBreakerSettings struct { // Name is the name of the CircuitBreaker. @@ -128,16 +128,16 @@ func CircuitBreakerWithIsSuccessful( // CircuitBreaker returns a RoundTripper which wraps all the following RoundTripwares and the Handler with a circuit // breaker. This will prevent further request once a certain number of requests failed. -// NOTE: It's strongly adviced to add this Roundripware before the metric middleware (if both are used). As the measure- +// NOTE: It's strongly advised to add this Roundripware before the metric middleware (if both are used). As the measure- // ments of the execution time will otherwise be falsified func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) RoundTripware { - // intialize the options + // intitialize the options o := newDefaultCircuitBreakerOptions() for _, opt := range opts { opt(o) } - // intialize the state change counter + // intitialize the state change counter var stateCounter syncint64.Counter if o.stateMeter != nil { var err error @@ -150,7 +150,7 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R } } - // intialize the state (un-)success counter + // intitialize the state (un-)success counter var successCounter syncint64.Counter if o.successMeter != nil { var err error @@ -163,7 +163,7 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R } } - // Initialize the gobreaker + // intitialize the gobreaker cbrSettings := gobreaker.Settings{ Name: set.Name, MaxRequests: set.MaxRequests, @@ -200,7 +200,6 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R // call the next handler enclosed in the circuit breaker. resp, err := circuitBreaker.Execute(func() (interface{}, error) { - resp, err := next(r) // clone the response and the body if wanted @@ -219,7 +218,7 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R // detect and log a state change toState := circuitBreaker.State() if fromState != toState { - l.Warn("state change occured", + l.Warn("state change occurred", zap.String("from", fromState.String()), zap.String("to", toState.String()), ) @@ -258,7 +257,11 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R successCounter.Add(ctx, 1, attributes...) } - return resp.(*http.Response), nil + if res, ok := resp.(*http.Response); ok { + return res, nil + } else { + return nil, errors.New("result is no *http.Response") + } } } } @@ -293,7 +296,6 @@ func copyRequest(req *http.Request, body bool) (*http.Request, error) { // if it is attempted to read from the body in isSuccessful we actually want the read to fail out.Body = failureToReadBody{} } - } else if req.Body == nil { req.Body = nil out.Body = nil @@ -337,7 +339,8 @@ func copyResponse(resp *http.Response, body bool) (*http.Response, error) { } // copied from httputil -func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { +func drainBody(b io.ReadCloser) (io.ReadCloser, io.ReadCloser, error) { + var err error if b == nil || b == http.NoBody { // No copying needed. Preserve the magic sentinel meaning of NoBody. return http.NoBody, http.NoBody, nil diff --git a/net/http/roundtripware/circuitbreaker_test.go b/net/http/roundtripware/circuitbreaker_test.go index 0777b74..784a881 100644 --- a/net/http/roundtripware/circuitbreaker_test.go +++ b/net/http/roundtripware/circuitbreaker_test.go @@ -42,7 +42,6 @@ var cbSettings = &roundtripware.CircuitBreakerSettings{ } func TestCircuitBreaker(t *testing.T) { - // create logger l := zaptest.NewLogger(t) @@ -63,7 +62,7 @@ func TestCircuitBreaker(t *testing.T) { roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreakerWithIsSuccessful( func(err error, req *http.Request, resp *http.Response) error { - if resp.StatusCode >= 500 { + if resp.StatusCode >= http.StatusInternalServerError { return errors.New("invalid status code") } return nil @@ -77,7 +76,10 @@ func TestCircuitBreaker(t *testing.T) { for i := 0; i <= 3; i++ { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) - _, err = client.Do(req) + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) } @@ -85,7 +87,10 @@ func TestCircuitBreaker(t *testing.T) { // this should result in a circuit breaker error req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) - _, err = client.Do(req) + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker) // wait for the timeout to hit @@ -93,13 +98,14 @@ func TestCircuitBreaker(t *testing.T) { req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) - resp, err := client.Do(req) + resp, err = client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.NoError(t, err) - resp.Body.Close() } func TestCircuitBreakerCopyBodies(t *testing.T) { - requestData := "some request" responseData := "some response" @@ -111,7 +117,10 @@ func TestCircuitBreakerCopyBodies(t *testing.T) { data, err := io.ReadAll(r.Body) require.NoError(t, err) require.Equal(t, string(data), requestData) - w.Write([]byte(responseData)) + _, err = w.Write([]byte(responseData)) + if err != nil { + panic(err) + } w.WriteHeader(http.StatusOK) })) defer svr.Close() @@ -122,7 +131,6 @@ func TestCircuitBreakerCopyBodies(t *testing.T) { roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreakerWithIsSuccessful( func(err error, req *http.Request, resp *http.Response) error { - // read the bodies _, errRead := io.ReadAll(req.Body) require.NoError(t, errRead) @@ -144,9 +152,10 @@ func TestCircuitBreakerCopyBodies(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) require.NoError(t, err) resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.NoError(t, err) - defer resp.Body.Close() - // make sure the correct data is returned data, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -154,7 +163,6 @@ func TestCircuitBreakerCopyBodies(t *testing.T) { } func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { - requestData := "some request" responseData := "some response" @@ -166,7 +174,10 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { data, err := io.ReadAll(r.Body) require.NoError(t, err) require.Equal(t, string(data), requestData) - w.Write([]byte(responseData)) + _, err = w.Write([]byte(responseData)) + if err != nil { + panic(err) + } w.WriteHeader(http.StatusOK) })) defer svr.Close() @@ -177,7 +188,6 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreakerWithIsSuccessful( func(err error, req *http.Request, resp *http.Response) error { - // read the bodies _, errRead := io.ReadAll(req.Body) if errRead != nil { @@ -194,7 +204,10 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { // do requests to trigger the circuit breaker req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) require.NoError(t, err) - _, err = client.Do(req) + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.Error(t, err) // same thing for the response @@ -203,7 +216,6 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreakerWithIsSuccessful( func(err error, req *http.Request, resp *http.Response) error { - // read the bodies _, errRead := io.ReadAll(resp.Body) if errRead != nil { @@ -220,12 +232,14 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { // do requests to trigger the circuit breaker req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) require.NoError(t, err) - _, err = client.Do(req) + resp, err = client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.Error(t, err) } func TestCircuitBreakerInterval(t *testing.T) { - // create logger l := zaptest.NewLogger(t) @@ -250,7 +264,7 @@ func TestCircuitBreakerInterval(t *testing.T) { }, roundtripware.CircuitBreakerWithIsSuccessful( func(err error, req *http.Request, resp *http.Response) error { - if resp.StatusCode >= 500 { + if resp.StatusCode >= http.StatusInternalServerError { return errors.New("invalid status code") } return nil @@ -264,7 +278,10 @@ func TestCircuitBreakerInterval(t *testing.T) { for i := 0; i < 3; i++ { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) - _, err = client.Do(req) + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) } @@ -276,13 +293,19 @@ func TestCircuitBreakerInterval(t *testing.T) { for i := 0; i <= 3; i++ { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) - _, err = client.Do(req) + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) } // this request should now finally trigger the circuit breaker req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) - _, err = client.Do(req) + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + } require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker) }