From af277d3a63ea96cc7a4d6f61d6f6c57bc46a85c5 Mon Sep 17 00:00:00 2001 From: franklin Date: Mon, 6 Sep 2021 22:13:07 +0200 Subject: [PATCH] feat: add middlewares --- net/http/headervalues.go | 5 + net/http/middleware/domainprovider.go | 70 -------- net/http/middleware/jwt.go | 145 ++++++++++++++++ net/http/middleware/server.go | 44 +++++ net/http/middleware/sessionid.go | 192 +++++++++++----------- net/http/middleware/sessioniogenerator.go | 9 - net/http/middleware/skip.go | 22 +++ net/http/middleware/skipper.go | 35 ++++ net/http/middleware/tmp.go | 139 ---------------- net/http/middleware/tokenprovider.go | 88 ++++++++++ utils/net/http/errors.go | 6 +- utils/net/http/request.go | 30 ++++ 12 files changed, 472 insertions(+), 313 deletions(-) create mode 100644 net/http/headervalues.go delete mode 100644 net/http/middleware/domainprovider.go create mode 100644 net/http/middleware/jwt.go create mode 100644 net/http/middleware/server.go delete mode 100644 net/http/middleware/sessioniogenerator.go create mode 100644 net/http/middleware/skip.go create mode 100644 net/http/middleware/skipper.go delete mode 100644 net/http/middleware/tmp.go create mode 100644 net/http/middleware/tokenprovider.go create mode 100644 utils/net/http/request.go diff --git a/net/http/headervalues.go b/net/http/headervalues.go new file mode 100644 index 0000000..94d1de3 --- /dev/null +++ b/net/http/headervalues.go @@ -0,0 +1,5 @@ +package http + +const ( + HeaderValueAuthorizationPrefix = "Bearer " +) diff --git a/net/http/middleware/domainprovider.go b/net/http/middleware/domainprovider.go deleted file mode 100644 index 1829f04..0000000 --- a/net/http/middleware/domainprovider.go +++ /dev/null @@ -1,70 +0,0 @@ -package middleware - -import ( - "net/http" - "strings" - - "github.com/pkg/errors" -) - -var ( - ErrDomainNotAllowed = errors.New("domain not allowed") -) - -type DomainProvider func(r *http.Request) (string, error) - -var DefaultDomainProvider = func(domains []string) DomainProvider { - return func(r *http.Request) (string, error) { - domain := getDomainFromHTTPRequest(r) - if !isDomainAllowed(domain, domains) { - return "", ErrDomainNotAllowed - } - return domain, nil - } -} - -var MappingDomainProvider = func(domains []string, mapping map[string]string) DomainProvider { - return func(r *http.Request) (string, error) { - domain := getDomainFromHTTPRequest(r) - if value, ok := mapping[domain]; ok { - domain = value - } - if !isDomainAllowed(domain, domains) { - return "", errors.New("invalid domain: " + domain) - } - return domain, nil - } -} - -// getDomainFromHTTPRequest helper -func getDomainFromHTTPRequest(r *http.Request) string { - var domain string - if r.Header.Get("X-Forwarded-Host") != "" { - domain = r.Header.Get("X-Forwarded-Host") - } else if !r.URL.IsAbs() { - domain = r.Host - } else { - domain = r.URL.Host - } - - // right trim port - portIndex := strings.Index(domain, ":") - if portIndex != -1 { - domain = domain[:portIndex] - } - - return domain -} - -// isDomainAllowed helper -func isDomainAllowed(domain string, domains []string) bool { - if domains == nil || len(domains) == 0 { - return true - } - for _, value := range domains { - if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) { - return true - } - } - return false -} diff --git a/net/http/middleware/jwt.go b/net/http/middleware/jwt.go new file mode 100644 index 0000000..0e689ac --- /dev/null +++ b/net/http/middleware/jwt.go @@ -0,0 +1,145 @@ +package middleware + +import ( + "context" + "net/http" + + jwt2 "github.com/golang-jwt/jwt" + "github.com/pkg/errors" + "go.uber.org/zap" + + "github.com/foomo/keel/jwt" + httputils "github.com/foomo/keel/utils/net/http" +) + +type ( + JWTOptions struct { + TokenProvider TokenProvider + ClaimsProvider JWTClaimsProvider + MissingTokenHandler JWTMissingTokenHandler + InvalidTokenHandler JWTInvalidTokenHandler + ErrorHandler JWTErrorHandler + } + JWTOption func(*JWTOptions) + JWTClaimsProvider func() jwt2.Claims + JWTErrorHandler func(*zap.Logger, http.ResponseWriter, *http.Request, error) bool + JWTMissingTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request) (jwt2.Claims, bool) + JWTInvalidTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request, *jwt2.Token) bool +) + +// DefaultJWTErrorHandler function +func DefaultJWTErrorHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, err error) bool { + httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed parse claims")) + return false +} + +// DefaultJWTMissingTokenHandler function +func DefaultJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) { + return nil, true +} + +// RequiredJWTMissingTokenHandler function +func RequiredJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) { + httputils.BadRequestServerError(l, w, r, errors.New("missing jwt token")) + return nil, false +} + +// DefaultJWTInvalidTokenHandler function +func DefaultJWTInvalidTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, token *jwt2.Token) bool { + httputils.BadRequestServerError(l, w, r, errors.New("invalid jwt token")) + return false +} + +// DefaultJWTClaimsProvider function +func DefaultJWTClaimsProvider() jwt2.Claims { + return &jwt2.StandardClaims{} +} + +// GetDefaultJWTOptions returns the default options +func GetDefaultJWTOptions() JWTOptions { + return JWTOptions{ + TokenProvider: HeaderTokenProvider(), + ClaimsProvider: DefaultJWTClaimsProvider, + ErrorHandler: DefaultJWTErrorHandler, + InvalidTokenHandler: DefaultJWTInvalidTokenHandler, + MissingTokenHandler: DefaultJWTMissingTokenHandler, + } +} + +// JWTWithTokenProvider middleware option +func JWTWithTokenProvider(v TokenProvider) JWTOption { + return func(o *JWTOptions) { + o.TokenProvider = v + } +} + +// JWTWithClaimsProvider middleware option +func JWTWithClaimsProvider(v JWTClaimsProvider) JWTOption { + return func(o *JWTOptions) { + o.ClaimsProvider = v + } +} + +// JWTWithInvalidTokenHandler middleware option +func JWTWithInvalidTokenHandler(v JWTInvalidTokenHandler) JWTOption { + return func(o *JWTOptions) { + o.InvalidTokenHandler = v + } +} + +// JWTWithMissingTokenHandler middleware option +func JWTWithMissingTokenHandler(v JWTMissingTokenHandler) JWTOption { + return func(o *JWTOptions) { + o.MissingTokenHandler = v + } +} + +// JWTWithErrorHandler middleware option +func JWTWithErrorHandler(v JWTErrorHandler) JWTOption { + return func(o *JWTOptions) { + o.ErrorHandler = v + } +} + +// JWT middleware +func JWT(jwt *jwt.JWT, contextKey interface{}, opts ...JWTOption) Middleware { + options := GetDefaultJWTOptions() + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + return JWTWithOptions(jwt, contextKey, options) +} + +// JWTWithOptions middleware +func JWTWithOptions(jwt *jwt.JWT, contextKey interface{}, opts JWTOptions) Middleware { + return func(l *zap.Logger, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := opts.ClaimsProvider() + if value := r.Context().Value(contextKey); value != nil { + // TODO check if type matches the existing + next.ServeHTTP(w, r) + } else if value, err := opts.TokenProvider(r); err != nil { + httputils.BadRequestServerError(l, w, r, errors.Wrap(err, "failed to retrieve token")) + } else if value == "" { + if claims, resume := opts.MissingTokenHandler(l, w, r); resume && claims != nil { + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims))) + } else if resume { + next.ServeHTTP(w, r) + } + } else if token, err := jwt.ParseWithClaims(value, claims); err != nil { + // TODO check if type matches the existing + if opts.ErrorHandler(l, w, r, err) { + next.ServeHTTP(w, r) + } + } else if !token.Valid { + if opts.InvalidTokenHandler(l, w, r, token) { + next.ServeHTTP(w, r) + } + } else { + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims))) + } + }) + } +} diff --git a/net/http/middleware/server.go b/net/http/middleware/server.go new file mode 100644 index 0000000..34aab57 --- /dev/null +++ b/net/http/middleware/server.go @@ -0,0 +1,44 @@ +package middleware + +import ( + "net/http" + + "go.uber.org/zap" + + httputils "github.com/foomo/keel/net/http" +) + +type ( + ServerOptions struct { + Header string + } + ServerOption func(*ServerOptions) +) + +// GetDefaultServerOptions returns the default options +func GetDefaultServerOptions() ServerOptions { + return ServerOptions{ + Header: httputils.HeaderServer, + } +} + +// Server middleware +func Server(name string, opts ...ServerOption) Middleware { + options := GetDefaultServerOptions() + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + return ServerWithOptions(name, options) +} + +// ServerWithOptions middleware +func ServerWithOptions(name string, opts ServerOptions) Middleware { + return func(l *zap.Logger, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(opts.Header, name) + next.ServeHTTP(w, r) + }) + } +} diff --git a/net/http/middleware/sessionid.go b/net/http/middleware/sessionid.go index ac0c8ae..5d5de45 100644 --- a/net/http/middleware/sessionid.go +++ b/net/http/middleware/sessionid.go @@ -1,141 +1,147 @@ package middleware import ( + "context" "net/http" + "github.com/google/uuid" "github.com/pkg/errors" "go.uber.org/zap" - httputils "github.com/foomo/keel/net/http" + keelhttp "github.com/foomo/keel/net/http" + "github.com/foomo/keel/net/http/cookie" + httputils "github.com/foomo/keel/utils/net/http" ) type ( - SessionIDConfig struct { - SetCookie bool // only create a session cookie if enabled - CookieName string - CookieSecure bool - CookieHttpOnly bool - CookiePath string - CookieDomain string - CookieDomains []string - CookieDomainProvider DomainProvider - Generator SessionIDGenerator + contextKey string + SessionIDOptions struct { + // Header to look up the session id + Header string + // Cookie how to set the cookie + Cookie cookie.Cookie + // Generator for the session ids + Generator SessionIDGenerator + // SetCookie if true it will create a cookie if not exists + SetCookie bool + // SetHeader if true it will set add a request header + SetHeader bool + // SetContext if true it will set the context key + SetContext bool } - SessionIDOption func(*SessionIDConfig) error + SessionIDOption func(*SessionIDOptions) + SessionIDGenerator func() string ) -var DefaultSessionIDConfig = SessionIDConfig{ - SetCookie: false, - CookieName: "keel-session", - CookieSecure: true, - CookieHttpOnly: true, - CookiePath: "/", +const ( + ContextKeySessionID contextKey = "sessionId" + + DefaultSessionIDCookieName = "keel-session" +) + +// DefaultSessionIDGenerator function +func DefaultSessionIDGenerator() string { + return uuid.New().String() } +// GetDefaultSessionIDOptions returns the default options +func GetDefaultSessionIDOptions() SessionIDOptions { + return SessionIDOptions{ + Header: keelhttp.HeaderXSessionID, + Cookie: cookie.New(DefaultSessionIDCookieName), + Generator: DefaultSessionIDGenerator, + SetCookie: false, + SetHeader: true, + SetContext: true, + } +} + +func SessionIDWithHeader(v string) SessionIDOption { + return func(o *SessionIDOptions) { + o.Header = v + } +} + +// SessionIDWithSetCookie middleware option func SessionIDWithSetCookie(v bool) SessionIDOption { - return func(c *SessionIDConfig) error { - c.SetCookie = v + return func(o *SessionIDOptions) { + o.SetCookie = v } } -func SessionIDWithCookieName(v string) SessionIDOption { - return func(c *SessionIDConfig) error { - c.CookieName = v +// SessionIDWithSetHeader middleware option +func SessionIDWithSetHeader(v bool) SessionIDOption { + return func(o *SessionIDOptions) { + o.SetHeader = v } } -func SessionIDWithCookieSecure(v bool) SessionIDOption { - return func(c *SessionIDConfig) error { - c.CookieSecure = v +// SessionIDWithSetContext middleware option +func SessionIDWithSetContext(v bool) SessionIDOption { + return func(o *SessionIDOptions) { + o.SetContext = v } } -func SessionIDWithCookieHttpOnly(v bool) SessionIDOption { - return func(c *SessionIDConfig) error { - c.CookieHttpOnly = v - } -} - -func SessionIDWithCookiePath(v string) SessionIDOption { - return func(c *SessionIDConfig) error { - c.CookiePath = v - } -} - -func SessionIDWithCookieDomain(v string) SessionIDOption { - return func(c *SessionIDConfig) error { - c.CookieDomain = v - } -} - -func SessionIDWithCookieDomains(v []string) SessionIDOption { - return func(c *SessionIDConfig) error { - c.CookieDomains = v - } -} - -func SessionIDWithCookieDomainProvider(v DomainProvider) SessionIDOption { - return func(c *SessionIDConfig) error { - c.CookieDomainProvider = v +// SessionIDWithCookie middleware option +func SessionIDWithCookie(v cookie.Cookie) SessionIDOption { + return func(o *SessionIDOptions) { + o.Cookie = v } } +// SessionIDWithGenerator middleware option func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption { - return func(c *SessionIDConfig) error { - c.Generator = v + return func(o *SessionIDOptions) { + o.Generator = v } } +// SessionID middleware func SessionID(opts ...SessionIDOption) Middleware { - config = DefaultSessionIDConfig + options := GetDefaultSessionIDOptions() for _, opt := range opts { if opt != nil { - if err := opt(&opts); err != nil { - return nil, err - } + opt(&options) } } - if config.Generator == nil { - config.Generator = DefaultSessionIDGenerator - } - if config.DomainProvider == nil { - config.DomainProvider = DefaultDomainProvider(config.CookieDomains) - } + return SessionIDWithOptions(options) +} +// SessionIDWithOptions middleware +func SessionIDWithOptions(opts SessionIDOptions) Middleware { return func(l *zap.Logger, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - - if id := r.Header.Get(httputils.HeaderXSessionID); id == "" { - if cookie, err := r.Cookie(config.CookieName); errors.Is(err, http.ErrNoCookie) { - if config.SetCookie { - - domain, err := config.DomainProvider(r) - if err != nil { - httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to resolve domain")) - return - } - - id = config.Generator() - - http.SetCookie(w, &http.Cookie{ - Name: config.CookieName, - Value: id, - Path: config.CookiePath, - HttpOnly: config.CookieHttpOnly, - Secure: config.CookieSecure, - Domain: domain, - }) - } - r.Header.Set(httputils.HeaderXSessionID, id) - } else if err != nil { - httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read cookie")) + var sessionID string + if c, err := opts.Cookie.Get(r); errors.Is(err, http.ErrNoCookie) && !opts.SetCookie { + // do nothing + } else if errors.Is(err, http.ErrNoCookie) && opts.SetCookie { + sessionID = opts.Generator() + if err := opts.Cookie.Set(w, r, sessionID); err != nil { + httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to set session id cookie")) return - } else { - r.Header.Set(httputils.HeaderXSessionID, cookie.Value) } + } else if err != nil { + httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read session id cookie")) + return + } else { + sessionID = c.Value + } + if sessionID != "" && opts.SetHeader { + r.Header.Set(opts.Header, sessionID) + } + if sessionID != "" && opts.SetContext { + r = r.WithContext(context.WithValue(r.Context(), ContextKeySessionID, sessionID)) } - next.ServeHTTP(w, r) }) } } + +// SessionIDFromContext helper +func SessionIDFromContext(ctx context.Context) string { + if value, ok := ctx.Value(ContextKeySessionID).(string); ok { + return value + } + return "" +} diff --git a/net/http/middleware/sessioniogenerator.go b/net/http/middleware/sessioniogenerator.go deleted file mode 100644 index 5b3826d..0000000 --- a/net/http/middleware/sessioniogenerator.go +++ /dev/null @@ -1,9 +0,0 @@ -package middleware - -import "github.com/google/uuid" - -type SessionIDGenerator func() string - -func DefaultSessionIDGenerator() string { - return uuid.New().String() -} diff --git a/net/http/middleware/skip.go b/net/http/middleware/skip.go new file mode 100644 index 0000000..44bab7a --- /dev/null +++ b/net/http/middleware/skip.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "net/http" + + "go.uber.org/zap" +) + +func Skip(mw Middleware, skippers ...Skipper) Middleware { + return func(l *zap.Logger, next http.Handler) http.Handler { + wrapped := mw(l, next) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, skipper := range skippers { + if skipper(r) { + next.ServeHTTP(w, r) + return + } + } + wrapped.ServeHTTP(w, r) + }) + } +} diff --git a/net/http/middleware/skipper.go b/net/http/middleware/skipper.go new file mode 100644 index 0000000..ba3f609 --- /dev/null +++ b/net/http/middleware/skipper.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "net/http" +) + +type Skipper func(*http.Request) bool + +// RequestURIWhitelistSkipper returns a HTTPMiddlewareConfig.Skipper which skips all but the given uris +func RequestURIWhitelistSkipper(uris ...string) Skipper { + urisMap := make(map[string]bool, len(uris)) + for _, uri := range uris { + urisMap[uri] = true + } + return func(r *http.Request) bool { + if _, ok := urisMap[r.RequestURI]; ok { + return true + } + return false + } +} + +// RequestURIBlacklistSkipper returns a HTTPMiddlewareConfig.Skipper which skips the given uris +func RequestURIBlacklistSkipper(uris ...string) Skipper { + urisMap := make(map[string]bool, len(uris)) + for _, uri := range uris { + urisMap[uri] = true + } + return func(r *http.Request) bool { + if _, ok := urisMap[r.RequestURI]; ok { + return false + } + return true + } +} diff --git a/net/http/middleware/tmp.go b/net/http/middleware/tmp.go deleted file mode 100644 index ada6317..0000000 --- a/net/http/middleware/tmp.go +++ /dev/null @@ -1,139 +0,0 @@ -package cmrcsession - -import ( - "net/http" - "strings" - "time" - - cmrcjwt "github.com/bestbytes/commerce/pkg/jwt" - "github.com/golang-jwt/jwt" - "github.com/pkg/errors" -) - -type Session struct { - jwt *cmrcjwt.JWT - secure bool - domains []string - domainMapping map[string]string -} - -func NewSession(jwt *cmrcjwt.JWT, domains []string, domainMapping map[string]string, secure bool) *Session { - return &Session{ - jwt: jwt, - secure: secure, - domains: domains, - domainMapping: domainMapping, - } -} - -func (s *Session) HasDomain(domain string) bool { - for _, value := range s.domains { - if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) { - return true - } - } - return false -} - -func (s *Session) GetCookieClaims(r *http.Request, name string, claims jwt.Claims) (jwt.Claims, error) { - if cookie, err := r.Cookie(name); err != nil { - return nil, err - } else if token, err := s.jwt.ParseWithClaims(cookie.Value, claims); err != nil { - return nil, err - } else if !token.Valid { - return nil, errors.New("invalid token") - } else { - return token.Claims, nil - } -} - -func (s *Session) SetCookieClaims(r *http.Request, w http.ResponseWriter, name string, claims jwt.Claims, lifetime time.Duration) error { - domain, err := s.getDomain(r) - if err != nil { - return err - } - - token, err := s.jwt.GetSignedToken(claims) - if err != nil { - return err - } - - http.SetCookie(w, &http.Cookie{ - Name: name, - Value: token, - Path: "/", - MaxAge: int(lifetime.Seconds()), - HttpOnly: true, - Secure: s.secure, - Domain: domain, - }) - - return nil -} - -func (s *Session) SetSessionCookie(r *http.Request, w http.ResponseWriter, name string, value string) error { - domain, err := s.getDomain(r) - if err != nil { - return err - } - - http.SetCookie(w, &http.Cookie{ - Name: name, - Value: value, - Path: "/", - HttpOnly: true, - Secure: s.secure, - Domain: domain, - }) - - return nil -} - -func (s *Session) DeleteCookie(r *http.Request, w http.ResponseWriter, name string) error { - domain, err := s.getDomain(r) - if err != nil { - return err - } - - http.SetCookie(w, &http.Cookie{ - Name: name, - Value: "", - Path: "/", - HttpOnly: true, - Secure: s.secure, - Domain: domain, - MaxAge: -1, - Expires: time.Now().AddDate(0, 0, -1), - }) - - return nil -} - -func (s *Session) getDomain(r *http.Request) (string, error) { - var domain string - if r.Header.Get("X-Forwarded-Host") != "" { - domain = r.Header.Get("X-Forwarded-Host") - } else if !r.URL.IsAbs() { - domain = r.Host - } else { - domain = r.URL.Host - } - - // right trim port - portIndex := strings.Index(domain, ":") - if portIndex != -1 { - domain = domain[:portIndex] - } - - if s.domainMapping != nil { - if value, ok := s.domainMapping[domain]; ok { - domain = value - } - } - - if !s.HasDomain(domain) { - return "", errors.New("invalid domain: " + domain) - } - - return domain, nil -} diff --git a/net/http/middleware/tokenprovider.go b/net/http/middleware/tokenprovider.go new file mode 100644 index 0000000..1d5f3a2 --- /dev/null +++ b/net/http/middleware/tokenprovider.go @@ -0,0 +1,88 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/pkg/errors" + + keelhttp "github.com/foomo/keel/net/http" +) + +type TokenProvider func(r *http.Request) (string, error) + +type ( + HeaderTokenProviderOptions struct { + Prefix string + Header string + } + HeaderTokenProviderOption func(*HeaderTokenProviderOptions) +) + +// GetDefaultHeaderTokenOptions returns the default options +func GetDefaultHeaderTokenOptions() HeaderTokenProviderOptions { + return HeaderTokenProviderOptions{ + Prefix: keelhttp.HeaderValueAuthorizationPrefix, + Header: keelhttp.HeaderAuthorization, + } +} + +// HeaderTokenProviderWithPrefix middleware option +func HeaderTokenProviderWithPrefix(v string) HeaderTokenProviderOption { + return func(o *HeaderTokenProviderOptions) { + o.Prefix = v + } +} + +// HeaderTokenProviderWithHeader middleware option +func HeaderTokenProviderWithHeader(v string) HeaderTokenProviderOption { + return func(o *HeaderTokenProviderOptions) { + o.Header = v + } +} + +func HeaderTokenProvider(opts ...HeaderTokenProviderOption) TokenProvider { + options := GetDefaultHeaderTokenOptions() + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + return func(r *http.Request) (string, error) { + if value := r.Header.Get(options.Header); value == "" { + return "", nil + } else if !strings.HasPrefix(value, options.Prefix) { + return "", errors.New("malformed bearer token") + } else { + return strings.TrimPrefix(value, options.Prefix), nil + } + } +} + +type ( + CookieTokenProviderOptions struct{} + CookieTokenProviderOption func(*CookieTokenProviderOptions) +) + +// GetDefaultCookieTokenOptions returns the default options +func GetDefaultCookieTokenOptions() CookieTokenProviderOptions { + return CookieTokenProviderOptions{} +} + +func CookieTokenProvider(cookieName string, opts ...CookieTokenProviderOption) TokenProvider { + options := GetDefaultCookieTokenOptions() + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + return func(r *http.Request) (string, error) { + if cookie, err := r.Cookie(cookieName); errors.Is(err, http.ErrNoCookie) { + return "", nil + } else if err != nil { + return "", errors.New("failed to retrieve cookie") + } else { + return cookie.Value, nil + } + } +} diff --git a/utils/net/http/errors.go b/utils/net/http/errors.go index 8e265f1..b1de8e0 100644 --- a/utils/net/http/errors.go +++ b/utils/net/http/errors.go @@ -6,6 +6,7 @@ import ( "go.uber.org/zap" "github.com/foomo/keel/log" + keelhttp "github.com/foomo/keel/net/http" ) // InternalServerError http response @@ -31,7 +32,8 @@ func NotFoundServerError(l *zap.Logger, w http.ResponseWriter, r *http.Request, // ServerError http response func ServerError(l *zap.Logger, w http.ResponseWriter, r *http.Request, code int, err error) { if err != nil { - log.Configure(l).HTTPRequest(r).Error(err).Logger().Error("http server error", zap.Int("code", code)) - http.Error(w, http.StatusText(code), code) // TODO enrich headers + log.WithHTTPRequest(l, r).Error("http server error", log.FError(err), log.FHTTPStatusCode(code)) + w.Header().Set(keelhttp.HeaderXError, err.Error()) + http.Error(w, http.StatusText(code), code) } } diff --git a/utils/net/http/request.go b/utils/net/http/request.go new file mode 100644 index 0000000..b2f933a --- /dev/null +++ b/utils/net/http/request.go @@ -0,0 +1,30 @@ +package httputils + +import ( + "net/http" + "strings" +) + +// GetRequestHost returns the request's host +func GetRequestHost(r *http.Request) string { + var host string + switch { + case r.Header.Get("X-Forwarded-Host") != "": + host = r.Header.Get("X-Forwarded-Host") + case !r.URL.IsAbs(): + host = r.Host + default: + host = r.URL.Host + } + return host +} + +// GetRequestDomain returns the request's domain +func GetRequestDomain(r *http.Request) string { + domain := GetRequestHost(r) + // right trim port + if portIndex := strings.Index(domain, ":"); portIndex != -1 { + domain = domain[:portIndex] + } + return domain +}