mirror of
https://github.com/foomo/keel.git
synced 2025-10-16 12:35:34 +00:00
wip: generic session handler
Co-authored-by: franklin <franklinkim@users.noreply.github.com>
This commit is contained in:
parent
c321e7dff6
commit
1ed4f47228
47
examples/sessionid/main.go
Normal file
47
examples/sessionid/main.go
Normal file
@ -0,0 +1,47 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/foomo/keel"
|
||||
"github.com/foomo/keel/log"
|
||||
"github.com/foomo/keel/net/http/middleware"
|
||||
)
|
||||
|
||||
func main() {
|
||||
svr := keel.NewServer()
|
||||
l := svr.Logger()
|
||||
|
||||
domains := []string{"*.example.com"}
|
||||
|
||||
domainMapping := map[string]string{
|
||||
"foo.example.com": "bar.example.com",
|
||||
}
|
||||
|
||||
domainProvider := middleware.MappingDomainProvider(domains, domainMapping)
|
||||
|
||||
svr.AddService(
|
||||
keel.NewServiceHTTP(
|
||||
log.WithServiceName(l, "demo"),
|
||||
":8080",
|
||||
newService(),
|
||||
middleware.SessionID(
|
||||
middleware.SessionIDWithCookieDomainProvider(
|
||||
domainProvider,
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
svr.Run()
|
||||
}
|
||||
|
||||
func newService() *http.ServeMux {
|
||||
s := http.NewServeMux()
|
||||
s.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Hello World!"))
|
||||
})
|
||||
return s
|
||||
}
|
||||
@ -39,12 +39,19 @@ const (
|
||||
|
||||
// HTTPRequestIDKey represents the HTTP request id if known (e.g from X-Request-ID).
|
||||
HTTPRequestIDKey = "http_request_id"
|
||||
|
||||
// HTTPSessionIDKey represents the HTTP session id if known (e.g from X-Session-ID).
|
||||
HTTPSessionIDKey = "http_session_id"
|
||||
)
|
||||
|
||||
func FHTTPRequestID(id string) zap.Field {
|
||||
return zap.String(HTTPRequestIDKey, id)
|
||||
}
|
||||
|
||||
func FHTTPSessionID(id string) zap.Field {
|
||||
return zap.String(HTTPSessionIDKey, id)
|
||||
}
|
||||
|
||||
func FHTTPRequestContentLength(bytes int64) zap.Field {
|
||||
return zap.Int64(HTTPRequestContentLengthKey, bytes)
|
||||
}
|
||||
|
||||
@ -42,6 +42,9 @@ func WithHTTPRequest(l *zap.Logger, r *http.Request) *zap.Logger {
|
||||
if id := r.Header.Get(httputils.HeaderXRequestID); id != "" {
|
||||
fields = append(fields, FHTTPRequestID(id))
|
||||
}
|
||||
if id := r.Header.Get(httputils.HeaderXSessionID); id != "" {
|
||||
fields = append(fields, FHTTPSessionID(id))
|
||||
}
|
||||
if r.TLS != nil {
|
||||
fields = append(fields, FHTTPScheme("https"))
|
||||
} else {
|
||||
|
||||
@ -26,6 +26,7 @@ const (
|
||||
HeaderXRealIP = "X-Real-IP"
|
||||
HeaderXRequestID = "X-Request-ID"
|
||||
HeaderXRequestedWith = "X-Requested-With"
|
||||
HeaderXSessionID = "X-Session-ID"
|
||||
HeaderServer = "Server"
|
||||
HeaderOrigin = "Origin"
|
||||
|
||||
|
||||
70
net/http/middleware/domainprovider.go
Normal file
70
net/http/middleware/domainprovider.go
Normal file
@ -0,0 +1,70 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDomainNotAllowed = errors.New("domain not allowed")
|
||||
)
|
||||
|
||||
type DomainProvider func(r *http.Request) (string, error)
|
||||
|
||||
var DefaultDomainProvider = func(domains []string) DomainProvider {
|
||||
return func(r *http.Request) (string, error) {
|
||||
domain := getDomainFromHTTPRequest(r)
|
||||
if !isDomainAllowed(domain, domains) {
|
||||
return "", ErrDomainNotAllowed
|
||||
}
|
||||
return domain, nil
|
||||
}
|
||||
}
|
||||
|
||||
var MappingDomainProvider = func(domains []string, mapping map[string]string) DomainProvider {
|
||||
return func(r *http.Request) (string, error) {
|
||||
domain := getDomainFromHTTPRequest(r)
|
||||
if value, ok := mapping[domain]; ok {
|
||||
domain = value
|
||||
}
|
||||
if !isDomainAllowed(domain, domains) {
|
||||
return "", errors.New("invalid domain: " + domain)
|
||||
}
|
||||
return domain, nil
|
||||
}
|
||||
}
|
||||
|
||||
// getDomainFromHTTPRequest helper
|
||||
func getDomainFromHTTPRequest(r *http.Request) string {
|
||||
var domain string
|
||||
if r.Header.Get("X-Forwarded-Host") != "" {
|
||||
domain = r.Header.Get("X-Forwarded-Host")
|
||||
} else if !r.URL.IsAbs() {
|
||||
domain = r.Host
|
||||
} else {
|
||||
domain = r.URL.Host
|
||||
}
|
||||
|
||||
// right trim port
|
||||
portIndex := strings.Index(domain, ":")
|
||||
if portIndex != -1 {
|
||||
domain = domain[:portIndex]
|
||||
}
|
||||
|
||||
return domain
|
||||
}
|
||||
|
||||
// isDomainAllowed helper
|
||||
func isDomainAllowed(domain string, domains []string) bool {
|
||||
if domains == nil || len(domains) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, value := range domains {
|
||||
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
141
net/http/middleware/sessionid.go
Normal file
141
net/http/middleware/sessionid.go
Normal file
@ -0,0 +1,141 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
httputils "github.com/foomo/keel/net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
SessionIDConfig struct {
|
||||
SetCookie bool // only create a session cookie if enabled
|
||||
CookieName string
|
||||
CookieSecure bool
|
||||
CookieHttpOnly bool
|
||||
CookiePath string
|
||||
CookieDomain string
|
||||
CookieDomains []string
|
||||
CookieDomainProvider DomainProvider
|
||||
Generator SessionIDGenerator
|
||||
}
|
||||
SessionIDOption func(*SessionIDConfig) error
|
||||
)
|
||||
|
||||
var DefaultSessionIDConfig = SessionIDConfig{
|
||||
SetCookie: false,
|
||||
CookieName: "keel-session",
|
||||
CookieSecure: true,
|
||||
CookieHttpOnly: true,
|
||||
CookiePath: "/",
|
||||
}
|
||||
|
||||
func SessionIDWithSetCookie(v bool) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.SetCookie = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionIDWithCookieName(v string) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.CookieName = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionIDWithCookieSecure(v bool) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.CookieSecure = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionIDWithCookieHttpOnly(v bool) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.CookieHttpOnly = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionIDWithCookiePath(v string) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.CookiePath = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionIDWithCookieDomain(v string) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.CookieDomain = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionIDWithCookieDomains(v []string) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.CookieDomains = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionIDWithCookieDomainProvider(v DomainProvider) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.CookieDomainProvider = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption {
|
||||
return func(c *SessionIDConfig) error {
|
||||
c.Generator = v
|
||||
}
|
||||
}
|
||||
|
||||
func SessionID(opts ...SessionIDOption) Middleware {
|
||||
config = DefaultSessionIDConfig
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
if err := opt(&opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.Generator == nil {
|
||||
config.Generator = DefaultSessionIDGenerator
|
||||
}
|
||||
if config.DomainProvider == nil {
|
||||
config.DomainProvider = DefaultDomainProvider(config.CookieDomains)
|
||||
}
|
||||
|
||||
return func(l *zap.Logger, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if id := r.Header.Get(httputils.HeaderXSessionID); id == "" {
|
||||
if cookie, err := r.Cookie(config.CookieName); errors.Is(err, http.ErrNoCookie) {
|
||||
if config.SetCookie {
|
||||
|
||||
domain, err := config.DomainProvider(r)
|
||||
if err != nil {
|
||||
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to resolve domain"))
|
||||
return
|
||||
}
|
||||
|
||||
id = config.Generator()
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: config.CookieName,
|
||||
Value: id,
|
||||
Path: config.CookiePath,
|
||||
HttpOnly: config.CookieHttpOnly,
|
||||
Secure: config.CookieSecure,
|
||||
Domain: domain,
|
||||
})
|
||||
}
|
||||
r.Header.Set(httputils.HeaderXSessionID, id)
|
||||
} else if err != nil {
|
||||
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read cookie"))
|
||||
return
|
||||
} else {
|
||||
r.Header.Set(httputils.HeaderXSessionID, cookie.Value)
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
9
net/http/middleware/sessioniogenerator.go
Normal file
9
net/http/middleware/sessioniogenerator.go
Normal file
@ -0,0 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type SessionIDGenerator func() string
|
||||
|
||||
func DefaultSessionIDGenerator() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
139
net/http/middleware/tmp.go
Normal file
139
net/http/middleware/tmp.go
Normal file
@ -0,0 +1,139 @@
|
||||
package cmrcsession
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cmrcjwt "github.com/bestbytes/commerce/pkg/jwt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
jwt *cmrcjwt.JWT
|
||||
secure bool
|
||||
domains []string
|
||||
domainMapping map[string]string
|
||||
}
|
||||
|
||||
func NewSession(jwt *cmrcjwt.JWT, domains []string, domainMapping map[string]string, secure bool) *Session {
|
||||
return &Session{
|
||||
jwt: jwt,
|
||||
secure: secure,
|
||||
domains: domains,
|
||||
domainMapping: domainMapping,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) HasDomain(domain string) bool {
|
||||
for _, value := range s.domains {
|
||||
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Session) GetCookieClaims(r *http.Request, name string, claims jwt.Claims) (jwt.Claims, error) {
|
||||
if cookie, err := r.Cookie(name); err != nil {
|
||||
return nil, err
|
||||
} else if token, err := s.jwt.ParseWithClaims(cookie.Value, claims); err != nil {
|
||||
return nil, err
|
||||
} else if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
} else {
|
||||
return token.Claims, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) SetCookieClaims(r *http.Request, w http.ResponseWriter, name string, claims jwt.Claims, lifetime time.Duration) error {
|
||||
domain, err := s.getDomain(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := s.jwt.GetSignedToken(claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
MaxAge: int(lifetime.Seconds()),
|
||||
HttpOnly: true,
|
||||
Secure: s.secure,
|
||||
Domain: domain,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) SetSessionCookie(r *http.Request, w http.ResponseWriter, name string, value string) error {
|
||||
domain, err := s.getDomain(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: s.secure,
|
||||
Domain: domain,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) DeleteCookie(r *http.Request, w http.ResponseWriter, name string) error {
|
||||
domain, err := s.getDomain(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: s.secure,
|
||||
Domain: domain,
|
||||
MaxAge: -1,
|
||||
Expires: time.Now().AddDate(0, 0, -1),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) getDomain(r *http.Request) (string, error) {
|
||||
var domain string
|
||||
if r.Header.Get("X-Forwarded-Host") != "" {
|
||||
domain = r.Header.Get("X-Forwarded-Host")
|
||||
} else if !r.URL.IsAbs() {
|
||||
domain = r.Host
|
||||
} else {
|
||||
domain = r.URL.Host
|
||||
}
|
||||
|
||||
// right trim port
|
||||
portIndex := strings.Index(domain, ":")
|
||||
if portIndex != -1 {
|
||||
domain = domain[:portIndex]
|
||||
}
|
||||
|
||||
if s.domainMapping != nil {
|
||||
if value, ok := s.domainMapping[domain]; ok {
|
||||
domain = value
|
||||
}
|
||||
}
|
||||
|
||||
if !s.HasDomain(domain) {
|
||||
return "", errors.New("invalid domain: " + domain)
|
||||
}
|
||||
|
||||
return domain, nil
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user