diff --git a/pkg/client/mpv2middleware.go b/pkg/client/mpv2middleware.go index b52db65..2f92860 100644 --- a/pkg/client/mpv2middleware.go +++ b/pkg/client/mpv2middleware.go @@ -71,3 +71,52 @@ func MPv2MiddlewareTimestamp(next MPv2Handler) MPv2Handler { return next(r, payload) } } + +func MiddlewareUserAgent(next MPv2Handler) MPv2Handler { + return func(r *http.Request, payload *mpv2.Payload[any]) error { + if userAgent := r.Header.Get("User-Agent"); userAgent != "" { + for i, event := range payload.Events { + if value, ok := event.Params.(map[string]any); ok { + value["user_agent"] = userAgent + } + payload.Events[i] = event + } + } + return next(r, payload) + } +} + +func MiddlewareIPOverride(next MPv2Handler) MPv2Handler { + return func(r *http.Request, payload *mpv2.Payload[any]) error { + var ipOverride string + for _, key := range []string{"X-Original-Forwarded-For", "X-Forwarded-For", "X-Real-Ip"} { + if value := r.Header.Get(key); value != "" { + ipOverride = value + break + } + } + if ipOverride != "" { + for i, event := range payload.Events { + if value, ok := event.Params.(map[string]any); ok { + value["ip_override"] = ipOverride + } + payload.Events[i] = event + } + } + return next(r, payload) + } +} + +func MiddlewarePageLocation(next MPv2Handler) MPv2Handler { + return func(r *http.Request, payload *mpv2.Payload[any]) error { + if referrer := r.Header.Get("Referer"); referrer != "" { + for i, event := range payload.Events { + if value, ok := event.Params.(map[string]any); ok { + value["page_location"] = referrer + } + payload.Events[i] = event + } + } + return next(r, payload) + } +}