From ef3ef2bc50c4b0bdffa10862033a9bbee9553c1c Mon Sep 17 00:00:00 2001 From: Jan Halfar Date: Tue, 27 Nov 2018 12:50:10 +0100 Subject: [PATCH] new socket client with connection pool and added http client --- client/client.go | 126 ++++-------------- client/client_test.go | 262 +++++++++++++++++++++++++------------- client/connectionpool.go | 130 +++++++++++++++++++ client/httptransport.go | 62 +++++++++ client/sockettransport.go | 122 ++++++++++++++++++ client/transport.go | 8 ++ 6 files changed, 524 insertions(+), 186 deletions(-) create mode 100644 client/connectionpool.go create mode 100644 client/httptransport.go create mode 100644 client/sockettransport.go create mode 100644 client/transport.go diff --git a/client/client.go b/client/client.go index 60c2c88..652bf79 100644 --- a/client/client.go +++ b/client/client.go @@ -1,12 +1,7 @@ package client import ( - "encoding/json" - "errors" - "fmt" - "io" - "net" - "strconv" + "time" "github.com/foomo/contentserver/content" "github.com/foomo/contentserver/requests" @@ -14,34 +9,47 @@ import ( "github.com/foomo/contentserver/server" ) -type serverResponse struct { - Reply interface{} -} - // Client a content server client type Client struct { - Server string - conn net.Conn + t transport +} + +func NewClient( + server string, + connectionPoolSize int, + waitTimeout time.Duration, +) (c *Client, err error) { + c = &Client{ + t: newSocketTransport(server, connectionPoolSize, waitTimeout), + } + return +} + +func NewHTTPClient(server string) (c *Client, err error) { + c = &Client{ + t: newHTTPTransport(server), + } + return } // Update tell the server to update itself func (c *Client) Update() (response *responses.Update, err error) { response = &responses.Update{} - err = c.call(server.HandlerUpdate, &requests.Update{}, response) + err = c.t.call(server.HandlerUpdate, &requests.Update{}, response) return } // GetContent request site content func (c *Client) GetContent(request *requests.Content) (response *content.SiteContent, err error) { response = &content.SiteContent{} - err = c.call(server.HandlerGetContent, request, response) + err = c.t.call(server.HandlerGetContent, request, response) return } // GetURIs resolve uris for ids in a dimension func (c *Client) GetURIs(dimension string, IDs []string) (uriMap map[string]string, err error) { uriMap = map[string]string{} - err = c.call( + err = c.t.call( server.HandlerGetURIs, &requests.URIs{ Dimension: dimension, @@ -59,97 +67,17 @@ func (c *Client) GetNodes(env *requests.Env, nodes map[string]*requests.Node) (n Nodes: nodes, } nodesResponse = map[string]*content.Node{} - err = c.call(server.HandlerGetNodes, r, &nodesResponse) + err = c.t.call(server.HandlerGetNodes, r, &nodesResponse) return } // GetRepo get the whole repo func (c *Client) GetRepo() (response map[string]*content.RepoNode, err error) { response = map[string]*content.RepoNode{} - err = c.call(server.HandlerGetRepo, &requests.Repo{}, &response) + err = c.t.call(server.HandlerGetRepo, &requests.Repo{}, &response) return } -// func (c *Client) closeConnection() error { -// if c.conn != nil { -// err := c.conn.Close() -// if err != nil { -// return err -// } -// c.conn = nil -// } -// return nil -// } - -// func (c *Client) getConnection() (conn net.Conn, err error) { -// // we need some pooling here -// return -// } - -func (c *Client) call(handler server.Handler, request interface{}, response interface{}) error { - jsonBytes, err := json.Marshal(request) - if err != nil { - return fmt.Errorf("could not marshal request : %q", err) - } - conn, err := net.Dial("tcp", c.Server) - if err != nil { - return fmt.Errorf("can not call server - connection error: %q", err) - } - defer conn.Close() - // write header result will be like handler:2{} - jsonBytes = append([]byte(fmt.Sprintf("%s:%d", handler, len(jsonBytes))), jsonBytes...) - - // send request - written := 0 - l := len(jsonBytes) - for written < l { - n, err := conn.Write(jsonBytes[written:]) - if err != nil { - return fmt.Errorf("failed to send request: %q", err) - } - written += n - } - - // read response - responseBytes := []byte{} - buf := make([]byte, 4096) - responseLength := 0 - for { - n, err := conn.Read(buf) - if err != nil && err != io.EOF { - return fmt.Errorf("an error occured while reading the response: %q", err) - } - if n == 0 { - break - } - responseBytes = append(responseBytes, buf[0:n]...) - if responseLength == 0 { - for index, byte := range responseBytes { - if byte == 123 { - // opening bracket - responseLength, err = strconv.Atoi(string(responseBytes[0:index])) - if err != nil { - return errors.New("could not read response length: " + err.Error()) - } - responseBytes = responseBytes[index:] - break - } - } - } - if responseLength > 0 && len(responseBytes) == responseLength { - break - } - } - // unmarshal response - responseJSONErr := json.Unmarshal(responseBytes, &serverResponse{Reply: response}) - if responseJSONErr != nil { - // is it an error ? - remoteErr := responses.Error{} - remoteErrJSONErr := json.Unmarshal(responseBytes, remoteErr) - if remoteErrJSONErr == nil { - return remoteErr - } - return fmt.Errorf("could not unmarshal response : %q %q", remoteErrJSONErr, string(responseBytes)) - } - return nil +func (c *Client) ShutDown() { + c.t.shutdown() } diff --git a/client/client_test.go b/client/client_test.go index 7dbebf5..068c0fb 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -2,6 +2,8 @@ package client import ( "encoding/json" + "net" + "strconv" "sync" "testing" "time" @@ -9,6 +11,7 @@ import ( "github.com/foomo/contentserver/content" "github.com/foomo/contentserver/log" "github.com/foomo/contentserver/repo/mock" + "github.com/foomo/contentserver/requests" "github.com/foomo/contentserver/server" ) @@ -23,123 +26,208 @@ func dump(t *testing.T, v interface{}) { t.Log(string(jsonBytes)) } -func getTestClient(t testing.TB) *Client { +func getFreePort() int { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + panic(err) + } + l, err := net.ListenTCP("tcp", addr) + if err != nil { + panic(err) + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port +} + +func getAvailableAddr() string { + return "127.0.0.1:" + strconv.Itoa(getFreePort()) +} + +var testServerSocketAddr string +var testServerWebserverAddr string + +func initTestServer(t testing.TB) (socketAddr, webserverAddr string) { + socketAddr = getAvailableAddr() + webserverAddr = getAvailableAddr() + testServer, varDir := mock.GetMockData(t) log.SelectedLevel = log.LevelError - addr := "127.0.0.1:9999" - if !testServerIsRunning { - testServerIsRunning = true - testServer, varDir := mock.GetMockData(t) - go server.Run(testServer.URL+"/repo-two-dimensions.json", addr, varDir) + go server.RunServerSocketAndWebServer(testServer.URL+"/repo-two-dimensions.json", socketAddr, webserverAddr, varDir) + socketClient, errClient := NewClient(socketAddr, 1, time.Duration(time.Millisecond*100)) + if errClient != nil { + panic(errClient) + } + i := 0 + for { time.Sleep(time.Millisecond * 100) + r, err := socketClient.GetRepo() + if err != nil { + continue + } + if r["dimension_foo"].Nodes["id-a"].Data["baz"].(float64) == float64(1) { + break + } + if i > 100 { + panic("this is taking too long") + } + i++ } - return &Client{ - Server: addr, + return +} + +func getTestClients(t testing.TB) (socketClient *Client, httpClient *Client) { + if testServerSocketAddr == "" { + socketAddr, webserverAddr := initTestServer(t) + testServerSocketAddr = socketAddr + testServerWebserverAddr = webserverAddr } + socketClient, errClient := NewClient(testServerSocketAddr, 30, time.Duration(time.Millisecond*100)) + if errClient != nil { + t.Log(errClient) + t.Fail() + } + httpClient, errHTTPClient := NewHTTPClient("http://" + testServerWebserverAddr + server.PathContentserver) + if errHTTPClient != nil { + t.Log(errHTTPClient) + t.Fail() + } + return +} + +func testWithClients(t *testing.T, testFunc func(c *Client)) { + socketClient, httpClient := getTestClients(t) + defer socketClient.ShutDown() + defer httpClient.ShutDown() + testFunc(socketClient) + testFunc(httpClient) } func TestUpdate(t *testing.T) { - c := getTestClient(t) - response, err := c.Update() - if err != nil { - t.Fatal("unexpected err", err) - } - if !response.Success { - t.Fatal("update has to return .Sucesss true", response) - } - stats := response.Stats - if !(stats.RepoRuntime > float64(0.0)) || !(stats.OwnRuntime > float64(0.0)) { - t.Fatal("stats invalid") - } + testWithClients(t, func(c *Client) { + response, err := c.Update() + if err != nil { + t.Fatal("unexpected err", err) + } + if !response.Success { + t.Fatal("update has to return .Sucesss true", response) + } + stats := response.Stats + if !(stats.RepoRuntime > float64(0.0)) || !(stats.OwnRuntime > float64(0.0)) { + t.Fatal("stats invalid") + } + }) } func TestGetURIs(t *testing.T) { - c := getTestClient(t) - request := mock.MakeValidURIsRequest() - uriMap, err := c.GetURIs(request.Dimension, request.IDs) - if err != nil { - t.Fatal(err) - } - if uriMap[request.IDs[0]] != "/a" { - t.Fatal(uriMap) - } + testWithClients(t, func(c *Client) { + defer c.ShutDown() + request := mock.MakeValidURIsRequest() + uriMap, err := c.GetURIs(request.Dimension, request.IDs) + if err != nil { + t.Fatal(err) + } + if uriMap[request.IDs[0]] != "/a" { + t.Fatal(uriMap) + } + }) } func TestGetRepo(t *testing.T) { - c := getTestClient(t) - r, err := c.GetRepo() - if err != nil { - t.Fatal(err) - } - if r["dimension_foo"].Nodes["id-a"].Data["baz"].(float64) != float64(1) { - t.Fatal("failed to drill deep for data") - } + testWithClients(t, func(c *Client) { + r, err := c.GetRepo() + if err != nil { + t.Fatal(err) + } + if r["dimension_foo"].Nodes["id-a"].Data["baz"].(float64) != float64(1) { + t.Fatal("failed to drill deep for data") + } + }) } func TestGetNodes(t *testing.T) { - c := getTestClient(t) - nodesRequest := mock.MakeNodesRequest() - nodes, err := c.GetNodes(nodesRequest.Env, nodesRequest.Nodes) - if err != nil { - t.Fatal(err) - } - testNode, ok := nodes["test"] - if !ok { - t.Fatal("that should be a node") - } - testData, ok := testNode.Item.Data["foo"] - if !ok { - t.Fatal("where is foo") - } - if testData != "bar" { - t.Fatal("testData should have bennd bar not", testData) - } - + testWithClients(t, func(c *Client) { + nodesRequest := mock.MakeNodesRequest() + nodes, err := c.GetNodes(nodesRequest.Env, nodesRequest.Nodes) + if err != nil { + t.Fatal(err) + } + testNode, ok := nodes["test"] + if !ok { + t.Fatal("that should be a node") + } + testData, ok := testNode.Item.Data["foo"] + if !ok { + t.Fatal("where is foo") + } + if testData != "bar" { + t.Fatal("testData should have bennd bar not", testData) + } + }) } func TestGetContent(t *testing.T) { - c := getTestClient(t) - request := mock.MakeValidContentRequest() - response, err := c.GetContent(request) - if err != nil { - t.Fatal("unexpected err", err) - } - if request.URI != response.URI { - dump(t, request) - dump(t, response) - t.Fatal("uri mismatch") - } - if response.Status != content.StatusOk { - t.Fatal("unexpected status") + testWithClients(t, func(c *Client) { + request := mock.MakeValidContentRequest() + response, err := c.GetContent(request) + if err != nil { + t.Fatal("unexpected err", err) + } + if request.URI != response.URI { + dump(t, request) + dump(t, response) + t.Fatal("uri mismatch") + } + if response.Status != content.StatusOk { + t.Fatal("unexpected status") + } + }) +} + +func BenchmarkSocketClientAndServerGetContent(b *testing.B) { + socketClient, _ := getTestClients(b) + benchmarkServerAndClientGetContent(b, 30, 100, socketClient) + +} +func BenchmarkWebClientAndServerGetContent(b *testing.B) { + _, httpClient := getTestClients(b) + benchmarkServerAndClientGetContent(b, 30, 100, httpClient) +} + +func benchmarkServerAndClientGetContent(b *testing.B, numGroups, numCalls int, client GetContentClient) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + start := time.Now() + benchmarkClientAndServerGetContent(b, numGroups, numCalls, client) + dur := time.Since(start) + totalCalls := numGroups * numCalls + b.Log("requests per second", int(float64(totalCalls)/(float64(dur)/float64(1000000000))), dur, totalCalls) } } -// not very meaningful yet -func BenchmarkServerAndClient(b *testing.B) { +type GetContentClient interface { + GetContent(request *requests.Content) (response *content.SiteContent, err error) +} + +func benchmarkClientAndServerGetContent(b testing.TB, numGroups, numCalls int, client GetContentClient) { var wg sync.WaitGroup - stats := make([]int, 100) - for group := 0; group < 100; group++ { - wg.Add(1) + wg.Add(numGroups) + for group := 0; group < numGroups; group++ { go func(g int) { defer wg.Done() - c := getTestClient(b) request := mock.MakeValidContentRequest() - for i := 0; i < 1000; i++ { - response, err := c.GetContent(request) - if err != nil { - b.Fatal("unexpected err", err) + for i := 0; i < numCalls; i++ { + response, err := client.GetContent(request) + if err == nil { + if request.URI != response.URI { + b.Fatal("uri mismatch") + } + if response.Status != content.StatusOk { + b.Fatal("unexpected status") + } } - if request.URI != response.URI { - b.Fatal("uri mismatch") - } - if response.Status != content.StatusOk { - b.Fatal("unexpected status") - } - stats[g] = i } }(group) - } // Wait for all HTTP fetches to complete. wg.Wait() - b.Log(stats) + return } diff --git a/client/connectionpool.go b/client/connectionpool.go new file mode 100644 index 0000000..de73455 --- /dev/null +++ b/client/connectionpool.go @@ -0,0 +1,130 @@ +package client + +import ( + "net" + "time" +) + +type connectionPool struct { + server string + conn net.Conn + chanConnGet chan chan net.Conn + chanConnReturn chan connReturn + chanDrainPool chan int +} + +func newConnectionPool(server string, connectionPoolSize int, waitTimeout time.Duration) *connectionPool { + connPool := &connectionPool{ + server: server, + chanConnGet: make(chan chan net.Conn), + chanConnReturn: make(chan connReturn), + chanDrainPool: make(chan int), + } + go connPool.run(connectionPoolSize, waitTimeout) + return connPool +} + +func (c *connectionPool) run(connectionPoolSize int, waitTimeout time.Duration) { + type poolEntry struct { + busy bool + err error + conn net.Conn + } + type waitPoolEntry struct { + entryTime time.Time + chanConn chan net.Conn + } + connectionPool := make(map[int]*poolEntry, connectionPoolSize) + waitPool := map[int]*waitPoolEntry{} + for i := 0; i < connectionPoolSize; i++ { + connectionPool[i] = &poolEntry{ + conn: nil, + busy: false, + } + } +RunLoop: + for { + // fmt.Println("----------------------- run loop ------------------------") + select { + case <-c.chanDrainPool: + // fmt.Println("<-c.chanDrainPool") + for _, waitPoolEntry := range waitPool { + waitPoolEntry.chanConn <- nil + } + break RunLoop + case <-time.After(waitTimeout): + // fmt.Println("tick", len(connectionPool), len(waitPool)) + // for i, poolEntry := range connectionPool { + // fmt.Println(i, poolEntry) + // } + // for i, waitPoolEntry := range waitPool { + // fmt.Println(i, waitPoolEntry) + // } + case chanReturnNextConn := <-c.chanConnGet: + // fmt.Println("chanReturnNextConn := <-c.chanConnGet:") + nextI := 0 + for i := range waitPool { + if i >= nextI { + nextI = i + 1 + } + } + waitPool[nextI] = &waitPoolEntry{ + chanConn: chanReturnNextConn, + entryTime: time.Now(), + } + // fmt.Println("sbdy wants a new conn", nextI) + case connReturn := <-c.chanConnReturn: + // fmt.Println("connReturn := <-c.chanConnReturn:") + for _, poolEntry := range connectionPool { + if connReturn.conn == poolEntry.conn { + poolEntry.busy = false + if connReturn.err != nil { + poolEntry.err = connReturn.err + poolEntry.conn.Close() + poolEntry.conn = nil + } + } + } + } + // refill connection pool + for _, poolEntry := range connectionPool { + if poolEntry.conn == nil { + newConn, errDial := net.Dial("tcp", c.server) + poolEntry.err = errDial + poolEntry.conn = newConn + } + } + // redistribute available connections + for _, poolEntry := range connectionPool { + if len(waitPool) == 0 { + break + } + if poolEntry.err == nil && poolEntry.conn != nil && !poolEntry.busy { + for i, waitPoolEntry := range waitPool { + // fmt.Println("---------------------------> serving wait pool", i, waitPoolEntry) + poolEntry.busy = true + delete(waitPool, i) + waitPoolEntry.chanConn <- poolEntry.conn + break + } + } + } + // waitpool cleanup + waitPoolLoosers := []int{} + now := time.Now() + for i, waitPoolEntry := range waitPool { + if now.Sub(waitPoolEntry.entryTime) > waitTimeout { + waitPoolLoosers = append(waitPoolLoosers, i) + waitPoolEntry.chanConn <- nil + } + } + for _, i := range waitPoolLoosers { + delete(waitPool, i) + } + + } + c.chanDrainPool = nil + c.chanConnReturn = nil + c.chanConnGet = nil + //fmt.Println("runloop is done", waitPool) +} diff --git a/client/httptransport.go b/client/httptransport.go new file mode 100644 index 0000000..aab070d --- /dev/null +++ b/client/httptransport.go @@ -0,0 +1,62 @@ +package client + +import ( + "bytes" + "encoding/json" + "errors" + "io/ioutil" + "net/http" + + "github.com/foomo/contentserver/server" +) + +type httpTransport struct { + client *http.Client + endpoint string +} + +func newHTTPTransport(server string) transport { + return &httpTransport{ + endpoint: server, + client: http.DefaultClient, + } +} + +func (ht *httpTransport) shutdown() { + // nothing to do here +} + +func (ht *httpTransport) call(handler server.Handler, request interface{}, response interface{}) error { + requestBytes, errMarshal := json.Marshal(request) + if errMarshal != nil { + return errMarshal + } + req, errNewRequest := http.NewRequest( + http.MethodPost, + ht.endpoint+"/"+string(handler), + bytes.NewBuffer(requestBytes), + ) + if errNewRequest != nil { + return errNewRequest + } + httpResponse, errDo := ht.client.Do(req) + if errDo != nil { + return errDo + } + if httpResponse.StatusCode != http.StatusOK { + return errors.New("non 200 reply") + } + if httpResponse.Body == nil { + return errors.New("empty response body") + } + responseBytes, errRead := ioutil.ReadAll(httpResponse.Body) + httpResponse.Body.Close() + if errRead != nil { + return errRead + } + errUnmarshal := json.Unmarshal(responseBytes, &serverResponse{Reply: response}) + if errUnmarshal != nil { + return errUnmarshal + } + return errUnmarshal +} diff --git a/client/sockettransport.go b/client/sockettransport.go new file mode 100644 index 0000000..97b524d --- /dev/null +++ b/client/sockettransport.go @@ -0,0 +1,122 @@ +package client + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net" + "strconv" + "time" + + "github.com/foomo/contentserver/responses" + "github.com/foomo/contentserver/server" +) + +type serverResponse struct { + Reply interface{} +} + +type connReturn struct { + conn net.Conn + err error +} + +type socketTransport struct { + connPool *connectionPool +} + +func newSocketTransport(server string, connectionPoolSize int, waitTimeout time.Duration) transport { + return &socketTransport{ + connPool: newConnectionPool(server, connectionPoolSize, waitTimeout), + } +} + +func (st *socketTransport) shutdown() { + if st.connPool.chanDrainPool != nil { + st.connPool.chanDrainPool <- 1 + } +} + +func (c *socketTransport) call(handler server.Handler, request interface{}, response interface{}) error { + if c.connPool.chanDrainPool == nil { + return errors.New("connection pool has been drained, client is dead") + } + jsonBytes, err := json.Marshal(request) + if err != nil { + return fmt.Errorf("could not marshal request : %q", err) + } + netChan := make(chan net.Conn) + c.connPool.chanConnGet <- netChan + conn := <-netChan + if conn == nil { + return errors.New("could not get a connection") + } + returnConn := func(err error) { + c.connPool.chanConnReturn <- connReturn{ + conn: conn, + err: err, + } + } + // write header result will be like handler:2{} + jsonBytes = append([]byte(fmt.Sprintf("%s:%d", handler, len(jsonBytes))), jsonBytes...) + + // send request + written := 0 + l := len(jsonBytes) + for written < l { + n, err := conn.Write(jsonBytes[written:]) + if err != nil { + returnConn(err) + return fmt.Errorf("failed to send request: %q", err) + } + written += n + } + + // read response + responseBytes := []byte{} + buf := make([]byte, 4096) + responseLength := 0 + for { + n, err := conn.Read(buf) + if err != nil && err != io.EOF { + returnConn(err) + return fmt.Errorf("an error occured while reading the response: %q", err) + } + if n == 0 { + break + } + responseBytes = append(responseBytes, buf[0:n]...) + if responseLength == 0 { + for index, byte := range responseBytes { + if byte == 123 { + // opening bracket + responseLength, err = strconv.Atoi(string(responseBytes[0:index])) + if err != nil { + returnConn(err) + return errors.New("could not read response length: " + err.Error()) + } + responseBytes = responseBytes[index:] + break + } + } + } + if responseLength > 0 && len(responseBytes) == responseLength { + break + } + } + // unmarshal response + responseJSONErr := json.Unmarshal(responseBytes, &serverResponse{Reply: response}) + if responseJSONErr != nil { + // is it an error ? + remoteErr := responses.Error{} + remoteErrJSONErr := json.Unmarshal(responseBytes, remoteErr) + if remoteErrJSONErr == nil { + returnConn(remoteErrJSONErr) + return remoteErr + } + return fmt.Errorf("could not unmarshal response : %q %q", remoteErrJSONErr, string(responseBytes)) + } + returnConn(nil) + return nil +} diff --git a/client/transport.go b/client/transport.go new file mode 100644 index 0000000..963d98c --- /dev/null +++ b/client/transport.go @@ -0,0 +1,8 @@ +package client + +import "github.com/foomo/contentserver/server" + +type transport interface { + call(handler server.Handler, request interface{}, response interface{}) error + shutdown() +}