mirror of
https://github.com/foomo/keel.git
synced 2025-10-16 12:35:34 +00:00
148 lines
3.7 KiB
Go
148 lines
3.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/pkg/errors"
|
|
"go.uber.org/zap"
|
|
|
|
keelhttp "github.com/foomo/keel/net/http"
|
|
"github.com/foomo/keel/net/http/cookie"
|
|
httputils "github.com/foomo/keel/utils/net/http"
|
|
)
|
|
|
|
type (
|
|
contextKey string
|
|
SessionIDOptions struct {
|
|
// Header to look up the session id
|
|
Header string
|
|
// Cookie how to set the cookie
|
|
Cookie cookie.Cookie
|
|
// Generator for the session ids
|
|
Generator SessionIDGenerator
|
|
// SetCookie if true it will create a cookie if not exists
|
|
SetCookie bool
|
|
// SetHeader if true it will set add a request header
|
|
SetHeader bool
|
|
// SetContext if true it will set the context key
|
|
SetContext bool
|
|
}
|
|
SessionIDOption func(*SessionIDOptions)
|
|
SessionIDGenerator func() string
|
|
)
|
|
|
|
const (
|
|
ContextKeySessionID contextKey = "sessionId"
|
|
|
|
DefaultSessionIDCookieName = "keel-session"
|
|
)
|
|
|
|
// DefaultSessionIDGenerator function
|
|
func DefaultSessionIDGenerator() string {
|
|
return uuid.New().String()
|
|
}
|
|
|
|
// GetDefaultSessionIDOptions returns the default options
|
|
func GetDefaultSessionIDOptions() SessionIDOptions {
|
|
return SessionIDOptions{
|
|
Header: keelhttp.HeaderXSessionID,
|
|
Cookie: cookie.New(DefaultSessionIDCookieName),
|
|
Generator: DefaultSessionIDGenerator,
|
|
SetCookie: false,
|
|
SetHeader: true,
|
|
SetContext: true,
|
|
}
|
|
}
|
|
|
|
func SessionIDWithHeader(v string) SessionIDOption {
|
|
return func(o *SessionIDOptions) {
|
|
o.Header = v
|
|
}
|
|
}
|
|
|
|
// SessionIDWithSetCookie middleware option
|
|
func SessionIDWithSetCookie(v bool) SessionIDOption {
|
|
return func(o *SessionIDOptions) {
|
|
o.SetCookie = v
|
|
}
|
|
}
|
|
|
|
// SessionIDWithSetHeader middleware option
|
|
func SessionIDWithSetHeader(v bool) SessionIDOption {
|
|
return func(o *SessionIDOptions) {
|
|
o.SetHeader = v
|
|
}
|
|
}
|
|
|
|
// SessionIDWithSetContext middleware option
|
|
func SessionIDWithSetContext(v bool) SessionIDOption {
|
|
return func(o *SessionIDOptions) {
|
|
o.SetContext = v
|
|
}
|
|
}
|
|
|
|
// SessionIDWithCookie middleware option
|
|
func SessionIDWithCookie(v cookie.Cookie) SessionIDOption {
|
|
return func(o *SessionIDOptions) {
|
|
o.Cookie = v
|
|
}
|
|
}
|
|
|
|
// SessionIDWithGenerator middleware option
|
|
func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption {
|
|
return func(o *SessionIDOptions) {
|
|
o.Generator = v
|
|
}
|
|
}
|
|
|
|
// SessionID middleware
|
|
func SessionID(opts ...SessionIDOption) Middleware {
|
|
options := GetDefaultSessionIDOptions()
|
|
for _, opt := range opts {
|
|
if opt != nil {
|
|
opt(&options)
|
|
}
|
|
}
|
|
return SessionIDWithOptions(options)
|
|
}
|
|
|
|
// SessionIDWithOptions middleware
|
|
func SessionIDWithOptions(opts SessionIDOptions) Middleware {
|
|
return func(l *zap.Logger, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
var sessionID string
|
|
if c, err := opts.Cookie.Get(r); errors.Is(err, http.ErrNoCookie) && !opts.SetCookie {
|
|
// do nothing
|
|
} else if errors.Is(err, http.ErrNoCookie) && opts.SetCookie {
|
|
sessionID = opts.Generator()
|
|
if err := opts.Cookie.Set(w, r, sessionID); err != nil {
|
|
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to set session id cookie"))
|
|
return
|
|
}
|
|
} else if err != nil {
|
|
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read session id cookie"))
|
|
return
|
|
} else {
|
|
sessionID = c.Value
|
|
}
|
|
if sessionID != "" && opts.SetHeader {
|
|
r.Header.Set(opts.Header, sessionID)
|
|
}
|
|
if sessionID != "" && opts.SetContext {
|
|
r = r.WithContext(context.WithValue(r.Context(), ContextKeySessionID, sessionID))
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// SessionIDFromContext helper
|
|
func SessionIDFromContext(ctx context.Context) string {
|
|
if value, ok := ctx.Value(ContextKeySessionID).(string); ok {
|
|
return value
|
|
}
|
|
return ""
|
|
}
|