diff --git a/integration/gotsrpc/net/middleware/telemetry.go b/integration/gotsrpc/net/middleware/telemetry.go index 17c927a..d3c5b83 100644 --- a/integration/gotsrpc/net/middleware/telemetry.go +++ b/integration/gotsrpc/net/middleware/telemetry.go @@ -38,21 +38,20 @@ func Telemetry() middleware.Middleware { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { *r = *gotsrpc.RequestWithStatsContext(r) - // retrieve or inject labeler - r, labeler := middleware.LoggerLabelerFromRequest(r) - next.ServeHTTP(w, r) if stats, ok := gotsrpc.GetStatsForRequest(r); ok { - labeler.Add( - zap.String(defaultGOTSRPCFunctionLabel, stats.Func), - zap.String(defaultGOTSRPCServiceLabel, stats.Service), - zap.String(defaultGOTSRPCPackageLabel, stats.Package), - ) - if stats.ErrorCode != 0 { - labeler.Add(zap.Int(defaultGOTSRPCErrorCode, stats.ErrorCode)) - if stats.ErrorMessage != "" { - labeler.Add(zap.String(defaultGOTSRPCErrorMessage, stats.ErrorMessage)) + if labeler, ok := middleware.LoggerLabelerFromRequest(r); ok { + labeler.Add( + zap.String(defaultGOTSRPCFunctionLabel, stats.Func), + zap.String(defaultGOTSRPCServiceLabel, stats.Service), + zap.String(defaultGOTSRPCPackageLabel, stats.Package), + ) + if stats.ErrorCode != 0 { + labeler.Add(zap.Int(defaultGOTSRPCErrorCode, stats.ErrorCode)) + if stats.ErrorMessage != "" { + labeler.Add(zap.String(defaultGOTSRPCErrorMessage, stats.ErrorMessage)) + } } } diff --git a/log/labeler.go b/log/labeler.go index 0f98b6c..3532bb8 100644 --- a/log/labeler.go +++ b/log/labeler.go @@ -35,9 +35,9 @@ func InjectLabeler(ctx context.Context, key LabelerContextKey) (context.Context, return context.WithValue(ctx, key, l), l } -func LabelerFromContext(ctx context.Context, key LabelerContextKey) (context.Context, *Labeler) { +func LabelerFromContext(ctx context.Context, key LabelerContextKey) (*Labeler, bool) { if l, ok := ctx.Value(key).(*Labeler); ok { - return ctx, l + return l, true } - return InjectLabeler(ctx, key) + return nil, false } diff --git a/net/http/middleware/logger.go b/net/http/middleware/logger.go index 380ce05..8a59c39 100644 --- a/net/http/middleware/logger.go +++ b/net/http/middleware/logger.go @@ -15,9 +15,10 @@ const loggerLabelerContextKey log.LabelerContextKey = "github.com/foomo/keel/net type ( LoggerOptions struct { - Message string - MinWarnCode int - MinErrorCode int + Message string + MinWarnCode int + MinErrorCode int + InjectLabeler bool } LoggerOption func(*LoggerOptions) ) @@ -25,9 +26,10 @@ type ( // GetDefaultLoggerOptions returns the default options func GetDefaultLoggerOptions() LoggerOptions { return LoggerOptions{ - Message: "handled http request", - MinWarnCode: 400, - MinErrorCode: 500, + Message: "handled http request", + MinWarnCode: 400, + MinErrorCode: 500, + InjectLabeler: true, } } @@ -63,6 +65,13 @@ func LoggerWithMinErrorCode(v int) LoggerOption { } } +// LoggerWithInjectLabeler middleware option +func LoggerWithInjectLabeler(v bool) LoggerOption { + return func(o *LoggerOptions) { + o.InjectLabeler = v + } +} + // LoggerWithOptions middleware func LoggerWithOptions(opts LoggerOptions) Middleware { return func(l *zap.Logger, name string, next http.Handler) http.Handler { @@ -74,8 +83,13 @@ func LoggerWithOptions(opts LoggerOptions) Middleware { l := log.WithHTTPRequest(l, r) - // retrieve or inject labeler - r, labeler := LoggerLabelerFromRequest(r) + var labeler *log.Labeler + + if labeler == nil && opts.InjectLabeler { + var labelerCtx context.Context + labelerCtx, labeler = log.InjectLabeler(r.Context(), loggerLabelerContextKey) + r = r.WithContext(labelerCtx) + } next.ServeHTTP(wr, r) @@ -101,11 +115,10 @@ func LoggerWithOptions(opts LoggerOptions) Middleware { } } -func LoggerLabelerFromContext(ctx context.Context) (context.Context, *log.Labeler) { +func LoggerLabelerFromContext(ctx context.Context) (*log.Labeler, bool) { return log.LabelerFromContext(ctx, loggerLabelerContextKey) } -func LoggerLabelerFromRequest(r *http.Request) (*http.Request, *log.Labeler) { - ctx, l := log.LabelerFromContext(r.Context(), loggerLabelerContextKey) - return r.WithContext(ctx), l +func LoggerLabelerFromRequest(r *http.Request) (*log.Labeler, bool) { + return log.LabelerFromContext(r.Context(), loggerLabelerContextKey) } diff --git a/net/http/middleware/logger_test.go b/net/http/middleware/logger_test.go index da47b2f..b86c473 100644 --- a/net/http/middleware/logger_test.go +++ b/net/http/middleware/logger_test.go @@ -7,6 +7,7 @@ import ( "github.com/foomo/keel/log" "github.com/foomo/keel/net/http/middleware" keeltest "github.com/foomo/keel/test" + "go.uber.org/zap" ) func ExampleLogger() { @@ -37,3 +38,42 @@ func ExampleLogger() { // Output: ok } + +func ExampleLoggerWithInjectLabeler() { + svr := keeltest.NewServer() + + // get logger + l := svr.Logger() + + // create demo service + svs := http.NewServeMux() + svs.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + fmt.Println("ok") + }) + + svr.AddService( + keeltest.NewServiceHTTP(l, "demo", svs, + func(l *zap.Logger, s string, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if labeler, ok := middleware.LoggerLabelerFromRequest(r); ok { + labeler.Add(zap.String("injected", "message")) + } + next.ServeHTTP(w, r) + }) + }, + middleware.Logger( + middleware.LoggerWithInjectLabeler(true), + ), + ), + ) + + svr.Start() + + resp, err := http.Get(svr.GetService("demo").URL() + "/") //nolint:noctx + log.Must(l, err) + defer resp.Body.Close() + + // Output: ok +}