feat: fixed linter errors

This commit is contained in:
Michael Hegel 2022-11-22 17:20:08 +01:00
parent d2b2d58489
commit 09fe99f831
2 changed files with 59 additions and 33 deletions

View File

@ -28,7 +28,7 @@ var (
ErrCircuitBreaker = errors.New("circuit breaker triggered") ErrCircuitBreaker = errors.New("circuit breaker triggered")
) )
// CircuitBreakerSettings is a copy of the gobreaker.Settings, except that the IsSuccessful function is ommited since we // CircuitBreakerSettings is a copy of the gobreaker.Settings, except that the IsSuccessful function is omitted since we
// want to allow access to the request and response. See `CircuitBreakerWithIsSuccessful` for more. // want to allow access to the request and response. See `CircuitBreakerWithIsSuccessful` for more.
type CircuitBreakerSettings struct { type CircuitBreakerSettings struct {
// Name is the name of the CircuitBreaker. // Name is the name of the CircuitBreaker.
@ -128,16 +128,16 @@ func CircuitBreakerWithIsSuccessful(
// CircuitBreaker returns a RoundTripper which wraps all the following RoundTripwares and the Handler with a circuit // CircuitBreaker returns a RoundTripper which wraps all the following RoundTripwares and the Handler with a circuit
// breaker. This will prevent further request once a certain number of requests failed. // breaker. This will prevent further request once a certain number of requests failed.
// NOTE: It's strongly adviced to add this Roundripware before the metric middleware (if both are used). As the measure- // NOTE: It's strongly advised to add this Roundripware before the metric middleware (if both are used). As the measure-
// ments of the execution time will otherwise be falsified // ments of the execution time will otherwise be falsified
func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) RoundTripware { func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) RoundTripware {
// intialize the options // intitialize the options
o := newDefaultCircuitBreakerOptions() o := newDefaultCircuitBreakerOptions()
for _, opt := range opts { for _, opt := range opts {
opt(o) opt(o)
} }
// intialize the state change counter // intitialize the state change counter
var stateCounter syncint64.Counter var stateCounter syncint64.Counter
if o.stateMeter != nil { if o.stateMeter != nil {
var err error var err error
@ -150,7 +150,7 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R
} }
} }
// intialize the state (un-)success counter // intitialize the state (un-)success counter
var successCounter syncint64.Counter var successCounter syncint64.Counter
if o.successMeter != nil { if o.successMeter != nil {
var err error var err error
@ -163,7 +163,7 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R
} }
} }
// Initialize the gobreaker // intitialize the gobreaker
cbrSettings := gobreaker.Settings{ cbrSettings := gobreaker.Settings{
Name: set.Name, Name: set.Name,
MaxRequests: set.MaxRequests, MaxRequests: set.MaxRequests,
@ -200,7 +200,6 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R
// call the next handler enclosed in the circuit breaker. // call the next handler enclosed in the circuit breaker.
resp, err := circuitBreaker.Execute(func() (interface{}, error) { resp, err := circuitBreaker.Execute(func() (interface{}, error) {
resp, err := next(r) resp, err := next(r)
// clone the response and the body if wanted // clone the response and the body if wanted
@ -219,7 +218,7 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R
// detect and log a state change // detect and log a state change
toState := circuitBreaker.State() toState := circuitBreaker.State()
if fromState != toState { if fromState != toState {
l.Warn("state change occured", l.Warn("state change occurred",
zap.String("from", fromState.String()), zap.String("from", fromState.String()),
zap.String("to", toState.String()), zap.String("to", toState.String()),
) )
@ -258,7 +257,11 @@ func CircuitBreaker(set *CircuitBreakerSettings, opts ...CircuitBreakerOption) R
successCounter.Add(ctx, 1, attributes...) successCounter.Add(ctx, 1, attributes...)
} }
return resp.(*http.Response), nil if res, ok := resp.(*http.Response); ok {
return res, nil
} else {
return nil, errors.New("result is no *http.Response")
}
} }
} }
} }
@ -293,7 +296,6 @@ func copyRequest(req *http.Request, body bool) (*http.Request, error) {
// if it is attempted to read from the body in isSuccessful we actually want the read to fail // if it is attempted to read from the body in isSuccessful we actually want the read to fail
out.Body = failureToReadBody{} out.Body = failureToReadBody{}
} }
} else if req.Body == nil { } else if req.Body == nil {
req.Body = nil req.Body = nil
out.Body = nil out.Body = nil
@ -337,7 +339,8 @@ func copyResponse(resp *http.Response, body bool) (*http.Response, error) {
} }
// copied from httputil // copied from httputil
func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { func drainBody(b io.ReadCloser) (io.ReadCloser, io.ReadCloser, error) {
var err error
if b == nil || b == http.NoBody { if b == nil || b == http.NoBody {
// No copying needed. Preserve the magic sentinel meaning of NoBody. // No copying needed. Preserve the magic sentinel meaning of NoBody.
return http.NoBody, http.NoBody, nil return http.NoBody, http.NoBody, nil

View File

@ -42,7 +42,6 @@ var cbSettings = &roundtripware.CircuitBreakerSettings{
} }
func TestCircuitBreaker(t *testing.T) { func TestCircuitBreaker(t *testing.T) {
// create logger // create logger
l := zaptest.NewLogger(t) l := zaptest.NewLogger(t)
@ -63,7 +62,7 @@ func TestCircuitBreaker(t *testing.T) {
roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreaker(cbSettings,
roundtripware.CircuitBreakerWithIsSuccessful( roundtripware.CircuitBreakerWithIsSuccessful(
func(err error, req *http.Request, resp *http.Response) error { func(err error, req *http.Request, resp *http.Response) error {
if resp.StatusCode >= 500 { if resp.StatusCode >= http.StatusInternalServerError {
return errors.New("invalid status code") return errors.New("invalid status code")
} }
return nil return nil
@ -77,7 +76,10 @@ func TestCircuitBreaker(t *testing.T) {
for i := 0; i <= 3; i++ { for i := 0; i <= 3; i++ {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil)
require.NoError(t, err) require.NoError(t, err)
_, err = client.Do(req) resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker)
} }
@ -85,7 +87,10 @@ func TestCircuitBreaker(t *testing.T) {
// this should result in a circuit breaker error // this should result in a circuit breaker error
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil)
require.NoError(t, err) require.NoError(t, err)
_, err = client.Do(req) resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker) require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker)
// wait for the timeout to hit // wait for the timeout to hit
@ -93,13 +98,14 @@ func TestCircuitBreaker(t *testing.T) {
req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err = client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.NoError(t, err) require.NoError(t, err)
resp.Body.Close()
} }
func TestCircuitBreakerCopyBodies(t *testing.T) { func TestCircuitBreakerCopyBodies(t *testing.T) {
requestData := "some request" requestData := "some request"
responseData := "some response" responseData := "some response"
@ -111,7 +117,10 @@ func TestCircuitBreakerCopyBodies(t *testing.T) {
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, string(data), requestData) require.Equal(t, string(data), requestData)
w.Write([]byte(responseData)) _, err = w.Write([]byte(responseData))
if err != nil {
panic(err)
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
})) }))
defer svr.Close() defer svr.Close()
@ -122,7 +131,6 @@ func TestCircuitBreakerCopyBodies(t *testing.T) {
roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreaker(cbSettings,
roundtripware.CircuitBreakerWithIsSuccessful( roundtripware.CircuitBreakerWithIsSuccessful(
func(err error, req *http.Request, resp *http.Response) error { func(err error, req *http.Request, resp *http.Response) error {
// read the bodies // read the bodies
_, errRead := io.ReadAll(req.Body) _, errRead := io.ReadAll(req.Body)
require.NoError(t, errRead) require.NoError(t, errRead)
@ -144,9 +152,10 @@ func TestCircuitBreakerCopyBodies(t *testing.T) {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData))
require.NoError(t, err) require.NoError(t, err)
resp, err := client.Do(req) resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close()
// make sure the correct data is returned // make sure the correct data is returned
data, err := io.ReadAll(resp.Body) data, err := io.ReadAll(resp.Body)
require.NoError(t, err) require.NoError(t, err)
@ -154,7 +163,6 @@ func TestCircuitBreakerCopyBodies(t *testing.T) {
} }
func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) { func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) {
requestData := "some request" requestData := "some request"
responseData := "some response" responseData := "some response"
@ -166,7 +174,10 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) {
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, string(data), requestData) require.Equal(t, string(data), requestData)
w.Write([]byte(responseData)) _, err = w.Write([]byte(responseData))
if err != nil {
panic(err)
}
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
})) }))
defer svr.Close() defer svr.Close()
@ -177,7 +188,6 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) {
roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreaker(cbSettings,
roundtripware.CircuitBreakerWithIsSuccessful( roundtripware.CircuitBreakerWithIsSuccessful(
func(err error, req *http.Request, resp *http.Response) error { func(err error, req *http.Request, resp *http.Response) error {
// read the bodies // read the bodies
_, errRead := io.ReadAll(req.Body) _, errRead := io.ReadAll(req.Body)
if errRead != nil { if errRead != nil {
@ -194,7 +204,10 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) {
// do requests to trigger the circuit breaker // do requests to trigger the circuit breaker
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData))
require.NoError(t, err) require.NoError(t, err)
_, err = client.Do(req) resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.Error(t, err) require.Error(t, err)
// same thing for the response // same thing for the response
@ -203,7 +216,6 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) {
roundtripware.CircuitBreaker(cbSettings, roundtripware.CircuitBreaker(cbSettings,
roundtripware.CircuitBreakerWithIsSuccessful( roundtripware.CircuitBreakerWithIsSuccessful(
func(err error, req *http.Request, resp *http.Response) error { func(err error, req *http.Request, resp *http.Response) error {
// read the bodies // read the bodies
_, errRead := io.ReadAll(resp.Body) _, errRead := io.ReadAll(resp.Body)
if errRead != nil { if errRead != nil {
@ -220,12 +232,14 @@ func TestCircuitBreakerReadFromNotCopiedBodies(t *testing.T) {
// do requests to trigger the circuit breaker // do requests to trigger the circuit breaker
req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData)) req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, strings.NewReader(requestData))
require.NoError(t, err) require.NoError(t, err)
_, err = client.Do(req) resp, err = client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.Error(t, err) require.Error(t, err)
} }
func TestCircuitBreakerInterval(t *testing.T) { func TestCircuitBreakerInterval(t *testing.T) {
// create logger // create logger
l := zaptest.NewLogger(t) l := zaptest.NewLogger(t)
@ -250,7 +264,7 @@ func TestCircuitBreakerInterval(t *testing.T) {
}, },
roundtripware.CircuitBreakerWithIsSuccessful( roundtripware.CircuitBreakerWithIsSuccessful(
func(err error, req *http.Request, resp *http.Response) error { func(err error, req *http.Request, resp *http.Response) error {
if resp.StatusCode >= 500 { if resp.StatusCode >= http.StatusInternalServerError {
return errors.New("invalid status code") return errors.New("invalid status code")
} }
return nil return nil
@ -264,7 +278,10 @@ func TestCircuitBreakerInterval(t *testing.T) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil)
require.NoError(t, err) require.NoError(t, err)
_, err = client.Do(req) resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker)
} }
@ -276,13 +293,19 @@ func TestCircuitBreakerInterval(t *testing.T) {
for i := 0; i <= 3; i++ { for i := 0; i <= 3; i++ {
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil)
require.NoError(t, err) require.NoError(t, err)
_, err = client.Do(req) resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker) require.NotErrorIs(t, err, roundtripware.ErrCircuitBreaker)
} }
// this request should now finally trigger the circuit breaker // this request should now finally trigger the circuit breaker
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, svr.URL, nil)
require.NoError(t, err) require.NoError(t, err)
_, err = client.Do(req) resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
}
require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker) require.ErrorIs(t, err, roundtripware.ErrCircuitBreaker)
} }