feat: add middlewares

This commit is contained in:
franklin 2021-09-06 22:13:07 +02:00
parent 6a6d96947d
commit af277d3a63
12 changed files with 472 additions and 313 deletions

5
net/http/headervalues.go Normal file
View File

@ -0,0 +1,5 @@
package http
const (
HeaderValueAuthorizationPrefix = "Bearer "
)

View File

@ -1,70 +0,0 @@
package middleware
import (
"net/http"
"strings"
"github.com/pkg/errors"
)
var (
ErrDomainNotAllowed = errors.New("domain not allowed")
)
type DomainProvider func(r *http.Request) (string, error)
var DefaultDomainProvider = func(domains []string) DomainProvider {
return func(r *http.Request) (string, error) {
domain := getDomainFromHTTPRequest(r)
if !isDomainAllowed(domain, domains) {
return "", ErrDomainNotAllowed
}
return domain, nil
}
}
var MappingDomainProvider = func(domains []string, mapping map[string]string) DomainProvider {
return func(r *http.Request) (string, error) {
domain := getDomainFromHTTPRequest(r)
if value, ok := mapping[domain]; ok {
domain = value
}
if !isDomainAllowed(domain, domains) {
return "", errors.New("invalid domain: " + domain)
}
return domain, nil
}
}
// getDomainFromHTTPRequest helper
func getDomainFromHTTPRequest(r *http.Request) string {
var domain string
if r.Header.Get("X-Forwarded-Host") != "" {
domain = r.Header.Get("X-Forwarded-Host")
} else if !r.URL.IsAbs() {
domain = r.Host
} else {
domain = r.URL.Host
}
// right trim port
portIndex := strings.Index(domain, ":")
if portIndex != -1 {
domain = domain[:portIndex]
}
return domain
}
// isDomainAllowed helper
func isDomainAllowed(domain string, domains []string) bool {
if domains == nil || len(domains) == 0 {
return true
}
for _, value := range domains {
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
return true
}
}
return false
}

145
net/http/middleware/jwt.go Normal file
View File

@ -0,0 +1,145 @@
package middleware
import (
"context"
"net/http"
jwt2 "github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"go.uber.org/zap"
"github.com/foomo/keel/jwt"
httputils "github.com/foomo/keel/utils/net/http"
)
type (
JWTOptions struct {
TokenProvider TokenProvider
ClaimsProvider JWTClaimsProvider
MissingTokenHandler JWTMissingTokenHandler
InvalidTokenHandler JWTInvalidTokenHandler
ErrorHandler JWTErrorHandler
}
JWTOption func(*JWTOptions)
JWTClaimsProvider func() jwt2.Claims
JWTErrorHandler func(*zap.Logger, http.ResponseWriter, *http.Request, error) bool
JWTMissingTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request) (jwt2.Claims, bool)
JWTInvalidTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request, *jwt2.Token) bool
)
// DefaultJWTErrorHandler function
func DefaultJWTErrorHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, err error) bool {
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed parse claims"))
return false
}
// DefaultJWTMissingTokenHandler function
func DefaultJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) {
return nil, true
}
// RequiredJWTMissingTokenHandler function
func RequiredJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) {
httputils.BadRequestServerError(l, w, r, errors.New("missing jwt token"))
return nil, false
}
// DefaultJWTInvalidTokenHandler function
func DefaultJWTInvalidTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, token *jwt2.Token) bool {
httputils.BadRequestServerError(l, w, r, errors.New("invalid jwt token"))
return false
}
// DefaultJWTClaimsProvider function
func DefaultJWTClaimsProvider() jwt2.Claims {
return &jwt2.StandardClaims{}
}
// GetDefaultJWTOptions returns the default options
func GetDefaultJWTOptions() JWTOptions {
return JWTOptions{
TokenProvider: HeaderTokenProvider(),
ClaimsProvider: DefaultJWTClaimsProvider,
ErrorHandler: DefaultJWTErrorHandler,
InvalidTokenHandler: DefaultJWTInvalidTokenHandler,
MissingTokenHandler: DefaultJWTMissingTokenHandler,
}
}
// JWTWithTokenProvider middleware option
func JWTWithTokenProvider(v TokenProvider) JWTOption {
return func(o *JWTOptions) {
o.TokenProvider = v
}
}
// JWTWithClaimsProvider middleware option
func JWTWithClaimsProvider(v JWTClaimsProvider) JWTOption {
return func(o *JWTOptions) {
o.ClaimsProvider = v
}
}
// JWTWithInvalidTokenHandler middleware option
func JWTWithInvalidTokenHandler(v JWTInvalidTokenHandler) JWTOption {
return func(o *JWTOptions) {
o.InvalidTokenHandler = v
}
}
// JWTWithMissingTokenHandler middleware option
func JWTWithMissingTokenHandler(v JWTMissingTokenHandler) JWTOption {
return func(o *JWTOptions) {
o.MissingTokenHandler = v
}
}
// JWTWithErrorHandler middleware option
func JWTWithErrorHandler(v JWTErrorHandler) JWTOption {
return func(o *JWTOptions) {
o.ErrorHandler = v
}
}
// JWT middleware
func JWT(jwt *jwt.JWT, contextKey interface{}, opts ...JWTOption) Middleware {
options := GetDefaultJWTOptions()
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return JWTWithOptions(jwt, contextKey, options)
}
// JWTWithOptions middleware
func JWTWithOptions(jwt *jwt.JWT, contextKey interface{}, opts JWTOptions) Middleware {
return func(l *zap.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := opts.ClaimsProvider()
if value := r.Context().Value(contextKey); value != nil {
// TODO check if type matches the existing
next.ServeHTTP(w, r)
} else if value, err := opts.TokenProvider(r); err != nil {
httputils.BadRequestServerError(l, w, r, errors.Wrap(err, "failed to retrieve token"))
} else if value == "" {
if claims, resume := opts.MissingTokenHandler(l, w, r); resume && claims != nil {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims)))
} else if resume {
next.ServeHTTP(w, r)
}
} else if token, err := jwt.ParseWithClaims(value, claims); err != nil {
// TODO check if type matches the existing
if opts.ErrorHandler(l, w, r, err) {
next.ServeHTTP(w, r)
}
} else if !token.Valid {
if opts.InvalidTokenHandler(l, w, r, token) {
next.ServeHTTP(w, r)
}
} else {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims)))
}
})
}
}

View File

@ -0,0 +1,44 @@
package middleware
import (
"net/http"
"go.uber.org/zap"
httputils "github.com/foomo/keel/net/http"
)
type (
ServerOptions struct {
Header string
}
ServerOption func(*ServerOptions)
)
// GetDefaultServerOptions returns the default options
func GetDefaultServerOptions() ServerOptions {
return ServerOptions{
Header: httputils.HeaderServer,
}
}
// Server middleware
func Server(name string, opts ...ServerOption) Middleware {
options := GetDefaultServerOptions()
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return ServerWithOptions(name, options)
}
// ServerWithOptions middleware
func ServerWithOptions(name string, opts ServerOptions) Middleware {
return func(l *zap.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(opts.Header, name)
next.ServeHTTP(w, r)
})
}
}

View File

@ -1,141 +1,147 @@
package middleware package middleware
import ( import (
"context"
"net/http" "net/http"
"github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.uber.org/zap" "go.uber.org/zap"
httputils "github.com/foomo/keel/net/http" keelhttp "github.com/foomo/keel/net/http"
"github.com/foomo/keel/net/http/cookie"
httputils "github.com/foomo/keel/utils/net/http"
) )
type ( type (
SessionIDConfig struct { contextKey string
SetCookie bool // only create a session cookie if enabled SessionIDOptions struct {
CookieName string // Header to look up the session id
CookieSecure bool Header string
CookieHttpOnly bool // Cookie how to set the cookie
CookiePath string Cookie cookie.Cookie
CookieDomain string // Generator for the session ids
CookieDomains []string
CookieDomainProvider DomainProvider
Generator SessionIDGenerator Generator SessionIDGenerator
// SetCookie if true it will create a cookie if not exists
SetCookie bool
// SetHeader if true it will set add a request header
SetHeader bool
// SetContext if true it will set the context key
SetContext bool
} }
SessionIDOption func(*SessionIDConfig) error SessionIDOption func(*SessionIDOptions)
SessionIDGenerator func() string
) )
var DefaultSessionIDConfig = SessionIDConfig{ const (
ContextKeySessionID contextKey = "sessionId"
DefaultSessionIDCookieName = "keel-session"
)
// DefaultSessionIDGenerator function
func DefaultSessionIDGenerator() string {
return uuid.New().String()
}
// GetDefaultSessionIDOptions returns the default options
func GetDefaultSessionIDOptions() SessionIDOptions {
return SessionIDOptions{
Header: keelhttp.HeaderXSessionID,
Cookie: cookie.New(DefaultSessionIDCookieName),
Generator: DefaultSessionIDGenerator,
SetCookie: false, SetCookie: false,
CookieName: "keel-session", SetHeader: true,
CookieSecure: true, SetContext: true,
CookieHttpOnly: true, }
CookiePath: "/",
} }
func SessionIDWithHeader(v string) SessionIDOption {
return func(o *SessionIDOptions) {
o.Header = v
}
}
// SessionIDWithSetCookie middleware option
func SessionIDWithSetCookie(v bool) SessionIDOption { func SessionIDWithSetCookie(v bool) SessionIDOption {
return func(c *SessionIDConfig) error { return func(o *SessionIDOptions) {
c.SetCookie = v o.SetCookie = v
} }
} }
func SessionIDWithCookieName(v string) SessionIDOption { // SessionIDWithSetHeader middleware option
return func(c *SessionIDConfig) error { func SessionIDWithSetHeader(v bool) SessionIDOption {
c.CookieName = v return func(o *SessionIDOptions) {
o.SetHeader = v
} }
} }
func SessionIDWithCookieSecure(v bool) SessionIDOption { // SessionIDWithSetContext middleware option
return func(c *SessionIDConfig) error { func SessionIDWithSetContext(v bool) SessionIDOption {
c.CookieSecure = v return func(o *SessionIDOptions) {
o.SetContext = v
} }
} }
func SessionIDWithCookieHttpOnly(v bool) SessionIDOption { // SessionIDWithCookie middleware option
return func(c *SessionIDConfig) error { func SessionIDWithCookie(v cookie.Cookie) SessionIDOption {
c.CookieHttpOnly = v return func(o *SessionIDOptions) {
} o.Cookie = v
}
func SessionIDWithCookiePath(v string) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookiePath = v
}
}
func SessionIDWithCookieDomain(v string) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieDomain = v
}
}
func SessionIDWithCookieDomains(v []string) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieDomains = v
}
}
func SessionIDWithCookieDomainProvider(v DomainProvider) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieDomainProvider = v
} }
} }
// SessionIDWithGenerator middleware option
func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption { func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption {
return func(c *SessionIDConfig) error { return func(o *SessionIDOptions) {
c.Generator = v o.Generator = v
} }
} }
// SessionID middleware
func SessionID(opts ...SessionIDOption) Middleware { func SessionID(opts ...SessionIDOption) Middleware {
config = DefaultSessionIDConfig options := GetDefaultSessionIDOptions()
for _, opt := range opts { for _, opt := range opts {
if opt != nil { if opt != nil {
if err := opt(&opts); err != nil { opt(&options)
return nil, err
} }
} }
} return SessionIDWithOptions(options)
if config.Generator == nil { }
config.Generator = DefaultSessionIDGenerator
}
if config.DomainProvider == nil {
config.DomainProvider = DefaultDomainProvider(config.CookieDomains)
}
// SessionIDWithOptions middleware
func SessionIDWithOptions(opts SessionIDOptions) Middleware {
return func(l *zap.Logger, next http.Handler) http.Handler { return func(l *zap.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var sessionID string
if id := r.Header.Get(httputils.HeaderXSessionID); id == "" { if c, err := opts.Cookie.Get(r); errors.Is(err, http.ErrNoCookie) && !opts.SetCookie {
if cookie, err := r.Cookie(config.CookieName); errors.Is(err, http.ErrNoCookie) { // do nothing
if config.SetCookie { } else if errors.Is(err, http.ErrNoCookie) && opts.SetCookie {
sessionID = opts.Generator()
domain, err := config.DomainProvider(r) if err := opts.Cookie.Set(w, r, sessionID); err != nil {
if err != nil { httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to set session id cookie"))
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to resolve domain"))
return return
} }
id = config.Generator()
http.SetCookie(w, &http.Cookie{
Name: config.CookieName,
Value: id,
Path: config.CookiePath,
HttpOnly: config.CookieHttpOnly,
Secure: config.CookieSecure,
Domain: domain,
})
}
r.Header.Set(httputils.HeaderXSessionID, id)
} else if err != nil { } else if err != nil {
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read cookie")) httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read session id cookie"))
return return
} else { } else {
r.Header.Set(httputils.HeaderXSessionID, cookie.Value) sessionID = c.Value
} }
if sessionID != "" && opts.SetHeader {
r.Header.Set(opts.Header, sessionID)
}
if sessionID != "" && opts.SetContext {
r = r.WithContext(context.WithValue(r.Context(), ContextKeySessionID, sessionID))
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }
} }
// SessionIDFromContext helper
func SessionIDFromContext(ctx context.Context) string {
if value, ok := ctx.Value(ContextKeySessionID).(string); ok {
return value
}
return ""
}

View File

@ -1,9 +0,0 @@
package middleware
import "github.com/google/uuid"
type SessionIDGenerator func() string
func DefaultSessionIDGenerator() string {
return uuid.New().String()
}

View File

@ -0,0 +1,22 @@
package middleware
import (
"net/http"
"go.uber.org/zap"
)
func Skip(mw Middleware, skippers ...Skipper) Middleware {
return func(l *zap.Logger, next http.Handler) http.Handler {
wrapped := mw(l, next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, skipper := range skippers {
if skipper(r) {
next.ServeHTTP(w, r)
return
}
}
wrapped.ServeHTTP(w, r)
})
}
}

View File

@ -0,0 +1,35 @@
package middleware
import (
"net/http"
)
type Skipper func(*http.Request) bool
// RequestURIWhitelistSkipper returns a HTTPMiddlewareConfig.Skipper which skips all but the given uris
func RequestURIWhitelistSkipper(uris ...string) Skipper {
urisMap := make(map[string]bool, len(uris))
for _, uri := range uris {
urisMap[uri] = true
}
return func(r *http.Request) bool {
if _, ok := urisMap[r.RequestURI]; ok {
return true
}
return false
}
}
// RequestURIBlacklistSkipper returns a HTTPMiddlewareConfig.Skipper which skips the given uris
func RequestURIBlacklistSkipper(uris ...string) Skipper {
urisMap := make(map[string]bool, len(uris))
for _, uri := range uris {
urisMap[uri] = true
}
return func(r *http.Request) bool {
if _, ok := urisMap[r.RequestURI]; ok {
return false
}
return true
}
}

View File

@ -1,139 +0,0 @@
package cmrcsession
import (
"net/http"
"strings"
"time"
cmrcjwt "github.com/bestbytes/commerce/pkg/jwt"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
)
type Session struct {
jwt *cmrcjwt.JWT
secure bool
domains []string
domainMapping map[string]string
}
func NewSession(jwt *cmrcjwt.JWT, domains []string, domainMapping map[string]string, secure bool) *Session {
return &Session{
jwt: jwt,
secure: secure,
domains: domains,
domainMapping: domainMapping,
}
}
func (s *Session) HasDomain(domain string) bool {
for _, value := range s.domains {
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
return true
}
}
return false
}
func (s *Session) GetCookieClaims(r *http.Request, name string, claims jwt.Claims) (jwt.Claims, error) {
if cookie, err := r.Cookie(name); err != nil {
return nil, err
} else if token, err := s.jwt.ParseWithClaims(cookie.Value, claims); err != nil {
return nil, err
} else if !token.Valid {
return nil, errors.New("invalid token")
} else {
return token.Claims, nil
}
}
func (s *Session) SetCookieClaims(r *http.Request, w http.ResponseWriter, name string, claims jwt.Claims, lifetime time.Duration) error {
domain, err := s.getDomain(r)
if err != nil {
return err
}
token, err := s.jwt.GetSignedToken(claims)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: token,
Path: "/",
MaxAge: int(lifetime.Seconds()),
HttpOnly: true,
Secure: s.secure,
Domain: domain,
})
return nil
}
func (s *Session) SetSessionCookie(r *http.Request, w http.ResponseWriter, name string, value string) error {
domain, err := s.getDomain(r)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
Path: "/",
HttpOnly: true,
Secure: s.secure,
Domain: domain,
})
return nil
}
func (s *Session) DeleteCookie(r *http.Request, w http.ResponseWriter, name string) error {
domain, err := s.getDomain(r)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
HttpOnly: true,
Secure: s.secure,
Domain: domain,
MaxAge: -1,
Expires: time.Now().AddDate(0, 0, -1),
})
return nil
}
func (s *Session) getDomain(r *http.Request) (string, error) {
var domain string
if r.Header.Get("X-Forwarded-Host") != "" {
domain = r.Header.Get("X-Forwarded-Host")
} else if !r.URL.IsAbs() {
domain = r.Host
} else {
domain = r.URL.Host
}
// right trim port
portIndex := strings.Index(domain, ":")
if portIndex != -1 {
domain = domain[:portIndex]
}
if s.domainMapping != nil {
if value, ok := s.domainMapping[domain]; ok {
domain = value
}
}
if !s.HasDomain(domain) {
return "", errors.New("invalid domain: " + domain)
}
return domain, nil
}

