mirror of
https://github.com/foomo/keel.git
synced 2025-10-16 12:35:34 +00:00
feat: add middlewares
This commit is contained in:
parent
6a6d96947d
commit
af277d3a63
5
net/http/headervalues.go
Normal file
5
net/http/headervalues.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
const (
|
||||||
|
HeaderValueAuthorizationPrefix = "Bearer "
|
||||||
|
)
|
||||||
@ -1,70 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrDomainNotAllowed = errors.New("domain not allowed")
|
|
||||||
)
|
|
||||||
|
|
||||||
type DomainProvider func(r *http.Request) (string, error)
|
|
||||||
|
|
||||||
var DefaultDomainProvider = func(domains []string) DomainProvider {
|
|
||||||
return func(r *http.Request) (string, error) {
|
|
||||||
domain := getDomainFromHTTPRequest(r)
|
|
||||||
if !isDomainAllowed(domain, domains) {
|
|
||||||
return "", ErrDomainNotAllowed
|
|
||||||
}
|
|
||||||
return domain, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var MappingDomainProvider = func(domains []string, mapping map[string]string) DomainProvider {
|
|
||||||
return func(r *http.Request) (string, error) {
|
|
||||||
domain := getDomainFromHTTPRequest(r)
|
|
||||||
if value, ok := mapping[domain]; ok {
|
|
||||||
domain = value
|
|
||||||
}
|
|
||||||
if !isDomainAllowed(domain, domains) {
|
|
||||||
return "", errors.New("invalid domain: " + domain)
|
|
||||||
}
|
|
||||||
return domain, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getDomainFromHTTPRequest helper
|
|
||||||
func getDomainFromHTTPRequest(r *http.Request) string {
|
|
||||||
var domain string
|
|
||||||
if r.Header.Get("X-Forwarded-Host") != "" {
|
|
||||||
domain = r.Header.Get("X-Forwarded-Host")
|
|
||||||
} else if !r.URL.IsAbs() {
|
|
||||||
domain = r.Host
|
|
||||||
} else {
|
|
||||||
domain = r.URL.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
// right trim port
|
|
||||||
portIndex := strings.Index(domain, ":")
|
|
||||||
if portIndex != -1 {
|
|
||||||
domain = domain[:portIndex]
|
|
||||||
}
|
|
||||||
|
|
||||||
return domain
|
|
||||||
}
|
|
||||||
|
|
||||||
// isDomainAllowed helper
|
|
||||||
func isDomainAllowed(domain string, domains []string) bool {
|
|
||||||
if domains == nil || len(domains) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for _, value := range domains {
|
|
||||||
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
145
net/http/middleware/jwt.go
Normal file
145
net/http/middleware/jwt.go
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
jwt2 "github.com/golang-jwt/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/foomo/keel/jwt"
|
||||||
|
httputils "github.com/foomo/keel/utils/net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
JWTOptions struct {
|
||||||
|
TokenProvider TokenProvider
|
||||||
|
ClaimsProvider JWTClaimsProvider
|
||||||
|
MissingTokenHandler JWTMissingTokenHandler
|
||||||
|
InvalidTokenHandler JWTInvalidTokenHandler
|
||||||
|
ErrorHandler JWTErrorHandler
|
||||||
|
}
|
||||||
|
JWTOption func(*JWTOptions)
|
||||||
|
JWTClaimsProvider func() jwt2.Claims
|
||||||
|
JWTErrorHandler func(*zap.Logger, http.ResponseWriter, *http.Request, error) bool
|
||||||
|
JWTMissingTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request) (jwt2.Claims, bool)
|
||||||
|
JWTInvalidTokenHandler func(*zap.Logger, http.ResponseWriter, *http.Request, *jwt2.Token) bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultJWTErrorHandler function
|
||||||
|
func DefaultJWTErrorHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, err error) bool {
|
||||||
|
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed parse claims"))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultJWTMissingTokenHandler function
|
||||||
|
func DefaultJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequiredJWTMissingTokenHandler function
|
||||||
|
func RequiredJWTMissingTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request) (jwt2.Claims, bool) {
|
||||||
|
httputils.BadRequestServerError(l, w, r, errors.New("missing jwt token"))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultJWTInvalidTokenHandler function
|
||||||
|
func DefaultJWTInvalidTokenHandler(l *zap.Logger, w http.ResponseWriter, r *http.Request, token *jwt2.Token) bool {
|
||||||
|
httputils.BadRequestServerError(l, w, r, errors.New("invalid jwt token"))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultJWTClaimsProvider function
|
||||||
|
func DefaultJWTClaimsProvider() jwt2.Claims {
|
||||||
|
return &jwt2.StandardClaims{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultJWTOptions returns the default options
|
||||||
|
func GetDefaultJWTOptions() JWTOptions {
|
||||||
|
return JWTOptions{
|
||||||
|
TokenProvider: HeaderTokenProvider(),
|
||||||
|
ClaimsProvider: DefaultJWTClaimsProvider,
|
||||||
|
ErrorHandler: DefaultJWTErrorHandler,
|
||||||
|
InvalidTokenHandler: DefaultJWTInvalidTokenHandler,
|
||||||
|
MissingTokenHandler: DefaultJWTMissingTokenHandler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTWithTokenProvider middleware option
|
||||||
|
func JWTWithTokenProvider(v TokenProvider) JWTOption {
|
||||||
|
return func(o *JWTOptions) {
|
||||||
|
o.TokenProvider = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTWithClaimsProvider middleware option
|
||||||
|
func JWTWithClaimsProvider(v JWTClaimsProvider) JWTOption {
|
||||||
|
return func(o *JWTOptions) {
|
||||||
|
o.ClaimsProvider = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTWithInvalidTokenHandler middleware option
|
||||||
|
func JWTWithInvalidTokenHandler(v JWTInvalidTokenHandler) JWTOption {
|
||||||
|
return func(o *JWTOptions) {
|
||||||
|
o.InvalidTokenHandler = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTWithMissingTokenHandler middleware option
|
||||||
|
func JWTWithMissingTokenHandler(v JWTMissingTokenHandler) JWTOption {
|
||||||
|
return func(o *JWTOptions) {
|
||||||
|
o.MissingTokenHandler = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTWithErrorHandler middleware option
|
||||||
|
func JWTWithErrorHandler(v JWTErrorHandler) JWTOption {
|
||||||
|
return func(o *JWTOptions) {
|
||||||
|
o.ErrorHandler = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWT middleware
|
||||||
|
func JWT(jwt *jwt.JWT, contextKey interface{}, opts ...JWTOption) Middleware {
|
||||||
|
options := GetDefaultJWTOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt != nil {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return JWTWithOptions(jwt, contextKey, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTWithOptions middleware
|
||||||
|
func JWTWithOptions(jwt *jwt.JWT, contextKey interface{}, opts JWTOptions) Middleware {
|
||||||
|
return func(l *zap.Logger, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := opts.ClaimsProvider()
|
||||||
|
if value := r.Context().Value(contextKey); value != nil {
|
||||||
|
// TODO check if type matches the existing
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
} else if value, err := opts.TokenProvider(r); err != nil {
|
||||||
|
httputils.BadRequestServerError(l, w, r, errors.Wrap(err, "failed to retrieve token"))
|
||||||
|
} else if value == "" {
|
||||||
|
if claims, resume := opts.MissingTokenHandler(l, w, r); resume && claims != nil {
|
||||||
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims)))
|
||||||
|
} else if resume {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
} else if token, err := jwt.ParseWithClaims(value, claims); err != nil {
|
||||||
|
// TODO check if type matches the existing
|
||||||
|
if opts.ErrorHandler(l, w, r, err) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
} else if !token.Valid {
|
||||||
|
if opts.InvalidTokenHandler(l, w, r, token) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey, claims)))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
44
net/http/middleware/server.go
Normal file
44
net/http/middleware/server.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
httputils "github.com/foomo/keel/net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
ServerOptions struct {
|
||||||
|
Header string
|
||||||
|
}
|
||||||
|
ServerOption func(*ServerOptions)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetDefaultServerOptions returns the default options
|
||||||
|
func GetDefaultServerOptions() ServerOptions {
|
||||||
|
return ServerOptions{
|
||||||
|
Header: httputils.HeaderServer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server middleware
|
||||||
|
func Server(name string, opts ...ServerOption) Middleware {
|
||||||
|
options := GetDefaultServerOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt != nil {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ServerWithOptions(name, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerWithOptions middleware
|
||||||
|
func ServerWithOptions(name string, opts ServerOptions) Middleware {
|
||||||
|
return func(l *zap.Logger, next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Add(opts.Header, name)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,141 +1,147 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
httputils "github.com/foomo/keel/net/http"
|
keelhttp "github.com/foomo/keel/net/http"
|
||||||
|
"github.com/foomo/keel/net/http/cookie"
|
||||||
|
httputils "github.com/foomo/keel/utils/net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
SessionIDConfig struct {
|
contextKey string
|
||||||
SetCookie bool // only create a session cookie if enabled
|
SessionIDOptions struct {
|
||||||
CookieName string
|
// Header to look up the session id
|
||||||
CookieSecure bool
|
Header string
|
||||||
CookieHttpOnly bool
|
// Cookie how to set the cookie
|
||||||
CookiePath string
|
Cookie cookie.Cookie
|
||||||
CookieDomain string
|
// Generator for the session ids
|
||||||
CookieDomains []string
|
|
||||||
CookieDomainProvider DomainProvider
|
|
||||||
Generator SessionIDGenerator
|
Generator SessionIDGenerator
|
||||||
|
// SetCookie if true it will create a cookie if not exists
|
||||||
|
SetCookie bool
|
||||||
|
// SetHeader if true it will set add a request header
|
||||||
|
SetHeader bool
|
||||||
|
// SetContext if true it will set the context key
|
||||||
|
SetContext bool
|
||||||
}
|
}
|
||||||
SessionIDOption func(*SessionIDConfig) error
|
SessionIDOption func(*SessionIDOptions)
|
||||||
|
SessionIDGenerator func() string
|
||||||
)
|
)
|
||||||
|
|
||||||
var DefaultSessionIDConfig = SessionIDConfig{
|
const (
|
||||||
|
ContextKeySessionID contextKey = "sessionId"
|
||||||
|
|
||||||
|
DefaultSessionIDCookieName = "keel-session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultSessionIDGenerator function
|
||||||
|
func DefaultSessionIDGenerator() string {
|
||||||
|
return uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultSessionIDOptions returns the default options
|
||||||
|
func GetDefaultSessionIDOptions() SessionIDOptions {
|
||||||
|
return SessionIDOptions{
|
||||||
|
Header: keelhttp.HeaderXSessionID,
|
||||||
|
Cookie: cookie.New(DefaultSessionIDCookieName),
|
||||||
|
Generator: DefaultSessionIDGenerator,
|
||||||
SetCookie: false,
|
SetCookie: false,
|
||||||
CookieName: "keel-session",
|
SetHeader: true,
|
||||||
CookieSecure: true,
|
SetContext: true,
|
||||||
CookieHttpOnly: true,
|
}
|
||||||
CookiePath: "/",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SessionIDWithHeader(v string) SessionIDOption {
|
||||||
|
return func(o *SessionIDOptions) {
|
||||||
|
o.Header = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionIDWithSetCookie middleware option
|
||||||
func SessionIDWithSetCookie(v bool) SessionIDOption {
|
func SessionIDWithSetCookie(v bool) SessionIDOption {
|
||||||
return func(c *SessionIDConfig) error {
|
return func(o *SessionIDOptions) {
|
||||||
c.SetCookie = v
|
o.SetCookie = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SessionIDWithCookieName(v string) SessionIDOption {
|
// SessionIDWithSetHeader middleware option
|
||||||
return func(c *SessionIDConfig) error {
|
func SessionIDWithSetHeader(v bool) SessionIDOption {
|
||||||
c.CookieName = v
|
return func(o *SessionIDOptions) {
|
||||||
|
o.SetHeader = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SessionIDWithCookieSecure(v bool) SessionIDOption {
|
// SessionIDWithSetContext middleware option
|
||||||
return func(c *SessionIDConfig) error {
|
func SessionIDWithSetContext(v bool) SessionIDOption {
|
||||||
c.CookieSecure = v
|
return func(o *SessionIDOptions) {
|
||||||
|
o.SetContext = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SessionIDWithCookieHttpOnly(v bool) SessionIDOption {
|
// SessionIDWithCookie middleware option
|
||||||
return func(c *SessionIDConfig) error {
|
func SessionIDWithCookie(v cookie.Cookie) SessionIDOption {
|
||||||
c.CookieHttpOnly = v
|
return func(o *SessionIDOptions) {
|
||||||
}
|
o.Cookie = v
|
||||||
}
|
|
||||||
|
|
||||||
func SessionIDWithCookiePath(v string) SessionIDOption {
|
|
||||||
return func(c *SessionIDConfig) error {
|
|
||||||
c.CookiePath = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SessionIDWithCookieDomain(v string) SessionIDOption {
|
|
||||||
return func(c *SessionIDConfig) error {
|
|
||||||
c.CookieDomain = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SessionIDWithCookieDomains(v []string) SessionIDOption {
|
|
||||||
return func(c *SessionIDConfig) error {
|
|
||||||
c.CookieDomains = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func SessionIDWithCookieDomainProvider(v DomainProvider) SessionIDOption {
|
|
||||||
return func(c *SessionIDConfig) error {
|
|
||||||
c.CookieDomainProvider = v
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionIDWithGenerator middleware option
|
||||||
func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption {
|
func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption {
|
||||||
return func(c *SessionIDConfig) error {
|
return func(o *SessionIDOptions) {
|
||||||
c.Generator = v
|
o.Generator = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionID middleware
|
||||||
func SessionID(opts ...SessionIDOption) Middleware {
|
func SessionID(opts ...SessionIDOption) Middleware {
|
||||||
config = DefaultSessionIDConfig
|
options := GetDefaultSessionIDOptions()
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if opt != nil {
|
if opt != nil {
|
||||||
if err := opt(&opts); err != nil {
|
opt(&options)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
return SessionIDWithOptions(options)
|
||||||
if config.Generator == nil {
|
|
||||||
config.Generator = DefaultSessionIDGenerator
|
|
||||||
}
|
|
||||||
if config.DomainProvider == nil {
|
|
||||||
config.DomainProvider = DefaultDomainProvider(config.CookieDomains)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionIDWithOptions middleware
|
||||||
|
func SessionIDWithOptions(opts SessionIDOptions) Middleware {
|
||||||
return func(l *zap.Logger, next http.Handler) http.Handler {
|
return func(l *zap.Logger, next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var sessionID string
|
||||||
if id := r.Header.Get(httputils.HeaderXSessionID); id == "" {
|
if c, err := opts.Cookie.Get(r); errors.Is(err, http.ErrNoCookie) && !opts.SetCookie {
|
||||||
if cookie, err := r.Cookie(config.CookieName); errors.Is(err, http.ErrNoCookie) {
|
// do nothing
|
||||||
if config.SetCookie {
|
} else if errors.Is(err, http.ErrNoCookie) && opts.SetCookie {
|
||||||
|
sessionID = opts.Generator()
|
||||||
domain, err := config.DomainProvider(r)
|
if err := opts.Cookie.Set(w, r, sessionID); err != nil {
|
||||||
if err != nil {
|
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to set session id cookie"))
|
||||||
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to resolve domain"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id = config.Generator()
|
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: config.CookieName,
|
|
||||||
Value: id,
|
|
||||||
Path: config.CookiePath,
|
|
||||||
HttpOnly: config.CookieHttpOnly,
|
|
||||||
Secure: config.CookieSecure,
|
|
||||||
Domain: domain,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
r.Header.Set(httputils.HeaderXSessionID, id)
|
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read cookie"))
|
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read session id cookie"))
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
r.Header.Set(httputils.HeaderXSessionID, cookie.Value)
|
sessionID = c.Value
|
||||||
}
|
}
|
||||||
|
if sessionID != "" && opts.SetHeader {
|
||||||
|
r.Header.Set(opts.Header, sessionID)
|
||||||
|
}
|
||||||
|
if sessionID != "" && opts.SetContext {
|
||||||
|
r = r.WithContext(context.WithValue(r.Context(), ContextKeySessionID, sessionID))
|
||||||
}
|
}
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionIDFromContext helper
|
||||||
|
func SessionIDFromContext(ctx context.Context) string {
|
||||||
|
if value, ok := ctx.Value(ContextKeySessionID).(string); ok {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@ -1,9 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import "github.com/google/uuid"
|
|
||||||
|
|
||||||
type SessionIDGenerator func() string
|
|
||||||
|
|
||||||
func DefaultSessionIDGenerator() string {
|
|
||||||
return uuid.New().String()
|
|
||||||
}
|
|
||||||
22
net/http/middleware/skip.go
Normal file
22
net/http/middleware/skip.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Skip(mw Middleware, skippers ...Skipper) Middleware {
|
||||||
|
return func(l *zap.Logger, next http.Handler) http.Handler {
|
||||||
|
wrapped := mw(l, next)
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
for _, skipper := range skippers {
|
||||||
|
if skipper(r) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wrapped.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
35
net/http/middleware/skipper.go
Normal file
35
net/http/middleware/skipper.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Skipper func(*http.Request) bool
|
||||||
|
|
||||||
|
// RequestURIWhitelistSkipper returns a HTTPMiddlewareConfig.Skipper which skips all but the given uris
|
||||||
|
func RequestURIWhitelistSkipper(uris ...string) Skipper {
|
||||||
|
urisMap := make(map[string]bool, len(uris))
|
||||||
|
for _, uri := range uris {
|
||||||
|
urisMap[uri] = true
|
||||||
|
}
|
||||||
|
return func(r *http.Request) bool {
|
||||||
|
if _, ok := urisMap[r.RequestURI]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestURIBlacklistSkipper returns a HTTPMiddlewareConfig.Skipper which skips the given uris
|
||||||
|
func RequestURIBlacklistSkipper(uris ...string) Skipper {
|
||||||
|
urisMap := make(map[string]bool, len(uris))
|
||||||
|
for _, uri := range uris {
|
||||||
|
urisMap[uri] = true
|
||||||
|
}
|
||||||
|
return func(r *http.Request) bool {
|
||||||
|
if _, ok := urisMap[r.RequestURI]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,139 +0,0 @@
|
|||||||
package cmrcsession
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
cmrcjwt "github.com/bestbytes/commerce/pkg/jwt"
|
|
||||||
"github.com/golang-jwt/jwt"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Session struct {
|
|
||||||
jwt *cmrcjwt.JWT
|
|
||||||
secure bool
|
|
||||||
domains []string
|
|
||||||
domainMapping map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSession(jwt *cmrcjwt.JWT, domains []string, domainMapping map[string]string, secure bool) *Session {
|
|
||||||
return &Session{
|
|
||||||
jwt: jwt,
|
|
||||||
secure: secure,
|
|
||||||
domains: domains,
|
|
||||||
domainMapping: domainMapping,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) HasDomain(domain string) bool {
|
|
||||||
for _, value := range s.domains {
|
|
||||||
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) GetCookieClaims(r *http.Request, name string, claims jwt.Claims) (jwt.Claims, error) {
|
|
||||||
if cookie, err := r.Cookie(name); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if token, err := s.jwt.ParseWithClaims(cookie.Value, claims); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if !token.Valid {
|
|
||||||
return nil, errors.New("invalid token")
|
|
||||||
} else {
|
|
||||||
return token.Claims, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) SetCookieClaims(r *http.Request, w http.ResponseWriter, name string, claims jwt.Claims, lifetime time.Duration) error {
|
|
||||||
domain, err := s.getDomain(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := s.jwt.GetSignedToken(claims)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: token,
|
|
||||||
Path: "/",
|
|
||||||
MaxAge: int(lifetime.Seconds()),
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: s.secure,
|
|
||||||
Domain: domain,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) SetSessionCookie(r *http.Request, w http.ResponseWriter, name string, value string) error {
|
|
||||||
domain, err := s.getDomain(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: value,
|
|
||||||
Path: "/",
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: s.secure,
|
|
||||||
Domain: domain,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) DeleteCookie(r *http.Request, w http.ResponseWriter, name string) error {
|
|
||||||
domain, err := s.getDomain(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: name,
|
|
||||||
Value: "",
|
|
||||||
Path: "/",
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: s.secure,
|
|
||||||
Domain: domain,
|
|
||||||
MaxAge: -1,
|
|
||||||
Expires: time.Now().AddDate(0, 0, -1),
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) getDomain(r *http.Request) (string, error) {
|
|
||||||
var domain string
|
|
||||||
if r.Header.Get("X-Forwarded-Host") != "" {
|
|
||||||
domain = r.Header.Get("X-Forwarded-Host")
|
|
||||||
} else if !r.URL.IsAbs() {
|
|
||||||
domain = r.Host
|
|
||||||
} else {
|
|
||||||
domain = r.URL.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
// right trim port
|
|
||||||
portIndex := strings.Index(domain, ":")
|
|
||||||
if portIndex != -1 {
|
|
||||||
domain = domain[:portIndex]
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.domainMapping != nil {
|
|
||||||
if value, ok := s.domainMapping[domain]; ok {
|
|
||||||
domain = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.HasDomain(domain) {
|
|
||||||
return "", errors.New("invalid domain: " + domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
return domain, nil
|
|
||||||
}
|
|
||||||
88
net/http/middleware/tokenprovider.go
Normal file
88
net/http/middleware/tokenprovider.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
keelhttp "github.com/foomo/keel/net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenProvider func(r *http.Request) (string, error)
|
||||||
|
|
||||||
|
type (
|
||||||
|
HeaderTokenProviderOptions struct {
|
||||||
|
Prefix string
|
||||||
|
Header string
|
||||||
|
}
|
||||||
|
HeaderTokenProviderOption func(*HeaderTokenProviderOptions)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetDefaultHeaderTokenOptions returns the default options
|
||||||
|
func GetDefaultHeaderTokenOptions() HeaderTokenProviderOptions {
|
||||||
|
return HeaderTokenProviderOptions{
|
||||||
|
Prefix: keelhttp.HeaderValueAuthorizationPrefix,
|
||||||
|
Header: keelhttp.HeaderAuthorization,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeaderTokenProviderWithPrefix middleware option
|
||||||
|
func HeaderTokenProviderWithPrefix(v string) HeaderTokenProviderOption {
|
||||||
|
return func(o *HeaderTokenProviderOptions) {
|
||||||
|
o.Prefix = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeaderTokenProviderWithHeader middleware option
|
||||||
|
func HeaderTokenProviderWithHeader(v string) HeaderTokenProviderOption {
|
||||||
|
return func(o *HeaderTokenProviderOptions) {
|
||||||
|
o.Header = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HeaderTokenProvider(opts ...HeaderTokenProviderOption) TokenProvider {
|
||||||
|
options := GetDefaultHeaderTokenOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt != nil {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(r *http.Request) (string, error) {
|
||||||
|
if value := r.Header.Get(options.Header); value == "" {
|
||||||
|
return "", nil
|
||||||
|
} else if !strings.HasPrefix(value, options.Prefix) {
|
||||||
|
return "", errors.New("malformed bearer token")
|
||||||
|
} else {
|
||||||
|
return strings.TrimPrefix(value, options.Prefix), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
CookieTokenProviderOptions struct{}
|
||||||
|
CookieTokenProviderOption func(*CookieTokenProviderOptions)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetDefaultCookieTokenOptions returns the default options
|
||||||
|
func GetDefaultCookieTokenOptions() CookieTokenProviderOptions {
|
||||||
|
return CookieTokenProviderOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func CookieTokenProvider(cookieName string, opts ...CookieTokenProviderOption) TokenProvider {
|
||||||
|
options := GetDefaultCookieTokenOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt != nil {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(r *http.Request) (string, error) {
|
||||||
|
if cookie, err := r.Cookie(cookieName); errors.Is(err, http.ErrNoCookie) {
|
||||||
|
return "", nil
|
||||||
|
} else if err != nil {
|
||||||
|
return "", errors.New("failed to retrieve cookie")
|
||||||
|
} else {
|
||||||
|
return cookie.Value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"github.com/foomo/keel/log"
|
"github.com/foomo/keel/log"
|
||||||
|
keelhttp "github.com/foomo/keel/net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// InternalServerError http response
|
// InternalServerError http response
|
||||||
@ -31,7 +32,8 @@ func NotFoundServerError(l *zap.Logger, w http.ResponseWriter, r *http.Request,
|
|||||||
// ServerError http response
|
// ServerError http response
|
||||||
func ServerError(l *zap.Logger, w http.ResponseWriter, r *http.Request, code int, err error) {
|
func ServerError(l *zap.Logger, w http.ResponseWriter, r *http.Request, code int, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Configure(l).HTTPRequest(r).Error(err).Logger().Error("http server error", zap.Int("code", code))
|
log.WithHTTPRequest(l, r).Error("http server error", log.FError(err), log.FHTTPStatusCode(code))
|
||||||
http.Error(w, http.StatusText(code), code) // TODO enrich headers
|
w.Header().Set(keelhttp.HeaderXError, err.Error())
|
||||||
|
http.Error(w, http.StatusText(code), code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
30
utils/net/http/request.go
Normal file
30
utils/net/http/request.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package httputils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetRequestHost returns the request's host
|
||||||
|
func GetRequestHost(r *http.Request) string {
|
||||||
|
var host string
|
||||||
|
switch {
|
||||||
|
case r.Header.Get("X-Forwarded-Host") != "":
|
||||||
|
host = r.Header.Get("X-Forwarded-Host")
|
||||||
|
case !r.URL.IsAbs():
|
||||||
|
host = r.Host
|
||||||
|
default:
|
||||||
|
host = r.URL.Host
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestDomain returns the request's domain
|
||||||
|
func GetRequestDomain(r *http.Request) string {
|
||||||
|
domain := GetRequestHost(r)
|
||||||
|
// right trim port
|
||||||
|
if portIndex := strings.Index(domain, ":"); portIndex != -1 {
|
||||||
|
domain = domain[:portIndex]
|
||||||
|
}
|
||||||
|
return domain
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user