diff --git a/net/http/roundtripware/circuitbreaker.go b/net/http/roundtripware/circuitbreaker.go index f8d7d62..e060c61 100644 --- a/net/http/roundtripware/circuitbreaker.go +++ b/net/http/roundtripware/circuitbreaker.go @@ -176,8 +176,17 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R } } - if err = o.IsSuccessful(err, reqCopy, respCopy); !errors.Is(err, ErrIgnoreSuccessfulness) { - done(err == nil) + if errSuccess := o.IsSuccessful(err, reqCopy, respCopy); errors.Is(errSuccess, errNoBody) { + l.Error("encountered read from not previously copied request/response body", + zap.Bool("copy_request", o.CopyReqBody), + zap.Bool("copy_response", o.CopyRespBody), + ) + // we actually want to return an error instead of the original request and error since the user + // should be made aware that there is a misconfiguration + resp = nil + err = errSuccess + } else if !errors.Is(errSuccess, ErrIgnoreSuccessfulness) { + done(errSuccess == nil) } } @@ -200,15 +209,12 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R attributes := append(attributes, attribute.Bool("error", true)) o.Counter.Add(r.Context(), 1, attributes...) } - return nil, err - } - - if o.Counter != nil { + } else if o.Counter != nil { attributes := append(attributes, attribute.Bool("error", false)) o.Counter.Add(r.Context(), 1, attributes...) } - return resp, nil + return resp, err } } } diff --git a/net/http/roundtripware/circuitbreaker_test.go b/net/http/roundtripware/circuitbreaker_test.go index 627abda..758810f 100644 --- a/net/http/roundtripware/circuitbreaker_test.go +++ b/net/http/roundtripware/circuitbreaker_test.go @@ -3,6 +3,7 @@ package roundtripware_test import ( "context" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -23,7 +24,7 @@ var cbSettings = &roundtripware.CircuitBreakerSettings{ // MaxRequests is the maximum number of requests allowed to pass through // when the CircuitBreaker is half-open. // If MaxRequests is 0, the CircuitBreaker allows only 1 request. - MaxRequests: 1, + MaxRequests: 2, // Interval is the cyclic period of the closed state // for the CircuitBreaker to clear the internal Counts. // If Interval is less than or equal to 0, the CircuitBreaker doesn't clear internal Counts during the closed state. @@ -39,6 +40,9 @@ var cbSettings = &roundtripware.CircuitBreakerSettings{ ReadyToTrip: func(counts gobreaker.Counts) bool { return counts.ConsecutiveFailures > 3 }, + OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) { + fmt.Printf("\n\nstate changed from %s to %s\n\n", from, to) + }, } func TestCircuitBreaker(t *testing.T) { @@ -96,7 +100,7 @@ func TestCircuitBreaker(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) @@ -118,7 +122,7 @@ func TestCircuitBreaker(t *testing.T) { req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) resp, err = client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.NoError(t, err) @@ -170,7 +174,7 @@ func TestCircuitBreakerCopyBodies(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.NoError(t, err) @@ -222,7 +226,7 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.Error(t, err) @@ -250,7 +254,7 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) require.NoError(t, err) resp, err = client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.Error(t, err) @@ -296,7 +300,7 @@ func TestCircuitBreakerInterval(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) @@ -311,7 +315,7 @@ func TestCircuitBreakerInterval(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) @@ -321,7 +325,7 @@ func TestCircuitBreakerInterval(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker) @@ -363,11 +367,12 @@ func TestCircuitBreakerIgnore(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) - require.Error(t, err) + require.NoError(t, err) + require.NotNil(t, resp) } } @@ -397,7 +402,7 @@ func TestCircuitBreakerTimeout(t *testing.T) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, svr.URL, nil) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) @@ -411,7 +416,7 @@ func TestCircuitBreakerTimeout(t *testing.T) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, svr.URL, nil) require.NoError(t, err) resp, err := client.Do(req) - if err == nil { + if resp != nil { defer resp.Body.Close() } require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker)