mirror of
https://github.com/foomo/keel.git
synced 2025-10-16 12:35:34 +00:00
248 lines
6.0 KiB
Go
248 lines
6.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
keelhttp "github.com/foomo/keel/net/http"
|
|
)
|
|
|
|
type (
|
|
CORSOptions struct {
|
|
AllowOrigins []string
|
|
AllowMethods []string
|
|
AllowHeaders []string
|
|
AllowCredentials bool
|
|
ExposeHeaders []string
|
|
MaxAge int
|
|
}
|
|
CORSOption func(*CORSOptions)
|
|
)
|
|
|
|
// GetDefaultCORSOptions returns the default options
|
|
func GetDefaultCORSOptions() CORSOptions {
|
|
return CORSOptions{
|
|
AllowOrigins: []string{"*"},
|
|
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
|
}
|
|
}
|
|
|
|
// CORSWithAllowOrigins middleware option
|
|
func CORSWithAllowOrigins(v ...string) CORSOption {
|
|
return func(o *CORSOptions) {
|
|
o.AllowOrigins = v
|
|
}
|
|
}
|
|
|
|
// CORSWithAllowMethods middleware option
|
|
func CORSWithAllowMethods(v ...string) CORSOption {
|
|
return func(o *CORSOptions) {
|
|
o.AllowMethods = v
|
|
}
|
|
}
|
|
|
|
// CORSWithAllowHeaders middleware option
|
|
func CORSWithAllowHeaders(v ...string) CORSOption {
|
|
return func(o *CORSOptions) {
|
|
o.AllowHeaders = v
|
|
}
|
|
}
|
|
|
|
// CORSWithAllowCredentials middleware option
|
|
func CORSWithAllowCredentials(v bool) CORSOption {
|
|
return func(o *CORSOptions) {
|
|
o.AllowCredentials = v
|
|
}
|
|
}
|
|
|
|
// CORSWithExposeHeaders middleware option
|
|
func CORSWithExposeHeaders(v ...string) CORSOption {
|
|
return func(o *CORSOptions) {
|
|
o.ExposeHeaders = v
|
|
}
|
|
}
|
|
|
|
// CORSWithMaxAge middleware option
|
|
func CORSWithMaxAge(v int) CORSOption {
|
|
return func(o *CORSOptions) {
|
|
o.MaxAge = v
|
|
}
|
|
}
|
|
|
|
// CORS middleware
|
|
func CORS(opts ...CORSOption) Middleware {
|
|
options := GetDefaultCORSOptions()
|
|
for _, opt := range opts {
|
|
if opt != nil {
|
|
opt(&options)
|
|
}
|
|
}
|
|
return CORSWithOptions(options)
|
|
}
|
|
|
|
// CORSWithOptions middleware
|
|
func CORSWithOptions(opts CORSOptions) Middleware {
|
|
allowOriginPatterns := make([]string, len(opts.AllowOrigins))
|
|
for i, origin := range opts.AllowOrigins {
|
|
pattern := regexp.QuoteMeta(origin)
|
|
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
|
|
pattern = strings.ReplaceAll(pattern, "\\?", ".")
|
|
pattern = "^" + pattern + "$"
|
|
allowOriginPatterns[i] = pattern
|
|
}
|
|
|
|
allowMethods := strings.Join(opts.AllowMethods, ",")
|
|
allowHeaders := strings.Join(opts.AllowHeaders, ",")
|
|
exposeHeaders := strings.Join(opts.ExposeHeaders, ",")
|
|
maxAge := strconv.Itoa(opts.MaxAge)
|
|
|
|
return func(l *zap.Logger, name string, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get(keelhttp.HeaderOrigin)
|
|
allowOrigin := ""
|
|
|
|
w.Header().Add(keelhttp.HeaderVary, keelhttp.HeaderOrigin)
|
|
|
|
preflight := r.Method == http.MethodOptions
|
|
|
|
// No Origin provided
|
|
if origin == "" {
|
|
if !preflight {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
// Check allowed origins
|
|
for _, value := range opts.AllowOrigins {
|
|
if value == "*" && opts.AllowCredentials {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
if value == "*" || value == origin {
|
|
allowOrigin = value
|
|
break
|
|
}
|
|
if matchSubdomain(origin, value) {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
}
|
|
|
|
// Check allowed origin patterns
|
|
for _, re := range allowOriginPatterns {
|
|
if allowOrigin == "" {
|
|
index := strings.Index(origin, "://")
|
|
if index == -1 {
|
|
continue
|
|
}
|
|
|
|
if len(origin[index+3:]) > 253 {
|
|
break
|
|
}
|
|
|
|
if match, _ := regexp.MatchString(re, origin); match {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Origin not allowed
|
|
if allowOrigin == "" && !preflight {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
} else if allowOrigin == "" {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
// Simple request
|
|
if !preflight {
|
|
w.Header().Set(keelhttp.HeaderAccessControlAllowOrigin, allowOrigin)
|
|
if opts.AllowCredentials {
|
|
w.Header().Set(keelhttp.HeaderAccessControlAllowCredentials, "true")
|
|
}
|
|
if exposeHeaders != "" {
|
|
w.Header().Set(keelhttp.HeaderAccessControlExposeHeaders, exposeHeaders)
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Preflight request
|
|
w.Header().Add(keelhttp.HeaderVary, keelhttp.HeaderAccessControlRequestMethod)
|
|
w.Header().Add(keelhttp.HeaderVary, keelhttp.HeaderAccessControlRequestHeaders)
|
|
w.Header().Set(keelhttp.HeaderAccessControlAllowOrigin, allowOrigin)
|
|
w.Header().Set(keelhttp.HeaderAccessControlAllowMethods, allowMethods)
|
|
if opts.AllowCredentials {
|
|
w.Header().Set(keelhttp.HeaderAccessControlAllowCredentials, "true")
|
|
}
|
|
if allowHeaders != "" {
|
|
w.Header().Set(keelhttp.HeaderAccessControlAllowHeaders, allowHeaders)
|
|
} else if h := r.Header.Get(keelhttp.HeaderAccessControlRequestHeaders); h != "" {
|
|
w.Header().Set(keelhttp.HeaderAccessControlAllowHeaders, h)
|
|
}
|
|
if opts.MaxAge > 0 {
|
|
w.Header().Set(keelhttp.HeaderAccessControlMaxAge, maxAge)
|
|
}
|
|
w.WriteHeader(http.StatusNoContent)
|
|
})
|
|
}
|
|
}
|
|
|
|
func matchScheme(domain, pattern string) bool {
|
|
didx := strings.Index(domain, ":")
|
|
pidx := strings.Index(pattern, ":")
|
|
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
|
}
|
|
|
|
// matchSubdomain compares authority with wildcard
|
|
func matchSubdomain(domain, pattern string) bool {
|
|
if !matchScheme(domain, pattern) {
|
|
return false
|
|
}
|
|
didx := strings.Index(domain, "://")
|
|
pidx := strings.Index(pattern, "://")
|
|
if didx == -1 || pidx == -1 {
|
|
return false
|
|
}
|
|
domAuth := domain[didx+3:]
|
|
// to avoid long loop by invalid long domain
|
|
if len(domAuth) > 253 {
|
|
return false
|
|
}
|
|
patAuth := pattern[pidx+3:]
|
|
|
|
domComp := strings.Split(domAuth, ".")
|
|
patComp := strings.Split(patAuth, ".")
|
|
for i := len(domComp)/2 - 1; i >= 0; i-- {
|
|
opp := len(domComp) - 1 - i
|
|
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
|
}
|
|
for i := len(patComp)/2 - 1; i >= 0; i-- {
|
|
opp := len(patComp) - 1 - i
|
|
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
|
}
|
|
|
|
for i, v := range domComp {
|
|
if len(patComp) <= i {
|
|
return false
|
|
}
|
|
p := patComp[i]
|
|
if p == "*" {
|
|
return true
|
|
}
|
|
if p != v {
|
|
return false
|
|
}
|
|
}
|
|
return false
|
|
}
|