wip: generic session handler

Co-authored-by: franklin <franklinkim@users.noreply.github.com>
This commit is contained in:
Frederik Löffert 2021-09-03 16:36:41 +02:00
parent c321e7dff6
commit 1ed4f47228
10 changed files with 417 additions and 0 deletions

View File

@ -0,0 +1,47 @@
package main
import (
"net/http"
"os"
"github.com/foomo/keel"
"github.com/foomo/keel/log"
"github.com/foomo/keel/net/http/middleware"
)
func main() {
svr := keel.NewServer()
l := svr.Logger()
domains := []string{"*.example.com"}
domainMapping := map[string]string{
"foo.example.com": "bar.example.com",
}
domainProvider := middleware.MappingDomainProvider(domains, domainMapping)
svr.AddService(
keel.NewServiceHTTP(
log.WithServiceName(l, "demo"),
":8080",
newService(),
middleware.SessionID(
middleware.SessionIDWithCookieDomainProvider(
domainProvider,
)
)
),
)
svr.Run()
}
func newService() *http.ServeMux {
s := http.NewServeMux()
s.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello World!"))
})
return s
}

View File

@ -39,12 +39,19 @@ const (
// HTTPRequestIDKey represents the HTTP request id if known (e.g from X-Request-ID).
HTTPRequestIDKey = "http_request_id"
// HTTPSessionIDKey represents the HTTP session id if known (e.g from X-Session-ID).
HTTPSessionIDKey = "http_session_id"
)
func FHTTPRequestID(id string) zap.Field {
return zap.String(HTTPRequestIDKey, id)
}
func FHTTPSessionID(id string) zap.Field {
return zap.String(HTTPSessionIDKey, id)
}
func FHTTPRequestContentLength(bytes int64) zap.Field {
return zap.Int64(HTTPRequestContentLengthKey, bytes)
}

View File

@ -42,6 +42,9 @@ func WithHTTPRequest(l *zap.Logger, r *http.Request) *zap.Logger {
if id := r.Header.Get(httputils.HeaderXRequestID); id != "" {
fields = append(fields, FHTTPRequestID(id))
}
if id := r.Header.Get(httputils.HeaderXSessionID); id != "" {
fields = append(fields, FHTTPSessionID(id))
}
if r.TLS != nil {
fields = append(fields, FHTTPScheme("https"))
} else {

View File

@ -26,6 +26,7 @@ const (
HeaderXRealIP = "X-Real-IP"
HeaderXRequestID = "X-Request-ID"
HeaderXRequestedWith = "X-Requested-With"
HeaderXSessionID = "X-Session-ID"
HeaderServer = "Server"
HeaderOrigin = "Origin"

View File

@ -0,0 +1,70 @@
package middleware
import (
"net/http"
"strings"
"github.com/pkg/errors"
)
var (
ErrDomainNotAllowed = errors.New("domain not allowed")
)
type DomainProvider func(r *http.Request) (string, error)
var DefaultDomainProvider = func(domains []string) DomainProvider {
return func(r *http.Request) (string, error) {
domain := getDomainFromHTTPRequest(r)
if !isDomainAllowed(domain, domains) {
return "", ErrDomainNotAllowed
}
return domain, nil
}
}
var MappingDomainProvider = func(domains []string, mapping map[string]string) DomainProvider {
return func(r *http.Request) (string, error) {
domain := getDomainFromHTTPRequest(r)
if value, ok := mapping[domain]; ok {
domain = value
}
if !isDomainAllowed(domain, domains) {
return "", errors.New("invalid domain: " + domain)
}
return domain, nil
}
}
// getDomainFromHTTPRequest helper
func getDomainFromHTTPRequest(r *http.Request) string {
var domain string
if r.Header.Get("X-Forwarded-Host") != "" {
domain = r.Header.Get("X-Forwarded-Host")
} else if !r.URL.IsAbs() {
domain = r.Host
} else {
domain = r.URL.Host
}
// right trim port
portIndex := strings.Index(domain, ":")
if portIndex != -1 {
domain = domain[:portIndex]
}
return domain
}
// isDomainAllowed helper
func isDomainAllowed(domain string, domains []string) bool {
if domains == nil || len(domains) == 0 {
return true
}
for _, value := range domains {
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
return true
}
}
return false
}

View File

@ -0,0 +1,141 @@
package middleware
import (
"net/http"
"github.com/pkg/errors"
"go.uber.org/zap"
httputils "github.com/foomo/keel/net/http"
)
type (
SessionIDConfig struct {
SetCookie bool // only create a session cookie if enabled
CookieName string
CookieSecure bool
CookieHttpOnly bool
CookiePath string
CookieDomain string
CookieDomains []string
CookieDomainProvider DomainProvider
Generator SessionIDGenerator
}
SessionIDOption func(*SessionIDConfig) error
)
var DefaultSessionIDConfig = SessionIDConfig{
SetCookie: false,
CookieName: "keel-session",
CookieSecure: true,
CookieHttpOnly: true,
CookiePath: "/",
}
func SessionIDWithSetCookie(v bool) SessionIDOption {
return func(c *SessionIDConfig) error {
c.SetCookie = v
}
}
func SessionIDWithCookieName(v string) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieName = v
}
}
func SessionIDWithCookieSecure(v bool) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieSecure = v
}
}
func SessionIDWithCookieHttpOnly(v bool) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieHttpOnly = v
}
}
func SessionIDWithCookiePath(v string) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookiePath = v
}
}
func SessionIDWithCookieDomain(v string) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieDomain = v
}
}
func SessionIDWithCookieDomains(v []string) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieDomains = v
}
}
func SessionIDWithCookieDomainProvider(v DomainProvider) SessionIDOption {
return func(c *SessionIDConfig) error {
c.CookieDomainProvider = v
}
}
func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption {
return func(c *SessionIDConfig) error {
c.Generator = v
}
}
func SessionID(opts ...SessionIDOption) Middleware {
config = DefaultSessionIDConfig
for _, opt := range opts {
if opt != nil {
if err := opt(&opts); err != nil {
return nil, err
}
}
}
if config.Generator == nil {
config.Generator = DefaultSessionIDGenerator
}
if config.DomainProvider == nil {
config.DomainProvider = DefaultDomainProvider(config.CookieDomains)
}
return func(l *zap.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if id := r.Header.Get(httputils.HeaderXSessionID); id == "" {
if cookie, err := r.Cookie(config.CookieName); errors.Is(err, http.ErrNoCookie) {
if config.SetCookie {
domain, err := config.DomainProvider(r)
if err != nil {
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to resolve domain"))
return
}
id = config.Generator()
http.SetCookie(w, &http.Cookie{
Name: config.CookieName,
Value: id,
Path: config.CookiePath,
HttpOnly: config.CookieHttpOnly,
Secure: config.CookieSecure,
Domain: domain,
})
}
r.Header.Set(httputils.HeaderXSessionID, id)
} else if err != nil {
httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read cookie"))
return
} else {
r.Header.Set(httputils.HeaderXSessionID, cookie.Value)
}
}
next.ServeHTTP(w, r)
})
}
}

