diff --git a/client/client.go b/client/client.go index 800b1f3..c67ae02 100644 --- a/client/client.go +++ b/client/client.go @@ -2,7 +2,6 @@ package client import ( "errors" - "fmt" "net/http" "net/url" "time" @@ -40,6 +39,16 @@ var ( ErrInvalidServerURL = errors.New("invalid contentserver url provided") ) +func isValidUrl(str string) bool { + u, err := url.Parse(str) + + if u.Scheme != "http" && u.Scheme != "https" { + return false + } + + return err == nil && u.Scheme != "" && u.Host != "" +} + // NewHTTPClient constructs a new client to talk to the contentserver. // It returns an error if the provided url is empty or invalid. func NewHTTPClient(server string) (c *Client, err error) { @@ -48,8 +57,9 @@ func NewHTTPClient(server string) (c *Client, err error) { return nil, ErrEmptyServerURL } - if _, err = url.Parse(server); err != nil { - return nil, fmt.Errorf("%w: %s", ErrInvalidServerURL, err.Error()) + // validate url + if !isValidUrl(server) { + return nil, ErrInvalidServerURL } return NewHTTPClientWithTransport(NewHTTPTransport(server, http.DefaultClient)) diff --git a/client/client_test.go b/client/client_test.go index 6bdb26e..b31b421 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -12,6 +12,7 @@ import ( "github.com/foomo/contentserver/repo/mock" "github.com/foomo/contentserver/requests" "github.com/foomo/contentserver/server" + "github.com/stretchr/testify/assert" ) const pathContentserver = "/contentserver" @@ -25,6 +26,28 @@ func init() { SetupLogging(true, "contentserver_client_test.log") } +func TestInvalidHTTPClientInit(t *testing.T) { + c, err := NewHTTPClient("") + assert.Nil(t, c) + assert.Error(t, err) + + c, err = NewHTTPClient("bogus") + assert.Nil(t, c) + assert.Error(t, err) + + c, err = NewHTTPClient("htt:/notaurl") + assert.Nil(t, c) + assert.Error(t, err) + + c, err = NewHTTPClient("htts://notaurl") + assert.Nil(t, c) + assert.Error(t, err) + + c, err = NewHTTPClient("/path/segment/only") + assert.Nil(t, c) + assert.Error(t, err) +} + func dump(t *testing.T, v interface{}) { jsonBytes, err := json.MarshalIndent(v, "", " ") if err != nil {