View File

@ -0,0 +1,88 @@
package middleware
import (
"net/http"
"strings"
"github.com/pkg/errors"
keelhttp "github.com/foomo/keel/net/http"
)
type TokenProvider func(r *http.Request) (string, error)
type (
HeaderTokenProviderOptions struct {
Prefix string
Header string
}
HeaderTokenProviderOption func(*HeaderTokenProviderOptions)
)
// GetDefaultHeaderTokenOptions returns the default options
func GetDefaultHeaderTokenOptions() HeaderTokenProviderOptions {
return HeaderTokenProviderOptions{
Prefix: keelhttp.HeaderValueAuthorizationPrefix,
Header: keelhttp.HeaderAuthorization,
}
}
// HeaderTokenProviderWithPrefix middleware option
func HeaderTokenProviderWithPrefix(v string) HeaderTokenProviderOption {
return func(o *HeaderTokenProviderOptions) {
o.Prefix = v
}
}
// HeaderTokenProviderWithHeader middleware option
func HeaderTokenProviderWithHeader(v string) HeaderTokenProviderOption {
return func(o *HeaderTokenProviderOptions) {
o.Header = v
}
}
func HeaderTokenProvider(opts ...HeaderTokenProviderOption) TokenProvider {
options := GetDefaultHeaderTokenOptions()
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return func(r *http.Request) (string, error) {
if value := r.Header.Get(options.Header); value == "" {
return "", nil
} else if !strings.HasPrefix(value, options.Prefix) {
return "", errors.New("malformed bearer token")
} else {
return strings.TrimPrefix(value, options.Prefix), nil
}
}
}
type (
CookieTokenProviderOptions struct{}
CookieTokenProviderOption func(*CookieTokenProviderOptions)
)
// GetDefaultCookieTokenOptions returns the default options
func GetDefaultCookieTokenOptions() CookieTokenProviderOptions {
return CookieTokenProviderOptions{}
}
func CookieTokenProvider(cookieName string, opts ...CookieTokenProviderOption) TokenProvider {
options := GetDefaultCookieTokenOptions()
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return func(r *http.Request) (string, error) {
if cookie, err := r.Cookie(cookieName); errors.Is(err, http.ErrNoCookie) {
return "", nil
} else if err != nil {
return "", errors.New("failed to retrieve cookie")
} else {
return cookie.Value, nil
}
}
}

