mirror of
https://github.com/foomo/keel.git
synced 2025-10-16 12:35:34 +00:00
140 lines
2.9 KiB
Go
140 lines
2.9 KiB
Go
package cmrcsession
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
cmrcjwt "github.com/bestbytes/commerce/pkg/jwt"
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type Session struct {
|
|
jwt *cmrcjwt.JWT
|
|
secure bool
|
|
domains []string
|
|
domainMapping map[string]string
|
|
}
|
|
|
|
func NewSession(jwt *cmrcjwt.JWT, domains []string, domainMapping map[string]string, secure bool) *Session {
|
|
return &Session{
|
|
jwt: jwt,
|
|
secure: secure,
|
|
domains: domains,
|
|
domainMapping: domainMapping,
|
|
}
|
|
}
|
|
|
|
func (s *Session) HasDomain(domain string) bool {
|
|
for _, value := range s.domains {
|
|
if domain == value || (strings.HasPrefix(value, "*.") && strings.HasSuffix(domain, value[2:])) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *Session) GetCookieClaims(r *http.Request, name string, claims jwt.Claims) (jwt.Claims, error) {
|
|
if cookie, err := r.Cookie(name); err != nil {
|
|
return nil, err
|
|
} else if token, err := s.jwt.ParseWithClaims(cookie.Value, claims); err != nil {
|
|
return nil, err
|
|
} else if !token.Valid {
|
|
return nil, errors.New("invalid token")
|
|
} else {
|
|
return token.Claims, nil
|
|
}
|
|
}
|
|
|
|
func (s *Session) SetCookieClaims(r *http.Request, w http.ResponseWriter, name string, claims jwt.Claims, lifetime time.Duration) error {
|
|
domain, err := s.getDomain(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
token, err := s.jwt.GetSignedToken(claims)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: name,
|
|
Value: token,
|
|
Path: "/",
|
|
MaxAge: int(lifetime.Seconds()),
|
|
HttpOnly: true,
|
|
Secure: s.secure,
|
|
Domain: domain,
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) SetSessionCookie(r *http.Request, w http.ResponseWriter, name string, value string) error {
|
|
domain, err := s.getDomain(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: name,
|
|
Value: value,
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: s.secure,
|
|
Domain: domain,
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) DeleteCookie(r *http.Request, w http.ResponseWriter, name string) error {
|
|
domain, err := s.getDomain(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: name,
|
|
Value: "",
|
|
Path: "/",
|
|
HttpOnly: true,
|
|
Secure: s.secure,
|
|
Domain: domain,
|
|
MaxAge: -1,
|
|
Expires: time.Now().AddDate(0, 0, -1),
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) getDomain(r *http.Request) (string, error) {
|
|
var domain string
|
|
if r.Header.Get("X-Forwarded-Host") != "" {
|
|
domain = r.Header.Get("X-Forwarded-Host")
|
|
} else if !r.URL.IsAbs() {
|
|
domain = r.Host
|
|
} else {
|
|
domain = r.URL.Host
|
|
}
|
|
|
|
// right trim port
|
|
portIndex := strings.Index(domain, ":")
|
|
if portIndex != -1 {
|
|
domain = domain[:portIndex]
|
|
}
|
|
|
|
if s.domainMapping != nil {
|
|
if value, ok := s.domainMapping[domain]; ok {
|
|
domain = value
|
|
}
|
|
}
|
|
|
|
if !s.HasDomain(domain) {
|
|
return "", errors.New("invalid domain: " + domain)
|
|
}
|
|
|
|
return domain, nil
|
|
}
|