refactor: use referer

This commit is contained in:
franklin 2023-07-07 08:31:14 +02:00
parent a7833b3d78
commit 80dc2187e6
6 changed files with 77 additions and 79 deletions

View File

@ -4,16 +4,16 @@ import (
"context"
)
const ContextKeyReferrer contextKey = "referrer"
const ContextKeyReferer contextKey = "referer"
func GetReferrer(ctx context.Context) (string, bool) {
if value, ok := ctx.Value(ContextKeyReferrer).(string); ok {
func GetReferer(ctx context.Context) (string, bool) {
if value, ok := ctx.Value(ContextKeyReferer).(string); ok {
return value, true
} else {
return "", false
}
}
func SetReferrer(ctx context.Context, referer string) context.Context {
return context.WithValue(ctx, ContextKeyReferrer, referer)
func SetReferer(ctx context.Context, referer string) context.Context {
return context.WithValue(ctx, ContextKeyReferer, referer)
}

View File

@ -15,7 +15,6 @@ const (
HeaderLastModified = "Last-Modified"
HeaderLocation = "Location"
HeaderUpgrade = "Upgrade"
HeaderReferrer = "Referer"
HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderXForwardedHost = "X-Forwarded-Host"
@ -37,7 +36,7 @@ const (
HeaderServer = "Server"
HeaderTrueClientIP = "True-Client-Ip"
HeaderOrigin = "Origin"
HeaderXReferrer = "X-Referrer"
HeaderXReferer = "X-Referer"
HeaderUserAgent = "User-Agent"
// Cloudflare

View File

@ -18,7 +18,7 @@ type (
// GetDefaultRefererOptions returns the default options
func GetDefaultRefererOptions() RefererOptions {
return RefererOptions{
RequestHeader: []string{"X-Referer", "X-Referrer", "Referer", "Referrer"},
RequestHeader: []string{"X-Referer", "Referer"},
SetContext: true,
}
}
@ -59,7 +59,7 @@ func RefererWithOptions(opts RefererOptions) Middleware {
}
}
if referer != "" && opts.SetContext {
r = r.WithContext(context.SetReferrer(r.Context(), referer))
r = r.WithContext(context.SetReferer(r.Context(), referer))
}
next.ServeHTTP(w, r)
})

View File

@ -0,0 +1,50 @@
package roundtripware
import (
"net/http"
"go.uber.org/zap"
keelhttpcontext "github.com/foomo/keel/net/http/context"
)
type (
RefererOptions struct {
Header string
}
RefererOption func(*RefererOptions)
)
// GetDefaultRefererOptions returns the default options
func GetDefaultRefererOptions() RefererOptions {
return RefererOptions{
Header: "X-Referer",
}
}
// RefererWithHeader middleware option
func RefererWithHeader(v string) RefererOption {
return func(o *RefererOptions) {
o.Header = v
}
}
// Referer returns a RoundTripper which prints out the request & response object
func Referer(opts ...RefererOption) RoundTripware {
o := GetDefaultRefererOptions()
for _, opt := range opts {
if opt != nil {
opt(&o)
}
}
return func(l *zap.Logger, next Handler) Handler {
return func(r *http.Request) (*http.Response, error) {
if value := r.Header.Get(o.Header); value == "" {
if value, ok := keelhttpcontext.GetReferer(r.Context()); ok && value != "" {
r.Header.Set(o.Header, value)
}
}
return next(r)
}
}
}

View File

@ -15,16 +15,16 @@ import (
"github.com/foomo/keel/net/http/roundtripware"
)
func TestReferrer(t *testing.T) {
var testReferrer string
func TestReferer(t *testing.T) {
var testReferer string
// create logger
l := zaptest.NewLogger(t)
// create http server with handler
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testReferrer = r.Header.Get(keelhttp.HeaderXReferrer)
assert.Empty(t, testReferrer)
testReferer = r.Header.Get(keelhttp.HeaderXReferer)
assert.Empty(t, testReferer)
w.WriteHeader(http.StatusOK)
}))
defer svr.Close()
@ -32,7 +32,7 @@ func TestReferrer(t *testing.T) {
// create http client
client := keelhttp.NewHTTPClient(
keelhttp.HTTPClientWithRoundTripware(l,
roundtripware.Referrer(),
roundtripware.Referer(),
),
)
@ -46,18 +46,18 @@ func TestReferrer(t *testing.T) {
defer resp.Body.Close()
// validate
assert.Equal(t, testReferrer, req.Header.Get(keelhttp.HeaderXReferrer))
assert.Equal(t, testReferer, req.Header.Get(keelhttp.HeaderXReferer))
}
func TestReferrer_Context(t *testing.T) {
testReferrer := "https://foomo.org/"
func TestReferer_Context(t *testing.T) {
testReferer := "https://foomo.org/"
// create logger
l := zaptest.NewLogger(t)
// create http server with handler
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, testReferrer, r.Header.Get(keelhttp.HeaderXReferrer))
assert.Equal(t, testReferer, r.Header.Get(keelhttp.HeaderXReferer))
w.WriteHeader(http.StatusOK)
}))
defer svr.Close()
@ -65,12 +65,12 @@ func TestReferrer_Context(t *testing.T) {
// create http client
client := keelhttp.NewHTTPClient(
keelhttp.HTTPClientWithRoundTripware(l,
roundtripware.Referrer(),
roundtripware.Referer(),
),
)
// set request id on context
ctx := keelhttpcontext.SetReferrer(context.Background(), testReferrer)
ctx := keelhttpcontext.SetReferer(context.Background(), testReferer)
// create request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, svr.URL, nil)
@ -82,11 +82,11 @@ func TestReferrer_Context(t *testing.T) {
defer resp.Body.Close()
// validate
assert.Equal(t, testReferrer, req.Header.Get(keelhttp.HeaderXReferrer))
assert.Equal(t, testReferer, req.Header.Get(keelhttp.HeaderXReferer))
}
func TestReferrer_WithHeader(t *testing.T) {
testReferrer := "https://foomo.org/"
func TestReferer_WithHeader(t *testing.T) {
testReferer := "https://foomo.org/"
testHeader := "X-Custom-Header"
// create logger
@ -94,7 +94,7 @@ func TestReferrer_WithHeader(t *testing.T) {
// create http server with handler
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, testReferrer, r.Header.Get(testHeader))
assert.Equal(t, testReferer, r.Header.Get(testHeader))
w.WriteHeader(http.StatusOK)
}))
defer svr.Close()
@ -102,14 +102,14 @@ func TestReferrer_WithHeader(t *testing.T) {
// create http client
client := keelhttp.NewHTTPClient(
keelhttp.HTTPClientWithRoundTripware(l,
roundtripware.Referrer(
roundtripware.ReferrerWithHeader(testHeader),
roundtripware.Referer(
roundtripware.RefererWithHeader(testHeader),
),
),
)
// set request id on context
ctx := keelhttpcontext.SetReferrer(context.Background(), testReferrer)
ctx := keelhttpcontext.SetReferer(context.Background(), testReferer)
// create request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, svr.URL, nil)
@ -121,5 +121,5 @@ func TestReferrer_WithHeader(t *testing.T) {
defer resp.Body.Close()
// validate
assert.Equal(t, testReferrer, req.Header.Get(testHeader))
assert.Equal(t, testReferer, req.Header.Get(testHeader))
}

View File

@ -1,51 +0,0 @@
package roundtripware
import (
"net/http"
"go.uber.org/zap"
keelhttpcontext "github.com/foomo/keel/net/http/context"
)
type (
ReferrerOptions struct {
Header string
}
ReferrerOption func(*ReferrerOptions)
ReferrerGenerator func() string
)
// GetDefaultReferrerOptions returns the default options
func GetDefaultReferrerOptions() ReferrerOptions {
return ReferrerOptions{
Header: "X-Referrer",
}
}
// ReferrerWithHeader middleware option
func ReferrerWithHeader(v string) ReferrerOption {
return func(o *ReferrerOptions) {
o.Header = v
}
}
// Referrer returns a RoundTripper which prints out the request & response object
func Referrer(opts ...ReferrerOption) RoundTripware {
o := GetDefaultReferrerOptions()
for _, opt := range opts {
if opt != nil {
opt(&o)
}
}
return func(l *zap.Logger, next Handler) Handler {
return func(r *http.Request) (*http.Response, error) {
if value := r.Header.Get(o.Header); value == "" {
if value, ok := keelhttpcontext.GetReferrer(r.Context()); ok && value != "" {
r.Header.Set(o.Header, value)
}
}
return next(r)
}
}
}