View File

@ -6,6 +6,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/foomo/keel/log" "github.com/foomo/keel/log"
keelhttp "github.com/foomo/keel/net/http"
) )
// InternalServerError http response // InternalServerError http response
@ -31,7 +32,8 @@ func NotFoundServerError(l *zap.Logger, w http.ResponseWriter, r *http.Request,
// ServerError http response // ServerError http response
func ServerError(l *zap.Logger, w http.ResponseWriter, r *http.Request, code int, err error) { func ServerError(l *zap.Logger, w http.ResponseWriter, r *http.Request, code int, err error) {
if err != nil { if err != nil {
log.Configure(l).HTTPRequest(r).Error(err).Logger().Error("http server error", zap.Int("code", code)) log.WithHTTPRequest(l, r).Error("http server error", log.FError(err), log.FHTTPStatusCode(code))
http.Error(w, http.StatusText(code), code) // TODO enrich headers w.Header().Set(keelhttp.HeaderXError, err.Error())
http.Error(w, http.StatusText(code), code)
} }
} }

30
utils/net/http/request.go Normal file
View File

@ -0,0 +1,30 @@
package httputils
import (
"net/http"
"strings"
)
// GetRequestHost returns the request's host
func GetRequestHost(r *http.Request) string {
var host string
switch {
case r.Header.Get("X-Forwarded-Host") != "":
host = r.Header.Get("X-Forwarded-Host")
case !r.URL.IsAbs():
host = r.Host
default:
host = r.URL.Host
}
return host
}
// GetRequestDomain returns the request's domain
func GetRequestDomain(r *http.Request) string {
domain := GetRequestHost(r)
// right trim port
if portIndex := strings.Index(domain, ":"); portIndex != -1 {
domain = domain[:portIndex]
}
return domain
}