From 8e6bb7cf0dafa193acab875b51aa0523e23758e0 Mon Sep 17 00:00:00 2001 From: Stefan Martinov Date: Fri, 17 Jan 2025 15:39:43 +0100 Subject: [PATCH] chore: add support for encoding/decoding requests and responses --- client.go | 112 ++++++++++++---------- client_test.go | 97 +++++++++---------- doc/adr/0002-streaming-and-compression.md | 38 +++----- gotsrpc.go | 65 ++++--------- response.go | 92 ++++++++++++++++++ transport.go | 11 +++ 6 files changed, 244 insertions(+), 171 deletions(-) create mode 100644 response.go diff --git a/client.go b/client.go index d74ba8a..a4120f3 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package gotsrpc import ( + "bytes" "compress/gzip" "context" "fmt" @@ -11,7 +12,6 @@ import ( "github.com/golang/snappy" "github.com/pkg/errors" "github.com/ugorji/go/codec" - "golang.org/x/sync/errgroup" ) const ( @@ -26,6 +26,19 @@ const ( CompressorSnappy ) +func (c Compressor) String() string { + switch c { + case CompressorNone: + return "none" + case CompressorGZIP: + return "gzip" + case CompressorSnappy: + return "snappy" + default: + return "unknown" + } +} + // ClientTransport to use for calls // var ClientTransport = &http.Transport{} @@ -76,10 +89,9 @@ func WithHTTPClient(c *http.Client) ClientOption { } } -// WithClientHandle allows you to specify a custom clientHandle. -func WithClientHandle(h *clientHandle) ClientOption { +func WithClientEncoding(encoding ClientEncoding) ClientOption { return func(bc *BufferedClient) { - bc.handle = h + bc.handle = getHandleForType(encoding) } } @@ -102,7 +114,7 @@ func NewBufferedClient(opts ...ClientOption) *BufferedClient { bc := &BufferedClient{ client: defaultHttpFactory(), headers: make(http.Header), - handle: getHandleForType(EncodingJson), + handle: getHandleForType(EncodingMsgpack), compressor: CompressorNone, writerPoolMap: map[Compressor]*sync.Pool{ CompressorGZIP: { @@ -124,52 +136,45 @@ func NewBufferedClient(opts ...ClientOption) *BufferedClient { // Call calls a method on the remove service func (c *BufferedClient) Call(ctx context.Context, url string, endpoint string, method string, args []interface{}, reply []interface{}) error { // Marshall args - reader, writer := io.Pipe() - defer reader.Close() + buffer := &bytes.Buffer{} + // If no arguments are set, remove - g, _ := errgroup.WithContext(ctx) - if len(args) != 0 { - g.Go(func() error { + var encodeWriter io.Writer + switch c.compressor { + case CompressorGZIP: + gzipWriter := c.writerPoolMap[CompressorGZIP].Get().(*gzip.Writer) + gzipWriter.Reset(buffer) - // Close piped writer after encoding - defer writer.Close() + defer c.writerPoolMap[CompressorGZIP].Put(gzipWriter) - var encodeWriter io.Writer - switch c.compressor { - case CompressorGZIP: - gzipWriter := c.writerPoolMap[CompressorGZIP].Get().(*gzip.Writer) - gzipWriter.Reset(writer) + encodeWriter = gzipWriter + case CompressorSnappy: + snappyWriter := c.writerPoolMap[CompressorSnappy].Get().(*snappy.Writer) + snappyWriter.Reset(buffer) - defer c.writerPoolMap[CompressorGZIP].Put(gzipWriter) + defer c.writerPoolMap[CompressorSnappy].Put(snappyWriter) + encodeWriter = snappyWriter + case CompressorNone: + encodeWriter = buffer + default: + encodeWriter = buffer + } - encodeWriter = gzipWriter - defer gzipWriter.Close() - case CompressorSnappy: - snappyWriter := c.writerPoolMap[CompressorSnappy].Get().(*snappy.Writer) - snappyWriter.Reset(writer) + err := codec.NewEncoder(encodeWriter, c.handle.handle).Encode(args) + if err != nil { + return errors.Wrap(err, "could not encode data") + } - defer c.writerPoolMap[CompressorSnappy].Put(snappyWriter) - - encodeWriter = snappyWriter - defer snappyWriter.Close() - case CompressorNone: - encodeWriter = writer - default: - encodeWriter = writer - } - - return codec.NewEncoder(encodeWriter, c.handle.handle).Encode(args) - }) - } else { - // Without arguments, skip the piping altogether - writer.Close() + if writer, ok := encodeWriter.(io.Closer); ok { + if err = writer.Close(); err != nil { + return errors.Wrap(err, "failed to write to request body") + } } // Create post url postURL := fmt.Sprintf("%s%s/%s", url, endpoint, method) - - req, err := newRequest(ctx, postURL, c.handle.contentType, reader, c.headers.Clone()) + req, err := newRequest(ctx, postURL, c.handle.contentType, buffer, c.headers.Clone()) if err != nil { return NewClientError(errors.Wrap(err, "failed to create request")) } @@ -177,8 +182,10 @@ func (c *BufferedClient) Call(ctx context.Context, url string, endpoint string, switch c.compressor { case CompressorGZIP: req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Accept-Encoding", "gzip") case CompressorSnappy: req.Header.Set("Content-Encoding", "snappy") + req.Header.Set("Accept-Encoding", "snappy") case CompressorNone: // uncompressed, nothing to do default: @@ -191,13 +198,6 @@ func (c *BufferedClient) Call(ctx context.Context, url string, endpoint string, } defer resp.Body.Close() - if len(args) != 0 { - err = g.Wait() - if err != nil { - return NewClientError(errors.Wrap(err, "failed to send request data")) - } - } - // Check status if resp.StatusCode != http.StatusOK { var msg string @@ -219,7 +219,23 @@ func (c *BufferedClient) Call(ctx context.Context, url string, endpoint string, } } - if err := codec.NewDecoder(resp.Body, clientHandle.handle).Decode(wrappedReply); err != nil { + var responseBodyReader io.Reader + + switch resp.Header.Get("Content-Encoding") { + case "snappy": + responseBodyReader = snappy.NewReader(resp.Body) + case "gzip": + gzipReader, err := gzip.NewReader(resp.Body) + if err != nil { + return NewClientError(errors.Wrap(err, "could not create gzip reader")) + } + responseBodyReader = gzipReader + defer gzipReader.Close() + default: + responseBodyReader = resp.Body + } + + if err := codec.NewDecoder(responseBodyReader, clientHandle.handle).Decode(wrappedReply); err != nil { return NewClientError(errors.Wrap(err, "failed to decode response")) } diff --git a/client_test.go b/client_test.go index 629ad31..60188b5 100644 --- a/client_test.go +++ b/client_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "net/http/httptest" "os" @@ -33,67 +32,63 @@ func Test_newRequest(t *testing.T) { } func TestNewBufferedClient(t *testing.T) { + contentTypeHeaderMap := map[ClientEncoding]string{ + EncodingMsgpack: "application/msgpack; charset=utf-8", + EncodingJson: "application/json; charset=utf-8", + } + + contentEncodingHeaderMap := map[Compressor]string{ + CompressorGZIP: "gzip", + CompressorSnappy: "snappy", + } + var testRequestData []interface{} data, err := os.ReadFile("testdata/request.json") require.NoError(t, err) err = json.Unmarshal(data, &testRequestData) require.NoError(t, err) - t.Run("gzip", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - //require.Equal(t, "application/msgpack; charset=utf-8", request.Header.Get("Content-Type")) - require.Equal(t, "gzip", request.Header.Get("Content-Encoding")) - data, _ := io.ReadAll(request.Body) - fmt.Println(string(data)) - _, _ = writer.Write([]byte("[]")) + testClient := func( + encoding ClientEncoding, + compressor Compressor, + t *testing.T, + ) { + requiredResponseMessage := "Fake Response Message" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var args []map[string]interface{} + err := LoadArgs(&args, nil, r) + require.NoError(t, err) + + require.Equal(t, contentTypeHeaderMap[encoding], r.Header.Get("Content-Type")) + require.Equal(t, contentEncodingHeaderMap[compressor], r.Header.Get("Content-Encoding")) + + _ = Reply([]interface{}{requiredResponseMessage}, nil, r, w) })) defer server.Close() client := NewBufferedClient( - WithCompressor(CompressorGZIP), + WithCompressor(compressor), + WithHTTPClient(server.Client()), + WithClientEncoding(encoding), ) - assert.NotNil(t, client) - err := client.Call(context.Background(), server.URL, "/test", "test", testRequestData, nil) - assert.NoError(t, err) - }) - t.Run("snappy", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - //require.Equal(t, "application/msgpack; charset=utf-8", request.Header.Get("Content-Type")) - require.Equal(t, "snappy", request.Header.Get("Content-Encoding")) - data, _ := io.ReadAll(request.Body) - fmt.Println(string(data)) + require.NotNil(t, client) - _, _ = writer.Write([]byte("[]")) - })) - defer server.Close() + var actualResponseMessage string + response := []interface{}{&actualResponseMessage} - client := NewBufferedClient( - WithCompressor(CompressorSnappy), - ) - - assert.NotNil(t, client) - err := client.Call(context.Background(), server.URL, "/test", "test", testRequestData, nil) - assert.NoError(t, err) - }) - t.Run("plain", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - //require.Equal(t, "application/msgpack; charset=utf-8", request.Header.Get("Content-Type")) - require.Empty(t, request.Header.Get("Content-Encoding")) - data, _ := io.ReadAll(request.Body) - fmt.Println(string(data)) - - _, _ = writer.Write([]byte("[]")) - })) - defer server.Close() - - client := NewBufferedClient() - - assert.NotNil(t, client) - err := client.Call(context.Background(), server.URL, "/test", "test", testRequestData, nil) - assert.NoError(t, err) - }) + err := client.Call(context.Background(), server.URL, "/Example", "Example", testRequestData, response) + require.NoError(t, err) + require.Equal(t, requiredResponseMessage, actualResponseMessage) + } + for _, encoding := range []ClientEncoding{EncodingMsgpack, EncodingJson} { + for _, compressor := range []Compressor{CompressorNone, CompressorGZIP, CompressorSnappy} { + t.Run(fmt.Sprintf("%s/%s", encoding, compressor), func(t *testing.T) { + testClient(encoding, compressor, t) + }) + } + } } func BenchmarkBufferedClient(b *testing.B) { @@ -105,12 +100,12 @@ func BenchmarkBufferedClient(b *testing.B) { require.NoError(b, err) benchClient := func(b *testing.B, client Client) { - server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - writer.Write([]byte("[]")) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = Reply([]interface{}{"HI"}, nil, r, w) })) defer server.Close() b.ReportAllocs() - b.ResetTimer() + if bc, ok := client.(*BufferedClient); ok { bc.client = server.Client() } @@ -125,7 +120,7 @@ func BenchmarkBufferedClient(b *testing.B) { "gzip": CompressorGZIP, "snappy": CompressorSnappy, } - runs := 5 + runs := 3 for name, compressor := range benchmarks { b.Run(name, func(b *testing.B) { diff --git a/doc/adr/0002-streaming-and-compression.md b/doc/adr/0002-streaming-and-compression.md index 373a98f..3ca9291 100644 --- a/doc/adr/0002-streaming-and-compression.md +++ b/doc/adr/0002-streaming-and-compression.md @@ -36,37 +36,25 @@ BenchmarkBufferedClient/deprecated/3 BenchmarkBufferedClient/deprecated/3-10 24580 48177 ns/op 26138 B/op 108 allocs/op BenchmarkBufferedClient/deprecated/4 BenchmarkBufferedClient/deprecated/4-10 24999 57772 ns/op 26154 B/op 108 allocs/op -BenchmarkBufferedClient/snappy -BenchmarkBufferedClient/snappy/0 -BenchmarkBufferedClient/snappy/0-10 16392 69553 ns/op 15282 B/op 114 allocs/op -BenchmarkBufferedClient/snappy/1 -BenchmarkBufferedClient/snappy/1-10 17702 72923 ns/op 14944 B/op 114 allocs/op -BenchmarkBufferedClient/snappy/2 -BenchmarkBufferedClient/snappy/2-10 17932 67446 ns/op 14819 B/op 114 allocs/op -BenchmarkBufferedClient/snappy/3 -BenchmarkBufferedClient/snappy/3-10 16640 69216 ns/op 15155 B/op 114 allocs/op -BenchmarkBufferedClient/snappy/4 -BenchmarkBufferedClient/snappy/4-10 16767 66247 ns/op 15010 B/op 114 allocs/op BenchmarkBufferedClient/none BenchmarkBufferedClient/none/0 -BenchmarkBufferedClient/none/0-10 17706 68516 ns/op 12280 B/op 112 allocs/op +BenchmarkBufferedClient/none/0-10 22604 49513 ns/op 28943 B/op 119 allocs/op BenchmarkBufferedClient/none/1 -BenchmarkBufferedClient/none/1-10 17593 68580 ns/op 12308 B/op 112 allocs/op +BenchmarkBufferedClient/none/1-10 23991 55106 ns/op 28862 B/op 119 allocs/op BenchmarkBufferedClient/none/2 -BenchmarkBufferedClient/none/2-10 17292 67673 ns/op 12208 B/op 112 allocs/op -BenchmarkBufferedClient/none/3 -BenchmarkBufferedClient/none/3-10 17086 71715 ns/op 12285 B/op 112 allocs/op -BenchmarkBufferedClient/none/4 -BenchmarkBufferedClient/none/4-10 17067 68955 ns/op 12295 B/op 112 allocs/op +BenchmarkBufferedClient/none/2-10 22976 50808 ns/op 28946 B/op 119 allocs/op BenchmarkBufferedClient/gzip BenchmarkBufferedClient/gzip/0 -BenchmarkBufferedClient/gzip/0-10 7190 153284 ns/op 22024 B/op 113 allocs/op +BenchmarkBufferedClient/gzip/0-10 9292 124257 ns/op 15796 B/op 117 allocs/op BenchmarkBufferedClient/gzip/1 -BenchmarkBufferedClient/gzip/1-10 6808 158344 ns/op 20757 B/op 113 allocs/op +BenchmarkBufferedClient/gzip/1-10 10520 112287 ns/op 15601 B/op 117 allocs/op BenchmarkBufferedClient/gzip/2 -BenchmarkBufferedClient/gzip/2-10 6889 156492 ns/op 19680 B/op 113 allocs/op -BenchmarkBufferedClient/gzip/3 -BenchmarkBufferedClient/gzip/3-10 6927 148146 ns/op 18912 B/op 113 allocs/op -BenchmarkBufferedClient/gzip/4 -BenchmarkBufferedClient/gzip/4-10 8340 146697 ns/op 20207 B/op 113 allocs/op +BenchmarkBufferedClient/gzip/2-10 9838 125777 ns/op 15751 B/op 117 allocs/op +BenchmarkBufferedClient/snappy +BenchmarkBufferedClient/snappy/0 +BenchmarkBufferedClient/snappy/0-10 21604 62413 ns/op 15189 B/op 114 allocs/op +BenchmarkBufferedClient/snappy/1 +BenchmarkBufferedClient/snappy/1-10 21208 54509 ns/op 15242 B/op 114 allocs/op +BenchmarkBufferedClient/snappy/2 +BenchmarkBufferedClient/snappy/2-10 24153 51172 ns/op 15253 B/op 114 allocs/op ``` diff --git a/gotsrpc.go b/gotsrpc.go index ec71394..ac6c1f3 100644 --- a/gotsrpc.go +++ b/gotsrpc.go @@ -1,21 +1,23 @@ package gotsrpc import ( + "compress/gzip" "context" "encoding/json" "fmt" "go/ast" "go/parser" "go/token" + "io" "net/http" "os" "path" "path/filepath" - "reflect" "sort" "strings" "time" + "github.com/golang/snappy" "github.com/pkg/errors" "github.com/ugorji/go/codec" @@ -48,11 +50,23 @@ func ErrorMethodNotAllowed(w http.ResponseWriter) { func LoadArgs(args interface{}, callStats *CallStats, r *http.Request) error { start := time.Now() + var bodyReader io.Reader = r.Body + switch r.Header.Get("Content-Encoding") { + case "snappy": + bodyReader = snappy.NewReader(r.Body) + case "gzip": + gzipReader, err := gzip.NewReader(r.Body) + if err != nil { + return errors.Wrap(err, "could not create gzip reader") + } + bodyReader = gzipReader + defer gzipReader.Close() + } handle := getHandlerForContentType(r.Header.Get("Content-Type")).handle - if errDecode := codec.NewDecoder(r.Body, handle).Decode(args); errDecode != nil { - _, _ = fmt.Fprintln(os.Stderr, errDecode.Error()) - return errors.Wrap(errDecode, "could not decode arguments") + + if err := codec.NewDecoder(bodyReader, handle).Decode(args); err != nil { + return errors.Wrap(err, "could not decode arguments") } if callStats != nil { callStats.Unmarshalling = time.Since(start) @@ -85,49 +99,6 @@ func ClearStats(r *http.Request) { *r = *r.WithContext(context.WithValue(r.Context(), contextStatsKey, nil)) } -// Reply despite the fact, that this is a public method - do not call it, it will be called by generated code -func Reply(response []interface{}, stats *CallStats, r *http.Request, w http.ResponseWriter) error { - writer := newResponseWriterWithLength(w) - serializationStart := time.Now() - - clientHandle := getHandlerForContentType(r.Header.Get("Content-Type")) - - writer.Header().Set("Content-Type", clientHandle.contentType) - - if clientHandle.beforeEncodeReply != nil { - if err := clientHandle.beforeEncodeReply(&response); err != nil { - _, _ = fmt.Fprintln(os.Stderr, err.Error()) - return errors.Wrap(err, "error during before encoder reply") - } - } - - if err := codec.NewEncoder(writer, clientHandle.handle).Encode(response); err != nil { - _, _ = fmt.Fprintln(os.Stderr, err.Error()) - return errors.Wrap(err, "could not encode data to accepted format") - } - - if stats != nil { - stats.ResponseSize = writer.length - stats.Marshalling = time.Since(serializationStart) - if len(response) > 0 { - errResp := response[len(response)-1] - if v, ok := errResp.(error); ok && v != nil { - if !reflect.ValueOf(v).IsNil() { - stats.ErrorCode = 1 - stats.ErrorType = fmt.Sprintf("%T", v) - stats.ErrorMessage = v.Error() - if v, ok := v.(interface { - ErrorCode() int - }); ok { - stats.ErrorCode = v.ErrorCode() - } - } - } - } - } - return nil -} - func parserExcludeFiles(info os.FileInfo) bool { return !strings.HasSuffix(info.Name(), "_test.go") } diff --git a/response.go b/response.go new file mode 100644 index 0000000..f043a39 --- /dev/null +++ b/response.go @@ -0,0 +1,92 @@ +package gotsrpc + +import ( + "compress/gzip" + "fmt" + "io" + "net/http" + "reflect" + "slices" + "sync" + "time" + + "github.com/golang/snappy" + "github.com/pkg/errors" + "github.com/ugorji/go/codec" +) + +var ( + responseCompressors = map[Compressor]*sync.Pool{ + CompressorGZIP: {New: func() interface{} { return gzip.NewWriter(io.Discard) }}, + CompressorSnappy: {New: func() interface{} { return snappy.NewBufferedWriter(io.Discard) }}, + } +) + +// Reply despite the fact, that this is a public method - do not call it, it will be called by generated code +func Reply(response []interface{}, stats *CallStats, r *http.Request, w http.ResponseWriter) error { + + responseWriter := newResponseWriterWithLength(w) + serializationStart := time.Now() + var responseBody io.Writer = responseWriter + + clientHandle := getHandlerForContentType(r.Header.Get("Content-Type")) + responseWriter.Header().Set("Content-Type", clientHandle.contentType) + // TODO: Add weighted compression support based on Accepted-Encoding header + switch { + case slices.Contains(r.Header.Values("Accept-Encoding"), "snappy"): + responseWriter.Header().Set("Content-Encoding", "snappy") + responseWriter.Header().Set("Vary", "Accept-Encoding") + + snappyWriter := responseCompressors[CompressorSnappy].Get().(*snappy.Writer) + snappyWriter.Reset(responseWriter) + + defer responseCompressors[CompressorSnappy].Put(snappyWriter) + responseBody = snappyWriter + case slices.Contains(r.Header.Values("Accept-Encoding"), "gzip"): + responseWriter.Header().Set("Content-Encoding", "gzip") + responseWriter.Header().Set("Vary", "Accept-Encoding") + + gzipWriter := responseCompressors[CompressorGZIP].Get().(*gzip.Writer) + gzipWriter.Reset(responseWriter) + + defer responseCompressors[CompressorGZIP].Put(gzipWriter) + responseBody = gzipWriter + } + + if clientHandle.beforeEncodeReply != nil { + if err := clientHandle.beforeEncodeReply(&response); err != nil { + return fmt.Errorf("error during before encoder reply: %w", err) + } + } + + if err := codec.NewEncoder(responseBody, clientHandle.handle).Encode(response); err != nil { + return fmt.Errorf("could not encode data to accepted format: %w", err) + } + + if writer, ok := responseBody.(io.Closer); ok { + if err := writer.Close(); err != nil { + return errors.Wrap(err, "failed to write to response body") + } + } + + if stats != nil { + stats.ResponseSize = responseWriter.length + stats.Marshalling = time.Since(serializationStart) + if len(response) > 0 { + errResp := response[len(response)-1] + if v, ok := errResp.(error); ok && v != nil { + if !reflect.ValueOf(v).IsNil() { + stats.ErrorCode = 1 + stats.ErrorType = fmt.Sprintf("%T", v) + stats.ErrorMessage = v.Error() + if v, ok := v.(interface { + ErrorCode() int + }); ok { + stats.ErrorCode = v.ErrorCode() + } + } + } + } + } + return nil +} diff --git a/transport.go b/transport.go index 75e84ab..773c8ed 100644 --- a/transport.go +++ b/transport.go @@ -12,6 +12,17 @@ import ( type ClientEncoding int +func (c ClientEncoding) String() string { + switch c { + case EncodingMsgpack: + return "msgpack" + case EncodingJson: + return "json" + default: + return "unknown" + } +} + const ( EncodingMsgpack = ClientEncoding(0) EncodingJson = ClientEncoding(1) //nolint:stylecheck