keel/net/http/middleware/jwt.go
2021-09-06 22:13:07 +02:00

146 lines
4.5 KiB
Go

package middleware
import (
"context"
"net/http"
jwt2 "github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"go.uber.org/zap"
"github.com/foomo/keel/jwt"
httputils "github.com/foomo/keel/utils/net/http"
)
type (
JWTOptions struct {
TokenProvider TokenProvider
ClaimsProvider JWTClaimsProvider
MissingTokenHandler JWTMissingTokenHandler
InvalidTokenHandler JWTInvalidTokenHandler
ErrorHandler JWTErrorHandler
}
JWTOption func(*JWTOptions)
JWTClaimsProvider func() jwt2.Claims
JWTErrorHandler func(*zap.Logger, http.ResponseWriter, *http.Request, error) bool
JWTMissingTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request) (jwt2.Claims, bool)
JWTInvalidTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request, *jwt2.Token) bool
)
// DefaultJWTErrorHandler function
func DefaultJWTErrorHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, err error) bool {
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed parse claims"))
return false
}
// DefaultJWTMissingTokenHandler function
func DefaultJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) {
return nil, true
}
// RequiredJWTMissingTokenHandler function
func RequiredJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) {
httputils.BadRequestServerError(l, w, r, errors.New("missing jwt token"))
return nil, false
}
// DefaultJWTInvalidTokenHandler function
func DefaultJWTInvalidTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, token *jwt2.Token) bool {
httputils.BadRequestServerError(l, w, r, errors.New("invalid jwt token"))
return false
}
// DefaultJWTClaimsProvider function
func DefaultJWTClaimsProvider() jwt2.Claims {
return &jwt2.StandardClaims{}
}
// GetDefaultJWTOptions returns the default options
func GetDefaultJWTOptions() JWTOptions {
return JWTOptions{
TokenProvider: HeaderTokenProvider(),
ClaimsProvider: DefaultJWTClaimsProvider,
ErrorHandler: DefaultJWTErrorHandler,
InvalidTokenHandler: DefaultJWTInvalidTokenHandler,
MissingTokenHandler: DefaultJWTMissingTokenHandler,
}
}
// JWTWithTokenProvider middleware option
func JWTWithTokenProvider(v TokenProvider) JWTOption {
return func(o *JWTOptions) {
o.TokenProvider = v
}
}
// JWTWithClaimsProvider middleware option
func JWTWithClaimsProvider(v JWTClaimsProvider) JWTOption {
return func(o *JWTOptions) {
o.ClaimsProvider = v
}
}
// JWTWithInvalidTokenHandler middleware option
func JWTWithInvalidTokenHandler(v JWTInvalidTokenHandler) JWTOption {
return func(o *JWTOptions) {
o.InvalidTokenHandler = v
}
}
// JWTWithMissingTokenHandler middleware option
func JWTWithMissingTokenHandler(v JWTMissingTokenHandler) JWTOption {
return func(o *JWTOptions) {
o.MissingTokenHandler = v
}
}
// JWTWithErrorHandler middleware option
func JWTWithErrorHandler(v JWTErrorHandler) JWTOption {
return func(o *JWTOptions) {
o.ErrorHandler = v
}
}
// JWT middleware
func JWT(jwt *jwt.JWT, contextKey interface{}, opts ...JWTOption) Middleware {
options := GetDefaultJWTOptions()
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return JWTWithOptions(jwt, contextKey, options)
}
// JWTWithOptions middleware
func JWTWithOptions(jwt *jwt.JWT, contextKey interface{}, opts JWTOptions) Middleware {
return func(l *zap.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := opts.ClaimsProvider()
if value := r.Context().Value(contextKey); value != nil {
// TODO check if type matches the existing
next.ServeHTTP(w, r)
} else if value, err := opts.TokenProvider(r); err != nil {
httputils.BadRequestServerError(l, w, r, errors.Wrap(err, "failed to retrieve token"))
} else if value == "" {
if claims, resume := opts.MissingTokenHandler(l, w, r); resume && claims != nil {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims)))
} else if resume {
next.ServeHTTP(w, r)
}
} else if token, err := jwt.ParseWithClaims(value, claims); err != nil {
// TODO check if type matches the existing
if opts.ErrorHandler(l, w, r, err) {
next.ServeHTTP(w, r)
}
} else if !token.Valid {
if opts.InvalidTokenHandler(l, w, r, token) {
next.ServeHTTP(w, r)
}
} else {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims)))
}
})
}
}