From add9ac903ac8cccff00aef1845edcb7e11e7f37f Mon Sep 17 00:00:00 2001 From: franklin Date: Fri, 21 May 2021 15:25:09 +0200 Subject: [PATCH] feat: don't send reponse on http error --- demo/gotsrpc_gen.go | 7 +++++-- go.go | 13 ++++++++++--- gotsrpc.go | 9 +++------ responsewriter.go | 24 ++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 11 deletions(-) create mode 100644 responsewriter.go diff --git a/demo/gotsrpc_gen.go b/demo/gotsrpc_gen.go index 6260bdb..71375a0 100644 --- a/demo/gotsrpc_gen.go +++ b/demo/gotsrpc_gen.go @@ -332,11 +332,14 @@ func (p *BarGoTSRPCProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } executionStart := time.Now() - helloRet := p.service.Hello(w, r, arg_number) + rw := gotsrpc.ResponseWriter{ResponseWriter: w} + helloRet := p.service.Hello(&rw, r, arg_number) if callStats != nil { callStats.Execution = time.Now().Sub(executionStart) } - gotsrpc.Reply([]interface{}{helloRet}, callStats, r, w) + if rw.Status() == http.StatusOK { + gotsrpc.Reply([]interface{}{helloRet}, callStats, r, w) + } return case "Inheritance": var ( diff --git a/go.go b/go.go index 5ba32fe..4d297d3 100644 --- a/go.go +++ b/go.go @@ -311,19 +311,26 @@ func renderTSRPCServiceProxies(services ServiceList, fullPackageName string, pac returnValueNames = append(returnValueNames, lcfirst(method.Name)+ucfirst(retArgName)) } g.l("executionStart := time.Now()") + if isSessionRequest { + g.l("rw := gotsrpc.ResponseWriter{ResponseWriter: w}") + callArgs = append([]string{"&rw", "r"}, callArgs...) + } if len(returnValueNames) > 0 { g.app(strings.Join(returnValueNames, ", ") + " := ") } - if isSessionRequest { - callArgs = append([]string{"w", "r"}, callArgs...) - } g.app("p.service." + method.Name + "(" + strings.Join(callArgs, ", ") + ")") g.nl() g.l("if callStats != nil {") g.ind(1).l("callStats.Execution = time.Now().Sub(executionStart)").ind(-1) g.l("}") + if isSessionRequest { + g.l("if rw.Status() == http.StatusOK {").ind(1) + } g.l("gotsrpc.Reply([]interface{}{" + strings.Join(returnValueNames, ", ") + "}, callStats, r, w)") + if isSessionRequest { + g.ind(-1).l("}") + } g.l("return") g.ind(-1) } diff --git a/gotsrpc.go b/gotsrpc.go index edcc6df..c0fe171 100644 --- a/gotsrpc.go +++ b/gotsrpc.go @@ -28,18 +28,15 @@ func GetCalledFunc(r *http.Request, endPoint string) string { } func ErrorFuncNotFound(w http.ResponseWriter) { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte("method not found")) + http.Error(w, "method not found", http.StatusNotFound) } func ErrorCouldNotLoadArgs(w http.ResponseWriter) { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("could not load args")) + http.Error(w, "could not load args", http.StatusBadRequest) } func ErrorMethodNotAllowed(w http.ResponseWriter) { - w.WriteHeader(http.StatusMethodNotAllowed) - w.Write([]byte("you gotta POST")) + http.Error(w, "you gotta POST", http.StatusMethodNotAllowed) } func LoadArgs(args interface{}, callStats *CallStats, r *http.Request) error { diff --git a/responsewriter.go b/responsewriter.go new file mode 100644 index 0000000..2fca722 --- /dev/null +++ b/responsewriter.go @@ -0,0 +1,24 @@ +package gotsrpc + +import ( + "net/http" +) + +type ResponseWriter struct { + http.ResponseWriter + wroteHeader bool + status int +} + +func (r *ResponseWriter) WriteHeader(status int) { + r.status = status + r.wroteHeader = true + r.ResponseWriter.WriteHeader(status) +} + +func (r *ResponseWriter) Status() int { + if !r.wroteHeader { + return http.StatusOK + } + return r.status +}