diff --git a/net/http/middleware/basicauth.go b/net/http/middleware/basicauth.go index 81ac51b..4848890 100644 --- a/net/http/middleware/basicauth.go +++ b/net/http/middleware/basicauth.go @@ -47,7 +47,7 @@ func BasicAuth(username string, passwordHash []byte, opts ...BasicAuthOption) Mi // BasicAuthWithOptions middleware func BasicAuthWithOptions(username string, passwordHash []byte, opts BasicAuthOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // basic auth from request header u, p, ok := r.BasicAuth() diff --git a/net/http/middleware/cookietokenprovider.go b/net/http/middleware/cookietokenprovider.go new file mode 100644 index 0000000..0c89acf --- /dev/null +++ b/net/http/middleware/cookietokenprovider.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "net/http" + + "github.com/pkg/errors" +) + +type ( + CookieTokenProviderOptions struct{} + CookieTokenProviderOption func(*CookieTokenProviderOptions) +) + +// GetDefaultCookieTokenOptions returns the default options +func GetDefaultCookieTokenOptions() CookieTokenProviderOptions { + return CookieTokenProviderOptions{} +} + +func CookieTokenProvider(cookieName string, opts ...CookieTokenProviderOption) TokenProvider { + options := GetDefaultCookieTokenOptions() + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + return func(r *http.Request) (string, error) { + if cookie, err := r.Cookie(cookieName); errors.Is(err, http.ErrNoCookie) { + return "", nil + } else if err != nil { + return "", errors.New("failed to retrieve cookie") + } else { + return cookie.Value, nil + } + } +} diff --git a/net/http/middleware/cors.go b/net/http/middleware/cors.go index 6c4bf6f..1dd13f8 100644 --- a/net/http/middleware/cors.go +++ b/net/http/middleware/cors.go @@ -100,7 +100,7 @@ func CORSWithOptions(opts CORSOptions) Middleware { exposeHeaders := strings.Join(opts.ExposeHeaders, ",") maxAge := strconv.Itoa(opts.MaxAge) - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get(keelhttp.HeaderOrigin) allowOrigin := "" diff --git a/net/http/middleware/headertokenprovider.go b/net/http/middleware/headertokenprovider.go new file mode 100644 index 0000000..66f73ca --- /dev/null +++ b/net/http/middleware/headertokenprovider.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/pkg/errors" + + keelhttp "github.com/foomo/keel/net/http" +) + +type ( + HeaderTokenProviderOptions struct { + Prefix string + Header string + } + HeaderTokenProviderOption func(*HeaderTokenProviderOptions) +) + +// GetDefaultHeaderTokenOptions returns the default options +func GetDefaultHeaderTokenOptions() HeaderTokenProviderOptions { + return HeaderTokenProviderOptions{ + Prefix: keelhttp.HeaderValueAuthorizationPrefix, + Header: keelhttp.HeaderAuthorization, + } +} + +// HeaderTokenProviderWithPrefix middleware option +func HeaderTokenProviderWithPrefix(v string) HeaderTokenProviderOption { + return func(o *HeaderTokenProviderOptions) { + o.Prefix = v + } +} + +// HeaderTokenProviderWithHeader middleware option +func HeaderTokenProviderWithHeader(v string) HeaderTokenProviderOption { + return func(o *HeaderTokenProviderOptions) { + o.Header = v + } +} + +func HeaderTokenProvider(opts ...HeaderTokenProviderOption) TokenProvider { + options := GetDefaultHeaderTokenOptions() + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + return func(r *http.Request) (string, error) { + if value := r.Header.Get(options.Header); value == "" { + return "", nil + } else if !strings.HasPrefix(value, options.Prefix) { + return "", errors.New("malformed bearer token") + } else { + return strings.TrimPrefix(value, options.Prefix), nil + } + } +} diff --git a/net/http/middleware/jwt.go b/net/http/middleware/jwt.go index 67c6559..02d80cb 100644 --- a/net/http/middleware/jwt.go +++ b/net/http/middleware/jwt.go @@ -122,7 +122,7 @@ func JWT(jwt *jwt.JWT, contextKey interface{}, opts ...JWTOption) Middleware { // JWTWithOptions middleware func JWTWithOptions(jwt *jwt.JWT, contextKey interface{}, opts JWTOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := opts.ClaimsProvider() if value := r.Context().Value(contextKey); value != nil { diff --git a/net/http/middleware/logger.go b/net/http/middleware/logger.go index e142cda..5ba4299 100644 --- a/net/http/middleware/logger.go +++ b/net/http/middleware/logger.go @@ -34,6 +34,7 @@ func Logger(opts ...LoggerOption) Middleware { return LoggerWithOptions(options) } +// LoggerWithMessage middleware option func LoggerWithMessage(v string) LoggerOption { return func(o *LoggerOptions) { o.Message = v @@ -42,7 +43,7 @@ func LoggerWithMessage(v string) LoggerOption { // LoggerWithOptions middleware func LoggerWithOptions(opts LoggerOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/net/http/middleware/middleware.go b/net/http/middleware/middleware.go index f28e097..eebb1d3 100644 --- a/net/http/middleware/middleware.go +++ b/net/http/middleware/middleware.go @@ -7,14 +7,14 @@ import ( ) // Middleware your way to handle requests -type Middleware func(*zap.Logger, http.Handler) http.Handler +type Middleware func(*zap.Logger, string, http.Handler) http.Handler -func Compose(l *zap.Logger, handler http.Handler, middlewares ...Middleware) http.Handler { - composed := func(l *zap.Logger, next http.Handler) http.Handler { +func Compose(l *zap.Logger, name string, handler http.Handler, middlewares ...Middleware) http.Handler { + composed := func(l *zap.Logger, name string, next http.Handler) http.Handler { for _, middleware := range middlewares { - next = middleware(l, next) + next = middleware(l, name, next) } return next } - return composed(l, handler) + return composed(l, name, handler) } diff --git a/net/http/middleware/poweredbyheader.go b/net/http/middleware/poweredbyheader.go index 18aea28..be7db7b 100644 --- a/net/http/middleware/poweredbyheader.go +++ b/net/http/middleware/poweredbyheader.go @@ -10,7 +10,7 @@ import ( type ( PoweredByHeaderOptions struct { - Header string + Header string Message string } PoweredByHeaderOption func(*PoweredByHeaderOptions) @@ -19,7 +19,7 @@ type ( // GetDefaultPoweredByHeaderOptions returns the default options func GetDefaultPoweredByHeaderOptions() PoweredByHeaderOptions { return PoweredByHeaderOptions{ - Header: httputils.HeaderXPoweredBy, + Header: httputils.HeaderXPoweredBy, Message: "a lot of LOVE", } } diff --git a/net/http/middleware/recover.go b/net/http/middleware/recover.go index fe73f2b..cc6ba71 100644 --- a/net/http/middleware/recover.go +++ b/net/http/middleware/recover.go @@ -45,7 +45,7 @@ func Recover(opts ...RecoverOption) Middleware { // RecoverWithOptions middleware func RecoverWithOptions(opts RecoverOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if e := recover(); e != nil { diff --git a/net/http/middleware/requestid.go b/net/http/middleware/requestid.go index b56b198..f3a3ab7 100644 --- a/net/http/middleware/requestid.go +++ b/net/http/middleware/requestid.go @@ -67,7 +67,7 @@ func RequestID(opts ...RequestIDOption) Middleware { // RequestIDWithOptions middleware func RequestIDWithOptions(opts RequestIDOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestID := r.Header.Get(opts.ResponseHeader) if requestID == "" { diff --git a/net/http/middleware/requesturiblacklistskipper.go b/net/http/middleware/requesturiblacklistskipper.go new file mode 100644 index 0000000..2c85f10 --- /dev/null +++ b/net/http/middleware/requesturiblacklistskipper.go @@ -0,0 +1,19 @@ +package middleware + +import ( + "net/http" +) + +// RequestURIBlacklistSkipper returns a HTTPMiddlewareConfig.Skipper which skips the given uris +func RequestURIBlacklistSkipper(uris ...string) Skipper { + urisMap := make(map[string]bool, len(uris)) + for _, uri := range uris { + urisMap[uri] = true + } + return func(r *http.Request) bool { + if _, ok := urisMap[r.RequestURI]; ok { + return false + } + return true + } +} diff --git a/net/http/middleware/requesturiwhitelistskipper.go b/net/http/middleware/requesturiwhitelistskipper.go new file mode 100644 index 0000000..004436e --- /dev/null +++ b/net/http/middleware/requesturiwhitelistskipper.go @@ -0,0 +1,19 @@ +package middleware + +import ( + "net/http" +) + +// RequestURIWhitelistSkipper returns a HTTPMiddlewareConfig.Skipper which skips all but the given uris +func RequestURIWhitelistSkipper(uris ...string) Skipper { + urisMap := make(map[string]bool, len(uris)) + for _, uri := range uris { + urisMap[uri] = true + } + return func(r *http.Request) bool { + if _, ok := urisMap[r.RequestURI]; ok { + return true + } + return false + } +} diff --git a/net/http/middleware/responsetime.go b/net/http/middleware/responsetime.go index e3df746..20b5aff 100644 --- a/net/http/middleware/responsetime.go +++ b/net/http/middleware/responsetime.go @@ -60,7 +60,7 @@ func ResponseTime(opts ...ResponseTimeOption) Middleware { // ResponseTimeWithOptions middleware func ResponseTimeWithOptions(opts ResponseTimeOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() rw := WrapResponseWriter(w) diff --git a/net/http/middleware/serverheader.go b/net/http/middleware/serverheader.go index e7f0fbf..243cba5 100644 --- a/net/http/middleware/serverheader.go +++ b/net/http/middleware/serverheader.go @@ -11,6 +11,7 @@ import ( type ( ServerHeaderOptions struct { Header string + Name string } ServerHeaderOption func(*ServerHeaderOptions) ) @@ -33,6 +34,13 @@ func ServerHeader(opts ...ServerHeaderOption) Middleware { return ServerHeaderWithOptions(options) } +// ServerHeaderWithName middleware option +func ServerHeaderWithName(v string) ServerHeaderOption { + return func(o *ServerHeaderOptions) { + o.Name = v + } +} + // ServerHeaderWithHeader middleware option func ServerHeaderWithHeader(v string) ServerHeaderOption { return func(o *ServerHeaderOptions) { @@ -43,6 +51,9 @@ func ServerHeaderWithHeader(v string) ServerHeaderOption { // ServerHeaderWithOptions middleware func ServerHeaderWithOptions(opts ServerHeaderOptions) Middleware { return func(l *zap.Logger, name string, next http.Handler) http.Handler { + if opts.Name != "" { + name = opts.Name + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add(opts.Header, name) next.ServeHTTP(w, r) diff --git a/net/http/middleware/sessionid.go b/net/http/middleware/sessionid.go index 4b43efb..c948b4f 100644 --- a/net/http/middleware/sessionid.go +++ b/net/http/middleware/sessionid.go @@ -110,7 +110,7 @@ func SessionID(opts ...SessionIDOption) Middleware { // SessionIDWithOptions middleware func SessionIDWithOptions(opts SessionIDOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var sessionID string if value := r.Header.Get(opts.Header); value != "" { diff --git a/net/http/middleware/skip.go b/net/http/middleware/skip.go index 44bab7a..6e14b98 100644 --- a/net/http/middleware/skip.go +++ b/net/http/middleware/skip.go @@ -7,8 +7,8 @@ import ( ) func Skip(mw Middleware, skippers ...Skipper) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { - wrapped := mw(l, next) + return func(l *zap.Logger, name string, next http.Handler) http.Handler { + wrapped := mw(l, name, next) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for _, skipper := range skippers { if skipper(r) { diff --git a/net/http/middleware/skipper.go b/net/http/middleware/skipper.go index ba3f609..74de48d 100644 --- a/net/http/middleware/skipper.go +++ b/net/http/middleware/skipper.go @@ -5,31 +5,3 @@ import ( ) type Skipper func(*http.Request) bool - -// RequestURIWhitelistSkipper returns a HTTPMiddlewareConfig.Skipper which skips all but the given uris -func RequestURIWhitelistSkipper(uris ...string) Skipper { - urisMap := make(map[string]bool, len(uris)) - for _, uri := range uris { - urisMap[uri] = true - } - return func(r *http.Request) bool { - if _, ok := urisMap[r.RequestURI]; ok { - return true - } - return false - } -} - -// RequestURIBlacklistSkipper returns a HTTPMiddlewareConfig.Skipper which skips the given uris -func RequestURIBlacklistSkipper(uris ...string) Skipper { - urisMap := make(map[string]bool, len(uris)) - for _, uri := range uris { - urisMap[uri] = true - } - return func(r *http.Request) bool { - if _, ok := urisMap[r.RequestURI]; ok { - return false - } - return true - } -} diff --git a/net/http/middleware/telemetry.go b/net/http/middleware/telemetry.go index 641f109..50225a3 100644 --- a/net/http/middleware/telemetry.go +++ b/net/http/middleware/telemetry.go @@ -9,6 +9,7 @@ import ( type ( TelemetryOptions struct { + Name string OtelOpts []otelhttp.Option } TelemetryOption func(*TelemetryOptions) @@ -22,19 +23,35 @@ func GetDefaultTelemetryOptions() TelemetryOptions { } // Telemetry middleware -func Telemetry(name string, opts ...TelemetryOption) Middleware { +func Telemetry(opts ...TelemetryOption) Middleware { options := GetDefaultTelemetryOptions() for _, opt := range opts { if opt != nil { opt(&options) } } - return TelemetryWithOptions(name, options) + return TelemetryWithOptions(options) +} + +func TelemetryWithName(v string) TelemetryOption { + return func(o *TelemetryOptions) { + o.Name = v + } +} + +// TelemetryWithOtelOpts middleware options +func TelemetryWithOtelOpts(v ...otelhttp.Option) TelemetryOption { + return func(o *TelemetryOptions) { + o.OtelOpts = v + } } // TelemetryWithOptions middleware -func TelemetryWithOptions(name string, opts TelemetryOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { +func TelemetryWithOptions(opts TelemetryOptions) Middleware { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { + if opts.Name != "" { + name = opts.Name + } return otelhttp.NewHandler(next, name, opts.OtelOpts...) } } diff --git a/net/http/middleware/tokenauth.go b/net/http/middleware/tokenauth.go index e6b9e9f..5591fc3 100644 --- a/net/http/middleware/tokenauth.go +++ b/net/http/middleware/tokenauth.go @@ -45,7 +45,7 @@ func TokenAuth(token string, opts ...TokenAuthOption) Middleware { // TokenAuthWithOptions middleware func TokenAuthWithOptions(token string, opts TokenAuthOptions) Middleware { - return func(l *zap.Logger, next http.Handler) http.Handler { + return func(l *zap.Logger, name string, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if value, err := opts.TokenProvider(r); err != nil { httputils.UnauthorizedServerError(l, w, r, errors.Wrap(err, "failed to retrieve token")) diff --git a/net/http/middleware/tokenprovider.go b/net/http/middleware/tokenprovider.go index 1d5f3a2..f2fd292 100644 --- a/net/http/middleware/tokenprovider.go +++ b/net/http/middleware/tokenprovider.go @@ -2,87 +2,6 @@ package middleware import ( "net/http" - "strings" - - "github.com/pkg/errors" - - keelhttp "github.com/foomo/keel/net/http" ) type TokenProvider func(r *http.Request) (string, error) - -type ( - HeaderTokenProviderOptions struct { - Prefix string - Header string - } - HeaderTokenProviderOption func(*HeaderTokenProviderOptions) -) - -// GetDefaultHeaderTokenOptions returns the default options -func GetDefaultHeaderTokenOptions() HeaderTokenProviderOptions { - return HeaderTokenProviderOptions{ - Prefix: keelhttp.HeaderValueAuthorizationPrefix, - Header: keelhttp.HeaderAuthorization, - } -} - -// HeaderTokenProviderWithPrefix middleware option -func HeaderTokenProviderWithPrefix(v string) HeaderTokenProviderOption { - return func(o *HeaderTokenProviderOptions) { - o.Prefix = v - } -} - -// HeaderTokenProviderWithHeader middleware option -func HeaderTokenProviderWithHeader(v string) HeaderTokenProviderOption { - return func(o *HeaderTokenProviderOptions) { - o.Header = v - } -} - -func HeaderTokenProvider(opts ...HeaderTokenProviderOption) TokenProvider { - options := GetDefaultHeaderTokenOptions() - for _, opt := range opts { - if opt != nil { - opt(&options) - } - } - return func(r *http.Request) (string, error) { - if value := r.Header.Get(options.Header); value == "" { - return "", nil - } else if !strings.HasPrefix(value, options.Prefix) { - return "", errors.New("malformed bearer token") - } else { - return strings.TrimPrefix(value, options.Prefix), nil - } - } -} - -type ( - CookieTokenProviderOptions struct{} - CookieTokenProviderOption func(*CookieTokenProviderOptions) -) - -// GetDefaultCookieTokenOptions returns the default options -func GetDefaultCookieTokenOptions() CookieTokenProviderOptions { - return CookieTokenProviderOptions{} -} - -func CookieTokenProvider(cookieName string, opts ...CookieTokenProviderOption) TokenProvider { - options := GetDefaultCookieTokenOptions() - for _, opt := range opts { - if opt != nil { - opt(&options) - } - } - return func(r *http.Request) (string, error) { - if cookie, err := r.Cookie(cookieName); errors.Is(err, http.ErrNoCookie) { - return "", nil - } else if err != nil { - return "", errors.New("failed to retrieve cookie") - } else { - return cookie.Value, nil - } - } -} diff --git a/servicehttp.go b/servicehttp.go index 31136f8..ed79c11 100644 --- a/servicehttp.go +++ b/servicehttp.go @@ -32,7 +32,7 @@ func NewServiceHTTP(l *zap.Logger, name, addr string, handler http.Handler, midd server: &http.Server{ Addr: addr, ErrorLog: errorLog, - Handler: middleware.Compose(l, handler, middlewares...), + Handler: middleware.Compose(l, name, handler, middlewares...), }, l: l, }