diff --git a/examples/sessionid/main.go b/examples/sessionid/main.go new file mode 100644 index 0000000..5836b60 --- /dev/null +++ b/examples/sessionid/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "net/http" + "os" + + "github.com/foomo/keel" + "github.com/foomo/keel/log" + "github.com/foomo/keel/net/http/middleware" +) + +func main() { + svr := keel.NewServer() + l := svr.Logger() + + domains := []string{"*.example.com"} + + domainMapping := map[string]string{ + "foo.example.com": "bar.example.com", + } + + domainProvider := middleware.MappingDomainProvider(domains, domainMapping) + + svr.AddService( + keel.NewServiceHTTP( + log.WithServiceName(l, "demo"), + ":8080", + newService(), + middleware.SessionID( + middleware.SessionIDWithCookieDomainProvider( + domainProvider, + ) + ) + ), + ) + + svr.Run() +} + +func newService() *http.ServeMux { + s := http.NewServeMux() + s.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Hello World!")) + }) + return s +} diff --git a/log/fields_http.go b/log/fields_http.go index 748d18e..5281003 100644 --- a/log/fields_http.go +++ b/log/fields_http.go @@ -39,12 +39,19 @@ const ( // HTTPRequestIDKey represents the HTTP request id if known (e.g from X-Request-ID). HTTPRequestIDKey = "http_request_id" + + // HTTPSessionIDKey represents the HTTP session id if known (e.g from X-Session-ID). + HTTPSessionIDKey = "http_session_id" ) func FHTTPRequestID(id string) zap.Field { return zap.String(HTTPRequestIDKey, id) } +func FHTTPSessionID(id string) zap.Field { + return zap.String(HTTPSessionIDKey, id) +} + func FHTTPRequestContentLength(bytes int64) zap.Field { return zap.Int64(HTTPRequestContentLengthKey, bytes) } diff --git a/log/with.go b/log/with.go index 646f3c1..17a8d75 100644 --- a/log/with.go +++ b/log/with.go @@ -42,6 +42,9 @@ func WithHTTPRequest(l *zap.Logger, r *http.Request) *zap.Logger { if id := r.Header.Get(httputils.HeaderXRequestID); id != "" { fields = append(fields, FHTTPRequestID(id)) } + if id := r.Header.Get(httputils.HeaderXSessionID); id != "" { + fields = append(fields, FHTTPSessionID(id)) + } if r.TLS != nil { fields = append(fields, FHTTPScheme("https")) } else { diff --git a/net/http/header.go b/net/http/header.go index b8fa846..20f5437 100644 --- a/net/http/header.go +++ b/net/http/header.go @@ -26,6 +26,7 @@ const ( HeaderXRealIP = "X-Real-IP" HeaderXRequestID = "X-Request-ID" HeaderXRequestedWith = "X-Requested-With" + HeaderXSessionID = "X-Session-ID" HeaderServer = "Server" HeaderOrigin = "Origin" diff --git a/net/http/middleware/domainprovider.go b/net/http/middleware/domainprovider.go new file mode 100644 index 0000000..1829f04 --- /dev/null +++ b/net/http/middleware/domainprovider.go @@ -0,0 +1,70 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/pkg/errors" +) + +var ( + ErrDomainNotAllowed = errors.New("domain not allowed") +) + +type DomainProvider func(r *http.Request) (string, error) + +var DefaultDomainProvider = func(domains []string) DomainProvider { + return func(r *http.Request) (string, error) { + domain := getDomainFromHTTPRequest(r) + if !isDomainAllowed(domain, domains) { + return "", ErrDomainNotAllowed + } + return domain, nil + } +} + +var MappingDomainProvider = func(domains []string, mapping map[string]string) DomainProvider { + return func(r *http.Request) (string, error) { + domain := getDomainFromHTTPRequest(r) + if value, ok := mapping[domain]; ok { + domain = value + } + if !isDomainAllowed(domain, domains) { + return "", errors.New("invalid domain: " + domain) + } + return domain, nil + } +} + +// getDomainFromHTTPRequest helper +func getDomainFromHTTPRequest(r *http.Request) string { + var domain string + if r.Header.Get("X-Forwarded-Host") != "" { + domain = r.Header.Get("X-Forwarded-Host") + } else if !r.URL.IsAbs() { + domain = r.Host + } else { + domain = r.URL.Host + } + + // right trim port + portIndex := strings.Index(domain, ":") + if portIndex != -1 { + domain = domain[:portIndex] + } + + return domain +} + +// isDomainAllowed helper +func isDomainAllowed(domain string, domains []string) bool { + if domains == nil || len(domains) == 0 { + return true + } + for _, value := range domains { + if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) { + return true + } + } + return false +} diff --git a/net/http/middleware/request_id.go b/net/http/middleware/requestid.go similarity index 100% rename from net/http/middleware/request_id.go rename to net/http/middleware/requestid.go diff --git a/net/http/middleware/response_writer.go b/net/http/middleware/responsewriter.go similarity index 100% rename from net/http/middleware/response_writer.go rename to net/http/middleware/responsewriter.go diff --git a/net/http/middleware/sessionid.go b/net/http/middleware/sessionid.go new file mode 100644 index 0000000..ac0c8ae --- /dev/null +++ b/net/http/middleware/sessionid.go @@ -0,0 +1,141 @@ +package middleware + +import ( + "net/http" + + "github.com/pkg/errors" + "go.uber.org/zap" + + httputils "github.com/foomo/keel/net/http" +) + +type ( + SessionIDConfig struct { + SetCookie bool // only create a session cookie if enabled + CookieName string + CookieSecure bool + CookieHttpOnly bool + CookiePath string + CookieDomain string + CookieDomains []string + CookieDomainProvider DomainProvider + Generator SessionIDGenerator + } + SessionIDOption func(*SessionIDConfig) error +) + +var DefaultSessionIDConfig = SessionIDConfig{ + SetCookie: false, + CookieName: "keel-session", + CookieSecure: true, + CookieHttpOnly: true, + CookiePath: "/", +} + +func SessionIDWithSetCookie(v bool) SessionIDOption { + return func(c *SessionIDConfig) error { + c.SetCookie = v + } +} + +func SessionIDWithCookieName(v string) SessionIDOption { + return func(c *SessionIDConfig) error { + c.CookieName = v + } +} + +func SessionIDWithCookieSecure(v bool) SessionIDOption { + return func(c *SessionIDConfig) error { + c.CookieSecure = v + } +} + +func SessionIDWithCookieHttpOnly(v bool) SessionIDOption { + return func(c *SessionIDConfig) error { + c.CookieHttpOnly = v + } +} + +func SessionIDWithCookiePath(v string) SessionIDOption { + return func(c *SessionIDConfig) error { + c.CookiePath = v + } +} + +func SessionIDWithCookieDomain(v string) SessionIDOption { + return func(c *SessionIDConfig) error { + c.CookieDomain = v + } +} + +func SessionIDWithCookieDomains(v []string) SessionIDOption { + return func(c *SessionIDConfig) error { + c.CookieDomains = v + } +} + +func SessionIDWithCookieDomainProvider(v DomainProvider) SessionIDOption { + return func(c *SessionIDConfig) error { + c.CookieDomainProvider = v + } +} + +func SessionIDWithGenerator(v SessionIDGenerator) SessionIDOption { + return func(c *SessionIDConfig) error { + c.Generator = v + } +} + +func SessionID(opts ...SessionIDOption) Middleware { + config = DefaultSessionIDConfig + for _, opt := range opts { + if opt != nil { + if err := opt(&opts); err != nil { + return nil, err + } + } + } + if config.Generator == nil { + config.Generator = DefaultSessionIDGenerator + } + if config.DomainProvider == nil { + config.DomainProvider = DefaultDomainProvider(config.CookieDomains) + } + + return func(l *zap.Logger, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if id := r.Header.Get(httputils.HeaderXSessionID); id == "" { + if cookie, err := r.Cookie(config.CookieName); errors.Is(err, http.ErrNoCookie) { + if config.SetCookie { + + domain, err := config.DomainProvider(r) + if err != nil { + httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to resolve domain")) + return + } + + id = config.Generator() + + http.SetCookie(w, &http.Cookie{ + Name: config.CookieName, + Value: id, + Path: config.CookiePath, + HttpOnly: config.CookieHttpOnly, + Secure: config.CookieSecure, + Domain: domain, + }) + } + r.Header.Set(httputils.HeaderXSessionID, id) + } else if err != nil { + httputils.InternalServerError(l, w, r, errors.Wrap(err, "failed to read cookie")) + return + } else { + r.Header.Set(httputils.HeaderXSessionID, cookie.Value) + } + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/net/http/middleware/sessioniogenerator.go b/net/http/middleware/sessioniogenerator.go new file mode 100644 index 0000000..5b3826d --- /dev/null +++ b/net/http/middleware/sessioniogenerator.go @@ -0,0 +1,9 @@ +package middleware + +import "github.com/google/uuid" + +type SessionIDGenerator func() string + +func DefaultSessionIDGenerator() string { + return uuid.New().String() +} diff --git a/net/http/middleware/tmp.go b/net/http/middleware/tmp.go new file mode 100644 index 0000000..ada6317 --- /dev/null +++ b/net/http/middleware/tmp.go @@ -0,0 +1,139 @@ +package cmrcsession + +import ( + "net/http" + "strings" + "time" + + cmrcjwt "github.com/bestbytes/commerce/pkg/jwt" + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" +) + +type Session struct { + jwt *cmrcjwt.JWT + secure bool + domains []string + domainMapping map[string]string +} + +func NewSession(jwt *cmrcjwt.JWT, domains []string, domainMapping map[string]string, secure bool) *Session { + return &Session{ + jwt: jwt, + secure: secure, + domains: domains, + domainMapping: domainMapping, + } +} + +func (s *Session) HasDomain(domain string) bool { + for _, value := range s.domains { + if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) { + return true + } + } + return false +} + +func (s *Session) GetCookieClaims(r *http.Request, name string, claims jwt.Claims) (jwt.Claims, error) { + if cookie, err := r.Cookie(name); err != nil { + return nil, err + } else if token, err := s.jwt.ParseWithClaims(cookie.Value, claims); err != nil { + return nil, err + } else if !token.Valid { + return nil, errors.New("invalid token") + } else { + return token.Claims, nil + } +} + +func (s *Session) SetCookieClaims(r *http.Request, w http.ResponseWriter, name string, claims jwt.Claims, lifetime time.Duration) error { + domain, err := s.getDomain(r) + if err != nil { + return err + } + + token, err := s.jwt.GetSignedToken(claims) + if err != nil { + return err + } + + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: token, + Path: "/", + MaxAge: int(lifetime.Seconds()), + HttpOnly: true, + Secure: s.secure, + Domain: domain, + }) + + return nil +} + +func (s *Session) SetSessionCookie(r *http.Request, w http.ResponseWriter, name string, value string) error { + domain, err := s.getDomain(r) + if err != nil { + return err + } + + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: value, + Path: "/", + HttpOnly: true, + Secure: s.secure, + Domain: domain, + }) + + return nil +} + +func (s *Session) DeleteCookie(r *http.Request, w http.ResponseWriter, name string) error { + domain, err := s.getDomain(r) + if err != nil { + return err + } + + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + Path: "/", + HttpOnly: true, + Secure: s.secure, + Domain: domain, + MaxAge: -1, + Expires: time.Now().AddDate(0, 0, -1), + }) + + return nil +} + +func (s *Session) getDomain(r *http.Request) (string, error) { + var domain string + if r.Header.Get("X-Forwarded-Host") != "" { + domain = r.Header.Get("X-Forwarded-Host") + } else if !r.URL.IsAbs() { + domain = r.Host + } else { + domain = r.URL.Host + } + + // right trim port + portIndex := strings.Index(domain, ":") + if portIndex != -1 { + domain = domain[:portIndex] + } + + if s.domainMapping != nil { + if value, ok := s.domainMapping[domain]; ok { + domain = value + } + } + + if !s.HasDomain(domain) { + return "", errors.New("invalid domain: " + domain) + } + + return domain, nil +}