chore: add support for encoding/decoding requests and responses

This commit is contained in:
Stefan Martinov 2025-01-17 15:39:43 +01:00
parent 9ee1a8d34c
commit 8e6bb7cf0d
6 changed files with 244 additions and 171 deletions

112
client.go
View File

@ -1,6 +1,7 @@
package gotsrpc
import (
"bytes"
"compress/gzip"
"context"
"fmt"
@ -11,7 +12,6 @@ import (
"github.com/golang/snappy"
"github.com/pkg/errors"
"github.com/ugorji/go/codec"
"golang.org/x/sync/errgroup"
)
const (
@ -26,6 +26,19 @@ const (
CompressorSnappy
)
func (c Compressor) String() string {
switch c {
case CompressorNone:
return "none"
case CompressorGZIP:
return "gzip"
case CompressorSnappy:
return "snappy"
default:
return "unknown"
}
}
// ClientTransport to use for calls
// var ClientTransport = &http.Transport{}
@ -76,10 +89,9 @@ func WithHTTPClient(c *http.Client) ClientOption {
}
}
// WithClientHandle allows you to specify a custom clientHandle.
func WithClientHandle(h *clientHandle) ClientOption {
func WithClientEncoding(encoding ClientEncoding) ClientOption {
return func(bc *BufferedClient) {
bc.handle = h
bc.handle = getHandleForType(encoding)
}
}
@ -102,7 +114,7 @@ func NewBufferedClient(opts ...ClientOption) *BufferedClient {
bc := &BufferedClient{
client: defaultHttpFactory(),
headers: make(http.Header),
handle: getHandleForType(EncodingJson),
handle: getHandleForType(EncodingMsgpack),
compressor: CompressorNone,
writerPoolMap: map[Compressor]*sync.Pool{
CompressorGZIP: {
@ -124,52 +136,45 @@ func NewBufferedClient(opts ...ClientOption) *BufferedClient {
// Call calls a method on the remove service
func (c *BufferedClient) Call(ctx context.Context, url string, endpoint string, method string, args []interface{}, reply []interface{}) error {
// Marshall args
reader, writer := io.Pipe()
defer reader.Close()
buffer := &bytes.Buffer{}
// If no arguments are set, remove
g, _ := errgroup.WithContext(ctx)
if len(args) != 0 {
g.Go(func() error {
var encodeWriter io.Writer
switch c.compressor {
case CompressorGZIP:
gzipWriter := c.writerPoolMap[CompressorGZIP].Get().(*gzip.Writer)
gzipWriter.Reset(buffer)
// Close piped writer after encoding
defer writer.Close()
defer c.writerPoolMap[CompressorGZIP].Put(gzipWriter)
var encodeWriter io.Writer
switch c.compressor {
case CompressorGZIP:
gzipWriter := c.writerPoolMap[CompressorGZIP].Get().(*gzip.Writer)
gzipWriter.Reset(writer)
encodeWriter = gzipWriter
case CompressorSnappy:
snappyWriter := c.writerPoolMap[CompressorSnappy].Get().(*snappy.Writer)
snappyWriter.Reset(buffer)
defer c.writerPoolMap[CompressorGZIP].Put(gzipWriter)
defer c.writerPoolMap[CompressorSnappy].Put(snappyWriter)
encodeWriter = snappyWriter
case CompressorNone:
encodeWriter = buffer
default:
encodeWriter = buffer
}
encodeWriter = gzipWriter
defer gzipWriter.Close()
case CompressorSnappy:
snappyWriter := c.writerPoolMap[CompressorSnappy].Get().(*snappy.Writer)
snappyWriter.Reset(writer)
err := codec.NewEncoder(encodeWriter, c.handle.handle).Encode(args)
if err != nil {
return errors.Wrap(err, "could not encode data")
}
defer c.writerPoolMap[CompressorSnappy].Put(snappyWriter)
encodeWriter = snappyWriter
defer snappyWriter.Close()
case CompressorNone:
encodeWriter = writer
default:
encodeWriter = writer
}
return codec.NewEncoder(encodeWriter, c.handle.handle).Encode(args)
})
} else {
// Without arguments, skip the piping altogether
writer.Close()
if writer, ok := encodeWriter.(io.Closer); ok {
if err = writer.Close(); err != nil {
return errors.Wrap(err, "failed to write to request body")
}
}
// Create post url
postURL := fmt.Sprintf("%s%s/%s", url, endpoint, method)
req, err := newRequest(ctx, postURL, c.handle.contentType, reader, c.headers.Clone())
req, err := newRequest(ctx, postURL, c.handle.contentType, buffer, c.headers.Clone())
if err != nil {
return NewClientError(errors.Wrap(err, "failed to create request"))
}
@ -177,8 +182,10 @@ func (c *BufferedClient) Call(ctx context.Context, url string, endpoint string,
switch c.compressor {
case CompressorGZIP:
req.Header.Set("Content-Encoding", "gzip")
req.Header.Set("Accept-Encoding", "gzip")
case CompressorSnappy:
req.Header.Set("Content-Encoding", "snappy")
req.Header.Set("Accept-Encoding", "snappy")
case CompressorNone:
// uncompressed, nothing to do
default:
@ -191,13 +198,6 @@ func (c *BufferedClient) Call(ctx context.Context, url string, endpoint string,
}
defer resp.Body.Close()
if len(args) != 0 {
err = g.Wait()
if err != nil {
return NewClientError(errors.Wrap(err, "failed to send request data"))
}
}
// Check status
if resp.StatusCode != http.StatusOK {
var msg string
@ -219,7 +219,23 @@ func (c *BufferedClient) Call(ctx context.Context, url string, endpoint string,
}
}
if err := codec.NewDecoder(resp.Body, clientHandle.handle).Decode(wrappedReply); err != nil {
var responseBodyReader io.Reader
switch resp.Header.Get("Content-Encoding") {
case "snappy":
responseBodyReader = snappy.NewReader(resp.Body)
case "gzip":
gzipReader, err := gzip.NewReader(resp.Body)
if err != nil {
return NewClientError(errors.Wrap(err, "could not create gzip reader"))
}
responseBodyReader = gzipReader
defer gzipReader.Close()
default:
responseBodyReader = resp.Body
}
if err := codec.NewDecoder(responseBodyReader, clientHandle.handle).Decode(wrappedReply); err != nil {
return NewClientError(errors.Wrap(err, "failed to decode response"))
}

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
@ -33,67 +32,63 @@ func Test_newRequest(t *testing.T) {
}
func TestNewBufferedClient(t *testing.T) {
contentTypeHeaderMap := map[ClientEncoding]string{
EncodingMsgpack: "application/msgpack; charset=utf-8",
EncodingJson: "application/json; charset=utf-8",
}
contentEncodingHeaderMap := map[Compressor]string{
CompressorGZIP: "gzip",
CompressorSnappy: "snappy",
}
var testRequestData []interface{}
data, err := os.ReadFile("testdata/request.json")
require.NoError(t, err)
err = json.Unmarshal(data, &testRequestData)
require.NoError(t, err)
t.Run("gzip", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
//require.Equal(t, "application/msgpack; charset=utf-8", request.Header.Get("Content-Type"))
require.Equal(t, "gzip", request.Header.Get("Content-Encoding"))
data, _ := io.ReadAll(request.Body)
fmt.Println(string(data))
_, _ = writer.Write([]byte("[]"))
testClient := func(
encoding ClientEncoding,
compressor Compressor,
t *testing.T,
) {
requiredResponseMessage := "Fake Response Message"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var args []map[string]interface{}
err := LoadArgs(&args, nil, r)
require.NoError(t, err)
require.Equal(t, contentTypeHeaderMap[encoding], r.Header.Get("Content-Type"))
require.Equal(t, contentEncodingHeaderMap[compressor], r.Header.Get("Content-Encoding"))
_ = Reply([]interface{}{requiredResponseMessage}, nil, r, w)
}))
defer server.Close()
client := NewBufferedClient(
WithCompressor(CompressorGZIP),
WithCompressor(compressor),
WithHTTPClient(server.Client()),
WithClientEncoding(encoding),
)
assert.NotNil(t, client)
err := client.Call(context.Background(), server.URL, "/test", "test", testRequestData, nil)
assert.NoError(t, err)
})
t.Run("snappy", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
//require.Equal(t, "application/msgpack; charset=utf-8", request.Header.Get("Content-Type"))
require.Equal(t, "snappy", request.Header.Get("Content-Encoding"))
data, _ := io.ReadAll(request.Body)
fmt.Println(string(data))
require.NotNil(t, client)
_, _ = writer.Write([]byte("[]"))
}))
defer server.Close()
var actualResponseMessage string
response := []interface{}{&actualResponseMessage}
client := NewBufferedClient(
WithCompressor(CompressorSnappy),
)
assert.NotNil(t, client)
err := client.Call(context.Background(), server.URL, "/test", "test", testRequestData, nil)
assert.NoError(t, err)
})
t.Run("plain", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
//require.Equal(t, "application/msgpack; charset=utf-8", request.Header.Get("Content-Type"))
require.Empty(t, request.Header.Get("Content-Encoding"))
data, _ := io.ReadAll(request.Body)
fmt.Println(string(data))
_, _ = writer.Write([]byte("[]"))
}))
defer server.Close()
client := NewBufferedClient()
assert.NotNil(t, client)
err := client.Call(context.Background(), server.URL, "/test", "test", testRequestData, nil)
assert.NoError(t, err)
})
err := client.Call(context.Background(), server.URL, "/Example", "Example", testRequestData, response)
require.NoError(t, err)
require.Equal(t, requiredResponseMessage, actualResponseMessage)
}
for _, encoding := range []ClientEncoding{EncodingMsgpack, EncodingJson} {
for _, compressor := range []Compressor{CompressorNone, CompressorGZIP, CompressorSnappy} {
t.Run(fmt.Sprintf("%s/%s", encoding, compressor), func(t *testing.T) {
testClient(encoding, compressor, t)
})
}
}
}
func BenchmarkBufferedClient(b *testing.B) {
@ -105,12 +100,12 @@ func BenchmarkBufferedClient(b *testing.B) {
require.NoError(b, err)
benchClient := func(b *testing.B, client Client) {
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.Write([]byte("[]"))
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = Reply([]interface{}{"HI"}, nil, r, w)
}))
defer server.Close()
b.ReportAllocs()
b.ResetTimer()
if bc, ok := client.(*BufferedClient); ok {
bc.client = server.Client()
}
@ -125,7 +120,7 @@ func BenchmarkBufferedClient(b *testing.B) {
"gzip": CompressorGZIP,
"snappy": CompressorSnappy,
}
runs := 5
runs := 3
for name, compressor := range benchmarks {
b.Run(name, func(b *testing.B) {

View File

@ -36,37 +36,25 @@ BenchmarkBufferedClient/deprecated/3
BenchmarkBufferedClient/deprecated/3-10 24580 48177 ns/op 26138 B/op 108 allocs/op
BenchmarkBufferedClient/deprecated/4
BenchmarkBufferedClient/deprecated/4-10 24999 57772 ns/op 26154 B/op 108 allocs/op
BenchmarkBufferedClient/snappy
BenchmarkBufferedClient/snappy/0
BenchmarkBufferedClient/snappy/0-10 16392 69553 ns/op 15282 B/op 114 allocs/op
BenchmarkBufferedClient/snappy/1
BenchmarkBufferedClient/snappy/1-10 17702 72923 ns/op 14944 B/op 114 allocs/op
BenchmarkBufferedClient/snappy/2
BenchmarkBufferedClient/snappy/2-10 17932 67446 ns/op 14819 B/op 114 allocs/op
BenchmarkBufferedClient/snappy/3
BenchmarkBufferedClient/snappy/3-10 16640 69216 ns/op 15155 B/op 114 allocs/op
BenchmarkBufferedClient/snappy/4
BenchmarkBufferedClient/snappy/4-10 16767 66247 ns/op 15010 B/op 114 allocs/op
BenchmarkBufferedClient/none
BenchmarkBufferedClient/none/0
BenchmarkBufferedClient/none/0-10 17706 68516 ns/op 12280 B/op 112 allocs/op
BenchmarkBufferedClient/none/0-10 22604 49513 ns/op 28943 B/op 119 allocs/op
BenchmarkBufferedClient/none/1
BenchmarkBufferedClient/none/1-10 17593 68580 ns/op 12308 B/op 112 allocs/op
BenchmarkBufferedClient/none/1-10 23991 55106 ns/op 28862 B/op 119 allocs/op
BenchmarkBufferedClient/none/2
BenchmarkBufferedClient/none/2-10 17292 67673 ns/op 12208 B/op 112 allocs/op
BenchmarkBufferedClient/none/3
BenchmarkBufferedClient/none/3-10 17086 71715 ns/op 12285 B/op 112 allocs/op
BenchmarkBufferedClient/none/4
BenchmarkBufferedClient/none/4-10 17067 68955 ns/op 12295 B/op 112 allocs/op
BenchmarkBufferedClient/none/2-10 22976 50808 ns/op 28946 B/op 119 allocs/op
BenchmarkBufferedClient/gzip
BenchmarkBufferedClient/gzip/0
BenchmarkBufferedClient/gzip/0-10 7190 153284 ns/op 22024 B/op 113 allocs/op
BenchmarkBufferedClient/gzip/0-10 9292 124257 ns/op 15796 B/op 117 allocs/op
BenchmarkBufferedClient/gzip/1
BenchmarkBufferedClient/gzip/1-10 6808 158344 ns/op 20757 B/op 113 allocs/op
BenchmarkBufferedClient/gzip/1-10 10520 112287 ns/op 15601 B/op 117 allocs/op
BenchmarkBufferedClient/gzip/2
BenchmarkBufferedClient/gzip/2-10 6889 156492 ns/op 19680 B/op 113 allocs/op
BenchmarkBufferedClient/gzip/3
BenchmarkBufferedClient/gzip/3-10 6927 148146 ns/op 18912 B/op 113 allocs/op
BenchmarkBufferedClient/gzip/4
BenchmarkBufferedClient/gzip/4-10 8340 146697 ns/op 20207 B/op 113 allocs/op
BenchmarkBufferedClient/gzip/2-10 9838 125777 ns/op 15751 B/op 117 allocs/op
BenchmarkBufferedClient/snappy
BenchmarkBufferedClient/snappy/0
BenchmarkBufferedClient/snappy/0-10 21604 62413 ns/op 15189 B/op 114 allocs/op
BenchmarkBufferedClient/snappy/1
BenchmarkBufferedClient/snappy/1-10 21208 54509 ns/op 15242 B/op 114 allocs/op
BenchmarkBufferedClient/snappy/2
BenchmarkBufferedClient/snappy/2-10 24153 51172 ns/op 15253 B/op 114 allocs/op
```

View File

@ -1,21 +1,23 @@
package gotsrpc
import (
"compress/gzip"
"context"
"encoding/json"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
"net/http"
"os"
"path"
"path/filepath"
"reflect"
"sort"
"strings"
"time"
"github.com/golang/snappy"
"github.com/pkg/errors"
"github.com/ugorji/go/codec"
@ -48,11 +50,23 @@ func ErrorMethodNotAllowed(w http.ResponseWriter) {
func LoadArgs(args interface{}, callStats *CallStats, r *http.Request) error {
start := time.Now()
var bodyReader io.Reader = r.Body
switch r.Header.Get("Content-Encoding") {
case "snappy":
bodyReader = snappy.NewReader(r.Body)
case "gzip":
gzipReader, err := gzip.NewReader(r.Body)
if err != nil {
return errors.Wrap(err, "could not create gzip reader")
}
bodyReader = gzipReader
defer gzipReader.Close()
}
handle := getHandlerForContentType(r.Header.Get("Content-Type")).handle
if errDecode := codec.NewDecoder(r.Body, handle).Decode(args); errDecode != nil {
_, _ = fmt.Fprintln(os.Stderr, errDecode.Error())
return errors.Wrap(errDecode, "could not decode arguments")
if err := codec.NewDecoder(bodyReader, handle).Decode(args); err != nil {
return errors.Wrap(err, "could not decode arguments")
}
if callStats != nil {
callStats.Unmarshalling = time.Since(start)
@ -85,49 +99,6 @@ func ClearStats(r *http.Request) {
*r = *r.WithContext(context.WithValue(r.Context(), contextStatsKey, nil))
}
// Reply despite the fact, that this is a public method - do not call it, it will be called by generated code
func Reply(response []interface{}, stats *CallStats, r *http.Request, w http.ResponseWriter) error {
writer := newResponseWriterWithLength(w)
serializationStart := time.Now()
clientHandle := getHandlerForContentType(r.Header.Get("Content-Type"))
writer.Header().Set("Content-Type", clientHandle.contentType)
if clientHandle.beforeEncodeReply != nil {
if err := clientHandle.beforeEncodeReply(&response); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
return errors.Wrap(err, "error during before encoder reply")
}
}
if err := codec.NewEncoder(writer, clientHandle.handle).Encode(response); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
return errors.Wrap(err, "could not encode data to accepted format")
}
if stats != nil {
stats.ResponseSize = writer.length
stats.Marshalling = time.Since(serializationStart)
if len(response) > 0 {
errResp := response[len(response)-1]
if v, ok := errResp.(error); ok && v != nil {
if !reflect.ValueOf(v).IsNil() {
stats.ErrorCode = 1
stats.ErrorType = fmt.Sprintf("%T", v)
stats.ErrorMessage = v.Error()
if v, ok := v.(interface {
ErrorCode() int
}); ok {
stats.ErrorCode = v.ErrorCode()
}
}
}
}
}
return nil
}
func parserExcludeFiles(info os.FileInfo) bool {
return !strings.HasSuffix(info.Name(), "_test.go")
}

92
response.go Normal file
View File

@ -0,0 +1,92 @@
package gotsrpc
import (
"compress/gzip"
"fmt"
"io"
"net/http"
"reflect"
"slices"
"sync"
"time"
"github.com/golang/snappy"
"github.com/pkg/errors"
"github.com/ugorji/go/codec"
)
var (
responseCompressors = map[Compressor]*sync.Pool{
CompressorGZIP: {New: func() interface{} { return gzip.NewWriter(io.Discard) }},
CompressorSnappy: {New: func() interface{} { return snappy.NewBufferedWriter(io.Discard) }},
}
)
// Reply despite the fact, that this is a public method - do not call it, it will be called by generated code
func Reply(response []interface{}, stats *CallStats, r *http.Request, w http.ResponseWriter) error {
responseWriter := newResponseWriterWithLength(w)
serializationStart := time.Now()
var responseBody io.Writer = responseWriter
clientHandle := getHandlerForContentType(r.Header.Get("Content-Type"))
responseWriter.Header().Set("Content-Type", clientHandle.contentType)
// TODO: Add weighted compression support based on Accepted-Encoding header
switch {
case slices.Contains(r.Header.Values("Accept-Encoding"), "snappy"):
responseWriter.Header().Set("Content-Encoding", "snappy")
responseWriter.Header().Set("Vary", "Accept-Encoding")
snappyWriter := responseCompressors[CompressorSnappy].Get().(*snappy.Writer)
snappyWriter.Reset(responseWriter)
defer responseCompressors[CompressorSnappy].Put(snappyWriter)
responseBody = snappyWriter
case slices.Contains(r.Header.Values("Accept-Encoding"), "gzip"):
responseWriter.Header().Set("Content-Encoding", "gzip")
responseWriter.Header().Set("Vary", "Accept-Encoding")
gzipWriter := responseCompressors[CompressorGZIP].Get().(*gzip.Writer)
gzipWriter.Reset(responseWriter)
defer responseCompressors[CompressorGZIP].Put(gzipWriter)
responseBody = gzipWriter
}
if clientHandle.beforeEncodeReply != nil {
if err := clientHandle.beforeEncodeReply(&response); err != nil {
return fmt.Errorf("error during before encoder reply: %w", err)
}
}
if err := codec.NewEncoder(responseBody, clientHandle.handle).Encode(response); err != nil {
return fmt.Errorf("could not encode data to accepted format: %w", err)
}
if writer, ok := responseBody.(io.Closer); ok {
if err := writer.Close(); err != nil {
return errors.Wrap(err, "failed to write to response body")
}
}
if stats != nil {
stats.ResponseSize = responseWriter.length
stats.Marshalling = time.Since(serializationStart)
if len(response) > 0 {
errResp := response[len(response)-1]
if v, ok := errResp.(error); ok && v != nil {
if !reflect.ValueOf(v).IsNil() {
stats.ErrorCode = 1
stats.ErrorType = fmt.Sprintf("%T", v)
stats.ErrorMessage = v.Error()
if v, ok := v.(interface {
ErrorCode() int
}); ok {
stats.ErrorCode = v.ErrorCode()
}
}
}
}
}
return nil
}

View File

@ -12,6 +12,17 @@ import (
type ClientEncoding int
func (c ClientEncoding) String() string {
switch c {
case EncodingMsgpack:
return "msgpack"
case EncodingJson:
return "json"
default:
return "unknown"
}
}
const (
EncodingMsgpack = ClientEncoding(0)
EncodingJson = ClientEncoding(1) //nolint:stylecheck