View File

@ -0,0 +1,9 @@
package middleware
import "github.com/google/uuid"
type SessionIDGenerator func() string
func DefaultSessionIDGenerator() string {
return uuid.New().String()
}

139
net/http/middleware/tmp.go Normal file
View File

@ -0,0 +1,139 @@
package cmrcsession
import (
"net/http"
"strings"
"time"
cmrcjwt "github.com/bestbytes/commerce/pkg/jwt"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
)
type Session struct {
jwt *cmrcjwt.JWT
secure bool
domains []string
domainMapping map[string]string
}
func NewSession(jwt *cmrcjwt.JWT, domains []string, domainMapping map[string]string, secure bool) *Session {
return &Session{
jwt: jwt,
secure: secure,
domains: domains,
domainMapping: domainMapping,
}
}
func (s *Session) HasDomain(domain string) bool {
for _, value := range s.domains {
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
return true
}
}
return false
}
func (s *Session) GetCookieClaims(r *http.Request, name string, claims jwt.Claims) (jwt.Claims, error) {
if cookie, err := r.Cookie(name); err != nil {
return nil, err
} else if token, err := s.jwt.ParseWithClaims(cookie.Value, claims); err != nil {
return nil, err
} else if !token.Valid {
return nil, errors.New("invalid token")
} else {
return token.Claims, nil
}
}
func (s *Session) SetCookieClaims(r *http.Request, w http.ResponseWriter, name string, claims jwt.Claims, lifetime time.Duration) error {
domain, err := s.getDomain(r)
if err != nil {
return err
}
token, err := s.jwt.GetSignedToken(claims)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: token,
Path: "/",
MaxAge: int(lifetime.Seconds()),
HttpOnly: true,
Secure: s.secure,
Domain: domain,
})
return nil
}
func (s *Session) SetSessionCookie(r *http.Request, w http.ResponseWriter, name string, value string) error {
domain, err := s.getDomain(r)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
Path: "/",
HttpOnly: true,
Secure: s.secure,
Domain: domain,
})
return nil
}
func (s *Session) DeleteCookie(r *http.Request, w http.ResponseWriter, name string) error {
domain, err := s.getDomain(r)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
HttpOnly: true,
Secure: s.secure,
Domain: domain,
MaxAge: -1,
Expires: time.Now().AddDate(0, 0, -1),
})
return nil
}
func (s *Session) getDomain(r *http.Request) (string, error) {
var domain string
if r.Header.Get("X-Forwarded-Host") != "" {
domain = r.Header.Get("X-Forwarded-Host")
} else if !r.URL.IsAbs() {
domain = r.Host
} else {
domain = r.URL.Host
}
// right trim port
portIndex := strings.Index(domain, ":")
if portIndex != -1 {
domain = domain[:portIndex]
}
if s.domainMapping != nil {
if value, ok := s.domainMapping[domain]; ok {
domain = value
}
}
if !s.HasDomain(domain) {
return "", errors.New("invalid domain: " + domain)
}
return domain, nil
}