fix: improve url validation

This commit is contained in:
Philipp Mieden 2021-10-07 12:00:14 +02:00
parent c2837eec07
commit 3440cbdc0e
2 changed files with 36 additions and 3 deletions

View File

@ -2,7 +2,6 @@ package client
import ( import (
"errors" "errors"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
@ -40,6 +39,16 @@ var (
ErrInvalidServerURL = errors.New("invalid contentserver url provided") ErrInvalidServerURL = errors.New("invalid contentserver url provided")
) )
func isValidUrl(str string) bool {
u, err := url.Parse(str)
if u.Scheme != "http" && u.Scheme != "https" {
return false
}
return err == nil && u.Scheme != "" && u.Host != ""
}
// NewHTTPClient constructs a new client to talk to the contentserver. // NewHTTPClient constructs a new client to talk to the contentserver.
// It returns an error if the provided url is empty or invalid. // It returns an error if the provided url is empty or invalid.
func NewHTTPClient(server string) (c *Client, err error) { func NewHTTPClient(server string) (c *Client, err error) {
@ -48,8 +57,9 @@ func NewHTTPClient(server string) (c *Client, err error) {
return nil, ErrEmptyServerURL return nil, ErrEmptyServerURL
} }
if _, err = url.Parse(server); err != nil { // validate url
return nil, fmt.Errorf("%w: %s", ErrInvalidServerURL, err.Error()) if !isValidUrl(server) {
return nil, ErrInvalidServerURL
} }
return NewHTTPClientWithTransport(NewHTTPTransport(server, http.DefaultClient)) return NewHTTPClientWithTransport(NewHTTPTransport(server, http.DefaultClient))

View File

@ -12,6 +12,7 @@ import (
"github.com/foomo/contentserver/repo/mock" "github.com/foomo/contentserver/repo/mock"
"github.com/foomo/contentserver/requests" "github.com/foomo/contentserver/requests"
"github.com/foomo/contentserver/server" "github.com/foomo/contentserver/server"
"github.com/stretchr/testify/assert"
) )
const pathContentserver = "/contentserver" const pathContentserver = "/contentserver"
@ -25,6 +26,28 @@ func init() {
SetupLogging(true, "contentserver_client_test.log") SetupLogging(true, "contentserver_client_test.log")
} }
func TestInvalidHTTPClientInit(t *testing.T) {
c, err := NewHTTPClient("")
assert.Nil(t, c)
assert.Error(t, err)
c, err = NewHTTPClient("bogus")
assert.Nil(t, c)
assert.Error(t, err)
c, err = NewHTTPClient("htt:/notaurl")
assert.Nil(t, c)
assert.Error(t, err)
c, err = NewHTTPClient("htts://notaurl")
assert.Nil(t, c)
assert.Error(t, err)
c, err = NewHTTPClient("/path/segment/only")
assert.Nil(t, c)
assert.Error(t, err)
}
func dump(t *testing.T, v interface{}) { func dump(t *testing.T, v interface{}) {
jsonBytes, err := json.MarshalIndent(v, "", " ") jsonBytes, err := json.MarshalIndent(v, "", " ")
if err != nil { if err != nil {