diff --git a/client.go b/client.go index 51c9a55..3e8e431 100644 --- a/client.go +++ b/client.go @@ -18,13 +18,12 @@ var _ Client = &bufferedClient{} type Client interface { Call(url string, endpoint string, method string, args []interface{}, reply []interface{}) (err error) + SetEncoding(encoding ClientEncoding) + SetHttpClient(client *http.Client) } -func NewClient(httpClient *http.Client) Client { - if httpClient == nil { - httpClient = http.DefaultClient - } - return &bufferedClient{client: httpClient} +func NewClient() Client { + return &bufferedClient{client: http.DefaultClient, handle: getHandleForEncoding(EncodingMsgpack)} } func newRequest(url string, contentType string, reader io.Reader) (r *http.Request, err error) { @@ -32,21 +31,31 @@ func newRequest(url string, contentType string, reader io.Reader) (r *http.Reque if errRequest != nil { return nil, errors.Wrap(errRequest, "could not create a request") } + request.Header.Set("Content-Type", contentType) request.Header.Set("Accept", contentType) + return request, nil } type bufferedClient struct { client *http.Client + handle *clientHandle +} + +func (c *bufferedClient) SetEncoding(encoding ClientEncoding) { + c.handle = getHandleForEncoding(encoding) +} + +func (c *bufferedClient) SetHttpClient(client *http.Client) { + c.client = client } // CallClient calls a method on the remove service func (c *bufferedClient) Call(url string, endpoint string, method string, args []interface{}, reply []interface{}) (err error) { // Marshall args - b := new(bytes.Buffer) - errEncode := codec.NewEncoder(b, msgpackHandle).Encode(args) + errEncode := codec.NewEncoder(b, c.handle.handle).Encode(args) if errEncode != nil { return errors.Wrap(errEncode, "could not encode argument") } @@ -56,7 +65,7 @@ func (c *bufferedClient) Call(url string, endpoint string, method string, args [ postURL := fmt.Sprintf("%s%s/%s", url, endpoint, method) // Post - request, errRequest := newRequest(postURL, msgpackContentType, b) + request, errRequest := newRequest(postURL, c.handle.contentType, b) if errRequest != nil { return errRequest } @@ -77,14 +86,8 @@ func (c *bufferedClient) Call(url string, endpoint string, method string, args [ } var errDecode error - switch resp.Header.Get("Content-Type") { - case msgpackContentType: - errDecode = codec.NewDecoder(resp.Body, msgpackHandle).Decode(reply) - case jsonContentType: - errDecode = codec.NewDecoder(resp.Body, jsonHandle).Decode(reply) - default: - errDecode = codec.NewDecoder(resp.Body, jsonHandle).Decode(reply) - } + responseHandle := getHandlerForContentType(resp.Header.Get("Content-Type")).handle + errDecode = codec.NewDecoder(resp.Body, responseHandle).Decode(reply) // Unmarshal reply if errDecode != nil { diff --git a/demo/gorpc.go b/demo/gorpc.go index 9ef207f..2cab6a0 100644 --- a/demo/gorpc.go +++ b/demo/gorpc.go @@ -11,7 +11,7 @@ import ( time "time" gotsrpc "github.com/foomo/gotsrpc" - nested "github.com/foomo/gotsrpc/demo/nested" + github_com_foomo_gotsrpc_demo_nested "github.com/foomo/gotsrpc/demo/nested" gorpc "github.com/valyala/gorpc" ) @@ -111,8 +111,8 @@ type ( DemoGiveMeAScalarRequest struct { } DemoGiveMeAScalarResponse struct { - Amount nested.Amount - Wahr nested.True + Amount github_com_foomo_gotsrpc_demo_nested.Amount + Wahr github_com_foomo_gotsrpc_demo_nested.True Hier ScalarInPlace } @@ -147,7 +147,7 @@ type ( DemoNestRequest struct { } DemoNestResponse struct { - RetNest_0 []*nested.Nested + RetNest_0 []*github_com_foomo_gotsrpc_demo_nested.Nested } DemoTestScalarInPlaceRequest struct { diff --git a/demo/gorpcclient.go b/demo/gorpcclient.go index 843c418..241e170 100644 --- a/demo/gorpcclient.go +++ b/demo/gorpcclient.go @@ -5,7 +5,7 @@ package demo import ( tls "crypto/tls" - nested "github.com/foomo/gotsrpc/demo/nested" + github_com_foomo_gotsrpc_demo_nested "github.com/foomo/gotsrpc/demo/nested" gorpc "github.com/valyala/gorpc" ) @@ -75,7 +75,7 @@ func (tsc *DemoGoRPCClient) ExtractAddress(person *Person) (addr *Address, e *Er return response.Addr, response.E, nil } -func (tsc *DemoGoRPCClient) GiveMeAScalar() (amount nested.Amount, wahr nested.True, hier ScalarInPlace, clientErr error) { +func (tsc *DemoGoRPCClient) GiveMeAScalar() (amount github_com_foomo_gotsrpc_demo_nested.Amount, wahr github_com_foomo_gotsrpc_demo_nested.True, hier ScalarInPlace, clientErr error) { req := DemoGiveMeAScalarRequest{} rpcCallRes, rpcCallErr := tsc.Client.Call(req) if rpcCallErr != nil { @@ -129,7 +129,7 @@ func (tsc *DemoGoRPCClient) MapCrap() (crap map[string][]int, clientErr error) { return response.Crap, nil } -func (tsc *DemoGoRPCClient) Nest() (retNest_0 []*nested.Nested, clientErr error) { +func (tsc *DemoGoRPCClient) Nest() (retNest_0 []*github_com_foomo_gotsrpc_demo_nested.Nested, clientErr error) { req := DemoNestRequest{} rpcCallRes, rpcCallErr := tsc.Client.Call(req) if rpcCallErr != nil { diff --git a/demo/gotsrpcclient.go b/demo/gotsrpcclient.go index d1d6313..3d7526f 100644 --- a/demo/gotsrpcclient.go +++ b/demo/gotsrpcclient.go @@ -4,7 +4,7 @@ package demo import ( gotsrpc "github.com/foomo/gotsrpc" - nested "github.com/foomo/gotsrpc/demo/nested" + github_com_foomo_gotsrpc_demo_nested "github.com/foomo/gotsrpc/demo/nested" ) type FooGoTSRPCClient interface { @@ -25,7 +25,7 @@ func NewFooGoTSRPCClient(url string, endpoint string) FooGoTSRPCClient { return &tsrpcFooGoTSRPCClient{ URL: url, EndPoint: endpoint, - Client: gotsrpc.NewClient(nil), + Client: gotsrpc.NewClient(), } } @@ -41,17 +41,15 @@ func (tsc *tsrpcFooGoTSRPCClient) Hello(number int64) (retHello_0 int, clientErr type DemoGoTSRPCClient interface { ExtractAddress(person *Person) (addr *Address, e *Err, clientErr error) - GiveMeAScalar() (amount nested.Amount, wahr nested.True, hier ScalarInPlace, clientErr error) + GiveMeAScalar() (amount github_com_foomo_gotsrpc_demo_nested.Amount, wahr github_com_foomo_gotsrpc_demo_nested.True, hier ScalarInPlace, clientErr error) Hello(name string) (retHello_0 string, retHello_1 *Err, clientErr error) HelloInterface(anything interface{}, anythingMap map[string]interface{}, anythingSlice []interface{}) (clientErr error) HelloScalarError() (err *ScalarError, clientErr error) MapCrap() (crap map[string][]int, clientErr error) - Nest() (retNest_0 []*nested.Nested, clientErr error) + Nest() (retNest_0 []*github_com_foomo_gotsrpc_demo_nested.Nested, clientErr error) TestScalarInPlace() (retTestScalarInPlace_0 ScalarInPlace, clientErr error) } -var _ DemoGoTSRPCClient = &tsrpcDemoGoTSRPCClient{} - type tsrpcDemoGoTSRPCClient struct { URL string EndPoint string @@ -66,7 +64,7 @@ func NewDemoGoTSRPCClient(url string, endpoint string) DemoGoTSRPCClient { return &tsrpcDemoGoTSRPCClient{ URL: url, EndPoint: endpoint, - Client: gotsrpc.NewClient(nil), + Client: gotsrpc.NewClient(), } } @@ -80,7 +78,7 @@ func (tsc *tsrpcDemoGoTSRPCClient) ExtractAddress(person *Person) (addr *Address return } -func (tsc *tsrpcDemoGoTSRPCClient) GiveMeAScalar() (amount nested.Amount, wahr nested.True, hier ScalarInPlace, clientErr error) { +func (tsc *tsrpcDemoGoTSRPCClient) GiveMeAScalar() (amount github_com_foomo_gotsrpc_demo_nested.Amount, wahr github_com_foomo_gotsrpc_demo_nested.True, hier ScalarInPlace, clientErr error) { args := []interface{}{} reply := []interface{}{&amount, &wahr, &hier} clientErr = tsc.Client.Call(tsc.URL, tsc.EndPoint, "GiveMeAScalar", args, reply) @@ -115,7 +113,7 @@ func (tsc *tsrpcDemoGoTSRPCClient) MapCrap() (crap map[string][]int, clientErr e return } -func (tsc *tsrpcDemoGoTSRPCClient) Nest() (retNest_0 []*nested.Nested, clientErr error) { +func (tsc *tsrpcDemoGoTSRPCClient) Nest() (retNest_0 []*github_com_foomo_gotsrpc_demo_nested.Nested, clientErr error) { args := []interface{}{} reply := []interface{}{&retNest_0} clientErr = tsc.Client.Call(tsc.URL, tsc.EndPoint, "Nest", args, reply) diff --git a/go.go b/go.go index aa20902..d39eea8 100644 --- a/go.go +++ b/go.go @@ -433,7 +433,7 @@ func renderTSRPCServiceClients(services ServiceList, fullPackageName string, pac return &` + clientName + `{ URL: url, EndPoint: endpoint, - Client: gotsrpc.NewClient(nil), + Client: gotsrpc.NewClient(), } }`) diff --git a/gotsrpc.go b/gotsrpc.go index 30b3714..f0527e1 100644 --- a/gotsrpc.go +++ b/gotsrpc.go @@ -39,15 +39,9 @@ func ErrorMethodNotAllowed(w http.ResponseWriter) { func LoadArgs(args interface{}, callStats *CallStats, r *http.Request) error { start := time.Now() - var errDecode error - switch r.Header.Get("Content-Type") { - case msgpackContentType: - errDecode = codec.NewDecoder(r.Body, msgpackHandle).Decode(args) - default: - errDecode = codec.NewDecoder(r.Body, jsonHandle).Decode(args) - } - if errDecode != nil { + handle := getHandlerForContentType(r.Header.Get("Content-Type")).handle + if errDecode := codec.NewDecoder(r.Body, handle).Decode(args); errDecode != nil { return errors.Wrap(errDecode, "could not decode arguments") } if callStats != nil { @@ -85,21 +79,12 @@ func ClearStats(r *http.Request) { func Reply(response []interface{}, stats *CallStats, r *http.Request, w http.ResponseWriter) { writer := newResponseWriterWithLength(w) serializationStart := time.Now() - var errEncode error - switch r.Header.Get("Accept") { - case msgpackContentType: - writer.Header().Set("Content-Type", msgpackContentType) - errEncode = codec.NewEncoder(writer, msgpackHandle).Encode(response) - case jsonContentType: - writer.Header().Set("Content-Type", jsonContentType) - errEncode = codec.NewEncoder(writer, jsonHandle).Encode(response) - default: - writer.Header().Set("Content-Type", jsonContentType) - errEncode = codec.NewEncoder(writer, jsonHandle).Encode(response) - } + clientHandle := getHandlerForContentType(r.Header.Get("Content-Type")) - if errEncode != nil { + writer.Header().Set("Content-Type", clientHandle.contentType) + + if errEncode := codec.NewEncoder(writer, clientHandle.handle).Encode(response); errEncode != nil { fmt.Println(errEncode) http.Error(w, "could not encode data to accepted format", http.StatusInternalServerError) return @@ -109,7 +94,6 @@ func Reply(response []interface{}, stats *CallStats, r *http.Request, w http.Res stats.ResponseSize = writer.length stats.Marshalling = time.Now().Sub(serializationStart) } - //writer.WriteHeader(http.StatusOK) } func parseDir(goPaths []string, packageName string) (map[string]*ast.Package, error) { diff --git a/transport.go b/transport.go index 07cc296..597dbf0 100644 --- a/transport.go +++ b/transport.go @@ -6,19 +6,53 @@ import ( "github.com/ugorji/go/codec" ) -var ( - msgpackHandle = &codec.MsgpackHandle{ - RawToString: true, - } - msgpackContentType = "application/msgpack; charset=utf-8" +type ClientEncoding int + +const ( + EncodingMsgpack = ClientEncoding(0) + EncodingJson = ClientEncoding(1) ) -var ( - jsonHandle = &codec.JsonHandle{ +type clientHandle struct { + handle codec.Handle + contentType string +} + +var msgpackClientHandle = &clientHandle{ + handle: &codec.MsgpackHandle{ + RawToString: true, + }, + contentType: "application/msgpack; charset=utf-8", +} + +var jsonClientHandle = &clientHandle{ + handle: &codec.JsonHandle{ MapKeyAsString: true, + }, + contentType: "application/json; charset=utf-8", +} + +func getHandleForEncoding(encoding ClientEncoding) *clientHandle { + switch encoding { + case EncodingMsgpack: + return msgpackClientHandle + case EncodingJson: + return jsonClientHandle + default: + return jsonClientHandle } - jsonContentType = "application/json; charset=utf-8" -) +} + +func getHandlerForContentType(contentType string) *clientHandle { + switch contentType { + case msgpackClientHandle.contentType: + return msgpackClientHandle + case jsonClientHandle.contentType: + return jsonClientHandle + default: + return jsonClientHandle + } +} type responseWriterWithLength struct { http.ResponseWriter