keel/net/http/middleware/jwt.go
2024-04-16 09:38:28 +02:00

201 lines
5.5 KiB
Go

package middleware
import (
"context"
"net/http"
jwt2 "github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"go.uber.org/zap"
"github.com/foomo/keel/jwt"
httputils "github.com/foomo/keel/utils/net/http"
)
type (
JWTOptions struct {
SetContext bool
TokenProvider TokenProvider
ClaimsProvider JWTClaimsProvider
ClaimsHandler JWTClaimsHandler
MissingTokenHandler JWTMissingTokenHandler
InvalidTokenHandler JWTInvalidTokenHandler
ErrorHandler JWTErrorHandler
}
JWTOption func(*JWTOptions)
JWTClaimsProvider func() jwt2.Claims
JWTClaimsHandler func(*zap.Logger, http.ResponseWriter, *http.Request, jwt2.Claims) bool
JWTErrorHandler func(*zap.Logger, http.ResponseWriter, *http.Request, error) bool
JWTMissingTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request) (jwt2.Claims, bool)
JWTInvalidTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request, *jwt2.Token) bool
)
// DefaultJWTErrorHandler function
func DefaultJWTErrorHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, err error) bool {
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed parse claims"))
return false
}
// DefaultJWTMissingTokenHandler function
func DefaultJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) {
return nil, true
}
// RequiredJWTMissingTokenHandler function
func RequiredJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) {
httputils.BadRequestServerError(l, w, r, errors.New("missing jwt token"))
return nil, false
}
// DefaultJWTInvalidTokenHandler function
func DefaultJWTInvalidTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, token *jwt2.Token) bool {
httputils.BadRequestServerError(l, w, r, errors.New("invalid jwt token"))
return false
}
// DefaultJWTClaimsProvider function
func DefaultJWTClaimsProvider() jwt2.Claims {
return &jwt2.StandardClaims{}
}
// DefaultJWTClaimsHandler function
func DefaultJWTClaimsHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, claims jwt2.Claims) bool {
return true
}
// GetDefaultJWTOptions returns the default options
func GetDefaultJWTOptions() JWTOptions {
return JWTOptions{
SetContext: true,
TokenProvider: HeaderTokenProvider(),
ClaimsProvider: DefaultJWTClaimsProvider,
ClaimsHandler: DefaultJWTClaimsHandler,
ErrorHandler: DefaultJWTErrorHandler,
InvalidTokenHandler: DefaultJWTInvalidTokenHandler,
MissingTokenHandler: DefaultJWTMissingTokenHandler,
}
}
// JWTWithTokenProvider middleware option
func JWTWithTokenProvider(v TokenProvider) JWTOption {
return func(o *JWTOptions) {
o.TokenProvider = v
}
}
// JWTWithClaimsProvider middleware option
func JWTWithClaimsProvider(v JWTClaimsProvider) JWTOption {
return func(o *JWTOptions) {
o.ClaimsProvider = v
}
}
// JWTWithClaimsHandler middleware option
func JWTWithClaimsHandler(v JWTClaimsHandler) JWTOption {
return func(o *JWTOptions) {
o.ClaimsHandler = v
}
}
// JWTWithInvalidTokenHandler middleware option
func JWTWithInvalidTokenHandler(v JWTInvalidTokenHandler) JWTOption {
return func(o *JWTOptions) {
o.InvalidTokenHandler = v
}
}
// JWTWithMissingTokenHandler middleware option
func JWTWithMissingTokenHandler(v JWTMissingTokenHandler) JWTOption {
return func(o *JWTOptions) {
o.MissingTokenHandler = v
}
}
// JWTWithErrorHandler middleware option
func JWTWithErrorHandler(v JWTErrorHandler) JWTOption {
return func(o *JWTOptions) {
o.ErrorHandler = v
}
}
func JWTWithSetContext(v bool) JWTOption {
return func(o *JWTOptions) {
o.SetContext = v
}
}
// JWT middleware
func JWT(v *jwt.JWT, contextKey interface{}, opts ...JWTOption) Middleware {
options := GetDefaultJWTOptions()
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return JWTWithOptions(v, contextKey, options)
}
// JWTWithOptions middleware
func JWTWithOptions(v *jwt.JWT, contextKey interface{}, opts JWTOptions) Middleware {
return func(l *zap.Logger, name string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := opts.ClaimsProvider()
// check existing claims from context
if value := r.Context().Value(contextKey); value != nil {
next.ServeHTTP(w, r)
return
}
// retrieve token from provider
token, err := opts.TokenProvider(r)
if err != nil {
httputils.BadRequestServerError(l, w, r, errors.Wrap(err, "failed to retrieve token"))
return
}
// handle missing token
if token == "" {
if claims, resume := opts.MissingTokenHandler(l, w, r); claims != nil && resume && opts.SetContext {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims)))
return
} else if resume {
next.ServeHTTP(w, r)
return
} else {
return
}
}
// don't validate if not required
if !opts.SetContext {
next.ServeHTTP(w, r)
return
}
// handle existing token
jwtToken, err := v.ParseWithClaims(token, claims)
if err != nil {
if resume := opts.ErrorHandler(l, w, r, err); resume {
next.ServeHTTP(w, r)
return
} else {
return
}
} else if !jwtToken.Valid {
if resume := opts.InvalidTokenHandler(l, w, r, jwtToken); resume {
next.ServeHTTP(w, r)
return
} else {
return
}
} else if resume := opts.ClaimsHandler(l, w, r, claims); !resume {
return
} else {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims)))
return
}
})
}
}