package gotsrpc import ( "fmt" "sort" "strings" "github.com/foomo/gotsrpc/v2/config" ) func (v *Value) isHTTPResponseWriter() bool { return (v.StructType != nil && v.StructType.Name == "ResponseWriter" && v.StructType.Package == "net/http") || (v.Scalar != nil && v.Scalar.Name == "ResponseWriter" && v.Scalar.Package == "net/http") } func (v *Value) isHTTPRequest() bool { return (v.IsPtr && v.StructType != nil && v.StructType.Name == "Request" && v.StructType.Package == "net/http") || (v.IsPtr && v.Scalar != nil && v.Scalar.Name == "Request" && v.Scalar.Package == "net/http") } func (v *Value) goType(aliases map[string]string, packageName string) (t string) { if v.IsPtr { t = "*" } switch { case v.Array != nil: t += "[]" + v.Array.Value.goType(aliases, packageName) case len(v.GoScalarType) > 0: t += v.GoScalarType case v.StructType != nil: if packageName != v.StructType.Package && aliases[v.StructType.Package] != "" { t += aliases[v.StructType.Package] + "." } t += v.StructType.Name case v.Map != nil: t += `map[` + v.Map.Key.goType(aliases, packageName) + `]` + v.Map.Value.goType(aliases, packageName) case v.Scalar != nil: if packageName != v.Scalar.Package && aliases[v.Scalar.Package] != "" { t += aliases[v.Scalar.Package] + "." } t += v.Scalar.Name case v.IsInterface: t += "interface{}" default: // TODO fmt.Println("WARN: can't resolve goType") } return } func lcfirst(str string) string { return strfirst(str, strings.ToLower) } func ucfirst(str string) string { return strfirst(str, strings.ToUpper) } func strfirst(str string, strfunc func(string) string) string { res := "" for i, char := range str { if i == 0 { res += strfunc(string(char)) } else { res += string(char) } } return res } func extractImport(packageName string, fullPackageName string, aliases map[string]string) { r := strings.NewReplacer(".", "_", "/", "_", "-", "_") if packageName != fullPackageName { if _, ok := aliases[packageName]; !ok { packageParts := strings.Split(packageName, "/") beautifulAlias := packageParts[len(packageParts)-1] uglyAlias := r.Replace(packageName) alias := uglyAlias // beautifulAlias for _, otherAlias := range aliases { if otherAlias == beautifulAlias { alias = uglyAlias break } } aliases[packageName] = alias } } } func extractImports(fields []*Field, fullPackageName string, aliases map[string]string) { for _, f := range fields { extractImportValue(f.Value, fullPackageName, aliases) } } func extractImportValue(value *Value, fullPackageName string, aliases map[string]string) { switch { case value.StructType != nil: extractImport(value.StructType.Package, fullPackageName, aliases) case value.Array != nil: extractImportValue(value.Array.Value, fullPackageName, aliases) case value.Map != nil: extractImportValue(value.Map.Key, fullPackageName, aliases) extractImportValue(value.Map.Value, fullPackageName, aliases) case value.Scalar != nil: extractImport(value.Scalar.Package, fullPackageName, aliases) } } func renderTSRPCServiceProxies(services ServiceList, fullPackageName string, packageName string, config *config.Target, unions map[string][]string, g *code) error { aliases := map[string]string{ "time": "time", "net/http": "http", "io": "io", "github.com/foomo/gotsrpc/v2": "gotsrpc", } for _, service := range services { // Check if we should render this service as ts rpc // Note: remove once there's a separate gorcp generator if !config.IsTSRPC(service.Name) { continue } for _, m := range service.Methods { extractImports(m.Args, fullPackageName, aliases) } } for pkg := range unions { extractImport(pkg, fullPackageName, aliases) } g.l(renderImports(aliases, packageName)) renderInit(unions, aliases, packageName, g) for _, service := range services { // Check if we should render this service as ts rcp // Note: remove once there's a separate gorcp generator if !config.IsTSRPC(service.Name) { continue } servicePointer := "*" if service.IsInterface { servicePointer = "" } proxyName := service.Name + "GoTSRPCProxy" g.l("const (") for _, method := range service.Methods { g.l(proxyName + method.Name + " = \"" + method.Name + "\"") } g.l(")") g.l(` type ` + proxyName + ` struct { EndPoint string service ` + servicePointer + service.Name + ` } func NewDefault` + proxyName + `(service ` + servicePointer + service.Name + `) *` + proxyName + ` { return New` + proxyName + `(service, "` + service.Endpoint + `") } func New` + proxyName + `(service ` + servicePointer + service.Name + `, endpoint string) *` + proxyName + ` { return &` + proxyName + `{ EndPoint: endpoint, service: service, } } // ServeHTTP exposes your service func (p *` + proxyName + `) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { return } else if r.Method != http.MethodPost { gotsrpc.ErrorMethodNotAllowed(w) return } defer io.Copy(io.Discard, r.Body) // Drain Request Body `) g.l("funcName := gotsrpc.GetCalledFunc(r, p.EndPoint)") g.l("callStats, _ := gotsrpc.GetStatsForRequest(r)") g.l("callStats.Func = funcName") g.l("callStats.Package = \"" + fullPackageName + "\"") g.l("callStats.Service = \"" + service.Name + "\"") g.l(`switch funcName {`) // indenting into switch cases g.ind(4) for _, method := range service.Methods { // a case for each method g.l("case " + proxyName + method.Name + ":") g.ind(1) var callArgs []string isSessionRequest := false g.l("var (") g.ind(1) g.l("args []interface{}") g.l("rets []interface{}") g.ind(-1) g.l(")") if len(method.Args) > 0 { var args []string var argsDecls []string skipArgI := 0 nonHTTPRelatedArgs := goMethodArgsWithoutHTTPContextRelatedArgs(method) isSessionRequest = len(method.Args)-len(nonHTTPRelatedArgs) == 2 for _, arg := range nonHTTPRelatedArgs { argName := "arg_" + arg.Name argsDecls = append(argsDecls, argName+" "+arg.Value.goType(aliases, packageName)) args = append(args, "&"+argName) callArgs = append(callArgs, argName) skipArgI++ } if len(args) > 0 { g.l("var (") for _, argDecl := range argsDecls { g.l(argDecl) } g.l(")") g.l("args = []interface{}{" + strings.Join(args, ", ") + "}") g.l("if err := gotsrpc.LoadArgs(&args, callStats, r); err != nil {") g.ind(1) g.l("gotsrpc.ErrorCouldNotLoadArgs(w)") g.l("return") g.ind(-1) g.l("}") } } var returnValueNames []string for retI, retField := range method.Return { retArgName := retField.Name if len(retArgName) == 0 { retArgName = "ret" if retI > 0 { retArgName += "_" + fmt.Sprint(retI) } } returnValueNames = append(returnValueNames, lcfirst(method.Name)+ucfirst(retArgName)) } g.l("executionStart := time.Now()") if isSessionRequest { g.l("rw := gotsrpc.ResponseWriter{ResponseWriter: w}") callArgs = append([]string{"&rw", "r"}, callArgs...) } if len(returnValueNames) > 0 { g.app(strings.Join(returnValueNames, ", ") + " := ") } g.app("p.service." + method.Name + "(" + strings.Join(callArgs, ", ") + ")") g.nl() g.l("callStats.Execution = time.Since(executionStart)") if isSessionRequest { g.l("if rw.Status() == http.StatusOK {").ind(1) } g.l("rets = []interface{}{" + strings.Join(returnValueNames, ", ") + "}") g.l("if err := gotsrpc.Reply(rets, callStats, r, w); err != nil {") g.ind(1) g.l("gotsrpc.ErrorCouldNotReply(w)") g.l("return") g.ind(-1) g.l("}") if isSessionRequest { g.ind(-1).l("}") } g.l("gotsrpc.Monitor(w, r, args, rets, callStats)") g.l("return") g.ind(-1) } g.l("default:") g.ind(1).l("gotsrpc.ClearStats(r)") g.ind(1).l("gotsrpc.ErrorFuncNotFound(w)") g.ind(-2).l("}") // close switch g.ind(-1).l("}") // close ServeHttp } return nil } type goMethod struct { name string params []string args []string rets []string returns []string } func newMethodSignature(method *Method, aliases map[string]string, fullPackageName string) goMethod { var args []string var params []string params = append(params, "ctx go_context.Context") for _, a := range goMethodArgsWithoutHTTPContextRelatedArgs(method) { args = append(args, a.Name) params = append(params, a.Name+" "+a.Value.goType(aliases, fullPackageName)) } var rets []string var returns []string for i, r := range method.Return { name := r.Name if len(name) == 0 { name = fmt.Sprintf("ret%s_%d", method.Name, i) } rets = append(rets, "&"+name) returns = append(returns, name+" "+r.Value.goType(aliases, fullPackageName)) } returns = append(returns, "clientErr error") return goMethod{ name: method.Name, params: params, args: args, rets: rets, returns: returns, } } func (ms *goMethod) renderSignature() string { return ms.name + `(` + strings.Join(ms.params, ", ") + `) (` + strings.Join(ms.returns, ", ") + `)` } func renderTSRPCServiceClients(services ServiceList, fullPackageName string, packageName string, config *config.Target, g *code) error { aliases := map[string]string{ "github.com/pkg/errors": "pkg_errors", "github.com/foomo/gotsrpc/v2": "gotsrpc", "net/http": "go_net_http", "context": "go_context", } for _, service := range services { // Check if we should render this service as ts rcp // Note: remove once there's a separate gorcp generator if !config.IsTSRPC(service.Name) { continue } for _, m := range service.Methods { extractImports(m.Args, fullPackageName, aliases) extractImports(m.Return, fullPackageName, aliases) } } g.l(renderImports(aliases, packageName)) for _, service := range services { // Check if we should render this service as ts rcp // Note: remove once there's a separate gorcp generator if !config.IsTSRPC(service.Name) { continue } interfaceName := service.Name + "GoTSRPCClient" clientName := "HTTP" + interfaceName // Render Interface g.l(`type ` + interfaceName + ` interface { `) for _, method := range service.Methods { ms := newMethodSignature(method, aliases, fullPackageName) g.l(ms.renderSignature()) } g.l(`} `) // Render Constructors g.l(` type ` + clientName + ` struct { URL string EndPoint string Client gotsrpc.Client } func NewDefault` + interfaceName + `(url string) *` + clientName + ` { return New` + interfaceName + `(url, "` + service.Endpoint + `") } func New` + interfaceName + `(url string, endpoint string) *` + clientName + ` { return New` + interfaceName + `WithClient(url, endpoint, nil) } func New` + interfaceName + `WithClient(url string, endpoint string, client *go_net_http.Client) *` + clientName + ` { return &` + clientName + `{ URL: url, EndPoint: endpoint, Client: gotsrpc.NewClientWithHttpClient(client), } }`) for _, method := range service.Methods { ms := newMethodSignature(method, aliases, fullPackageName) g.l(`func (tsc *` + clientName + `) ` + ms.renderSignature() + ` {`) g.l(`args := []interface{}{` + strings.Join(ms.args, ", ") + `}`) g.l(`reply := []interface{}{` + strings.Join(ms.rets, ", ") + `}`) g.l(`clientErr = tsc.Client.Call(ctx, tsc.URL, tsc.EndPoint, "` + method.Name + `", args, reply)`) g.l(`if clientErr != nil {`) g.ind(1).l(`clientErr = pkg_errors.WithMessage(clientErr, "failed to call ` + packageName + `.` + service.Name + `GoTSRPCProxy ` + method.Name + `")`).ind(-1) g.l(`}`) g.l(`return`) g.l(`}`) g.nl() } } return nil } func renderGoRPCServiceProxies(services ServiceList, fullPackageName string, packageName string, config *config.Target, g *code) error { aliases := map[string]string{ "fmt": "fmt", "time": "time", "strings": "strings", "reflect": "reflect", "crypto/tls": "tls", "encoding/gob": "gob", "github.com/valyala/gorpc": "gorpc", "github.com/foomo/gotsrpc/v2": "gotsrpc", } for _, service := range services { if !config.IsGoRPC(service.Name) { continue } for _, m := range service.Methods { extractImports(m.Args, fullPackageName, aliases) extractImports(m.Return, fullPackageName, aliases) } } g.l(renderImports(aliases, packageName)) for _, service := range services { if !config.IsGoRPC(service.Name) { continue } servicePointer := "*" if service.IsInterface { servicePointer = "" } proxyName := service.Name + "GoRPCProxy" // Types g.l(`type (`) // Proxy type g.l(` ` + proxyName + ` struct { server *gorpc.Server service ` + servicePointer + service.Name + ` callStatsHandler gotsrpc.GoRPCCallStatsHandlerFun } `) // Request & Response types for _, method := range service.Methods { // Request type g.l(ucfirst(service.Name+method.Name) + `Request struct {`) for _, a := range goMethodArgsWithoutHTTPContextRelatedArgs(method) { g.l(ucfirst(a.Name) + ` ` + a.Value.goType(aliases, fullPackageName)) } g.l(`}`) // Response type g.l(ucfirst(service.Name+method.Name) + `Response struct {`) for i, r := range method.Return { name := r.Name if len(name) == 0 { name = fmt.Sprintf("ret%s_%d", method.Name, i) } g.l(ucfirst(name) + ` ` + r.Value.goType(aliases, fullPackageName)) } g.l(`}`) g.nl() } g.l(`)`) g.nl() // Init g.l(`func init() {`) for _, method := range service.Methods { g.l(`gob.Register(` + ucfirst(service.Name+method.Name) + `Request{})`) g.l(`gob.Register(` + ucfirst(service.Name+method.Name) + `Response{})`) } g.l(`}`) // Constructor g.l(` func New` + proxyName + `(addr string, service ` + servicePointer + service.Name + `, tlsConfig *tls.Config) *` + proxyName + ` { proxy := &` + proxyName + `{ service: service, } if tlsConfig != nil { proxy.server = gorpc.NewTLSServer(addr, proxy.handler, tlsConfig) } else { proxy.server = gorpc.NewTCPServer(addr, proxy.handler) } return proxy } func (p *` + proxyName + `) Start() error { return p.server.Start() } func (p *` + proxyName + `) Serve() error { return p.server.Serve() } func (p *` + proxyName + `) Stop() { p.server.Stop() } func (p *` + proxyName + `) SetCallStatsHandler(handler gotsrpc.GoRPCCallStatsHandlerFun) { p.callStatsHandler = handler } `) g.nl() // Handler g.l(`func (p *` + proxyName + `) handler(clientAddr string, request interface{}) (response interface{}) {`) g.l(`start := time.Now()`) g.nl() g.l(`reqType := reflect.TypeOf(request).String()`) g.l(`funcNameParts := strings.Split(reqType, ".")`) g.l(`funcName := funcNameParts[len(funcNameParts)-1]`) g.nl() g.l(`switch funcName {`) for _, method := range service.Methods { argParams := []string{} nonHTTPRelatedMethodArgs := goMethodArgsWithoutHTTPContextRelatedArgs(method) diffNONHTTPRelatedMethodArgs := len(method.Args) - len(nonHTTPRelatedMethodArgs) for i := 0; i < diffNONHTTPRelatedMethodArgs; i++ { argParams = append(argParams, "nil") } for _, a := range nonHTTPRelatedMethodArgs { argParams = append(argParams, "req."+ucfirst(a.Name)) } rets := []string{} retParams := []string{} for i, r := range method.Return { name := r.Name if len(name) == 0 { name = fmt.Sprintf("ret%s_%d", method.Name, i) } rets = append(rets, name) retParams = append(retParams, ucfirst(name)+`: `+name) } g.l(`case "` + service.Name + method.Name + `Request":`) if len(nonHTTPRelatedMethodArgs) > 0 { g.l(`req := request.(` + service.Name + method.Name + `Request)`) } if len(rets) > 0 { g.l(strings.Join(rets, ", ") + ` := p.service.` + method.Name + `(` + strings.Join(argParams, ", ") + `)`) } else { g.l(`p.service.` + method.Name + `(` + strings.Join(argParams, ", ") + `)`) } g.l(`response = ` + service.Name + method.Name + `Response{` + strings.Join(retParams, ", ") + `}`) } g.l(`default:`) g.l(`fmt.Println("Unknown request type", reflect.TypeOf(request).String())`) g.l(`}`) g.nl() g.l(`if p.callStatsHandler != nil {`) g.l(`p.callStatsHandler(&gotsrpc.CallStats{`) g.l(`Func: funcName,`) g.l(`Package: "` + fullPackageName + `",`) g.l(`Service: "` + service.Name + `",`) g.l(`Execution: time.Since(start),`) g.l(`})`) g.l(`}`) g.nl() g.l(`return`) g.l(`}`) } return nil } func renderGoRPCServiceClients(services ServiceList, fullPackageName string, packageName string, config *config.Target, g *code) error { aliases := map[string]string{ "crypto/tls": "tls", "github.com/valyala/gorpc": "gorpc", } for _, service := range services { if !config.IsGoRPC(service.Name) { continue } for _, m := range service.Methods { extractImports(m.Args, fullPackageName, aliases) extractImports(m.Return, fullPackageName, aliases) } } imports := "" for packageName, alias := range aliases { imports += alias + " \"" + packageName + "\"\n" } g.l(renderImports(aliases, packageName)) for _, service := range services { if !config.IsGoRPC(service.Name) { continue } clientName := service.Name + "GoRPCClient" // Client type g.l(` type ` + clientName + ` struct { Client *gorpc.Client } `) // Constructor g.l(` func New` + clientName + `(addr string, tlsConfig *tls.Config) *` + clientName + ` { client := &` + clientName + `{} if tlsConfig == nil { client.Client = gorpc.NewTCPClient(addr) } else { client.Client = gorpc.NewTLSClient(addr, tlsConfig) } return client } func (tsc *` + clientName + `) Start() { tsc.Client.Start() } func (tsc *` + clientName + `) Stop() { tsc.Client.Stop() } `) g.nl() // Methods for _, method := range service.Methods { args := []string{} params := []string{} for _, a := range goMethodArgsWithoutHTTPContextRelatedArgs(method) { args = append(args, ucfirst(a.Name)+`: `+a.Name) params = append(params, a.Name+" "+a.Value.goType(aliases, fullPackageName)) } rets := []string{} returns := []string{} for i, r := range method.Return { name := r.Name if len(name) == 0 { name = fmt.Sprintf("ret%s_%d", method.Name, i) } rets = append(rets, "response."+ucfirst(name)) returns = append(returns, name+" "+r.Value.goType(aliases, fullPackageName)) } returns = append(returns, "clientErr error") g.l(`func (tsc *` + clientName + `) ` + method.Name + `(` + strings.Join(params, ", ") + `) (` + strings.Join(returns, ", ") + `) {`) g.l(`req := ` + service.Name + method.Name + `Request{` + strings.Join(args, ", ") + `}`) if len(rets) > 0 { g.l(`rpcCallRes, rpcCallErr := tsc.Client.Call(req)`) } else { g.l(`_, rpcCallErr := tsc.Client.Call(req)`) } g.l(`if rpcCallErr != nil {`) g.l(`clientErr = rpcCallErr`) g.l(`return`) g.l(`}`) if len(rets) > 0 { g.l(`response := rpcCallRes.(` + service.Name + method.Name + `Response)`) g.l(`return ` + strings.Join(rets, ", ") + `, nil`) } else { g.l(`return nil`) } g.l(`}`) g.nl() } } return nil } func RenderGoTSRPCProxies(services ServiceList, longPackageName, packageName string, config *config.Target, unions map[string][]string) (gocode string, err error) { g := newCode(" ") err = renderTSRPCServiceProxies(services, longPackageName, packageName, config, unions, g) if err != nil { return } gocode = g.string() return } func RenderGoTSRPCClients(services ServiceList, longPackageName, packageName string, config *config.Target) (gocode string, err error) { g := newCode(" ") err = renderTSRPCServiceClients(services, longPackageName, packageName, config, g) if err != nil { return } gocode = g.string() return } func RenderGoRPCProxies(services ServiceList, longPackageName, packageName string, config *config.Target) (gocode string, err error) { g := newCode(" ") err = renderGoRPCServiceProxies(services, longPackageName, packageName, config, g) if err != nil { return } gocode = g.string() return } func RenderGoRPCClients(services ServiceList, longPackageName, packageName string, config *config.Target) (gocode string, err error) { g := newCode(" ") err = renderGoRPCServiceClients(services, longPackageName, packageName, config, g) if err != nil { return } gocode = g.string() return } func goMethodArgsWithoutHTTPContextRelatedArgs(m *Method) (filteredArgs []*Field) { filteredArgs = []*Field{} for argI, arg := range m.Args { if argI == 0 && arg.Value.isHTTPResponseWriter() { continue } if argI == 1 && arg.Value.isHTTPRequest() { continue } filteredArgs = append(filteredArgs, arg) } return } func renderInit(unions map[string][]string, aliases map[string]string, packageName string, g *code) { if len(unions) > 0 { g.l("func init() {") g.ind(1) var strs []string for pkg, us := range unions { for _, name := range us { var str string if packageName != pkg && aliases[pkg] != "" { str += aliases[pkg] + "." } str += name strs = append(strs, str) } } sort.Strings(strs) for _, str := range strs { g.l("gotsrpc.MustRegisterUnionExt(" + str + "{})") } g.ind(-1) g.l("}") } } func renderImports(aliases map[string]string, packageName string) string { imports := "" for importPath, alias := range aliases { imports += alias + " \"" + importPath + "\"\n" } return ` // Code generated by gotsrpc https://github.com/foomo/gotsrpc/v2 - DO NOT EDIT. package ` + packageName + ` import ( ` + imports + ` ) ` }