diff --git a/build.go b/build.go index 10ca8d8..6202d2c 100644 --- a/build.go +++ b/build.go @@ -28,7 +28,7 @@ func Build(conf *config.Config, goPath string) { } packageName := longPackageNameParts[len(longPackageNameParts)-1] - services, structs, err := Read(goPath, longPackageName, target.Services) + services, structs, constants, err := Read(goPath, longPackageName, target.Services) if err != nil { fmt.Fprintln(os.Stderr, " an error occured while trying to understand your code", err) @@ -46,7 +46,7 @@ func Build(conf *config.Config, goPath string) { fmt.Fprintln(os.Stderr, " could not write service file", target.Out, updateErr) os.Exit(3) } - err = RenderStructsToPackages(structs, conf.Mappings, mappedTypeScript) + err = RenderStructsToPackages(structs, conf.Mappings, constants, mappedTypeScript) if err != nil { fmt.Fprintln(os.Stderr, "struct gen err for target", name, err) os.Exit(4) @@ -80,17 +80,21 @@ func Build(conf *config.Config, goPath string) { } fmt.Fprintln(os.Stderr, "building structs for go package", goPackage, "to ts module", mapping.TypeScriptModule, "in file", mapping.Out) - moduleCode := newCode().l("module " + mapping.TypeScriptModule + "{").ind(1) - structNames := []string{} + moduleCode := newCode(" ").l("module " + mapping.TypeScriptModule + " {").ind(1) + structNames := []string{"___goConstants"} for structName := range mappedStructsMap { structNames = append(structNames, structName) } sort.Strings(structNames) for _, structName := range structNames { - structCode := mappedStructsMap[structName] - moduleCode.app(structCode.ind(-1).l("").string()) + + structCode, ok := mappedStructsMap[structName] + if ok { + moduleCode.app(structCode.ind(-1).l("").string()) + } } + moduleCode.ind(-1).l("}") updateErr := updateCode(mapping.Out, moduleCode.string()) if updateErr != nil { diff --git a/code.go b/code.go index 939317b..777538d 100644 --- a/code.go +++ b/code.go @@ -6,23 +6,28 @@ type code struct { line string lines []string indent int + tab string } -func newCode() *code { +func newCode(tab string) *code { return &code{ line: "", lines: []string{}, indent: 0, + tab: tab, } } func (c *code) ind(inc int) *code { c.indent += inc + if c.indent < 0 { + c.indent = 0 + } return c } func (c *code) nl() *code { - c.lines = append(c.lines, strings.Repeat(" ", c.indent)+c.line) + c.lines = append(c.lines, strings.Repeat(c.tab, c.indent)+c.line) c.line = "" return c } diff --git a/go.go b/go.go index b0a0286..f11f80d 100644 --- a/go.go +++ b/go.go @@ -261,7 +261,7 @@ func renderServiceProxies(services []*Service, fullPackageName string, packageNa } func RenderGo(services []*Service, longPackageName, packageName string) (gocode string, err error) { - g := newCode() + g := newCode(" ") err = renderServiceProxies(services, longPackageName, packageName, g) if err != nil { return diff --git a/gotsrpc.go b/gotsrpc.go index 91443ba..6b23dca 100644 --- a/gotsrpc.go +++ b/gotsrpc.go @@ -70,7 +70,6 @@ func parsePackage(goPath string, packageName string) (pkg *ast.Package, err erro strippedPackageName := packageNameParts[len(packageNameParts)-1] foundPackages := []string{} for pkgName, pkg := range pkgs { - //fmt.Println("pkgName", pkgName) if pkgName == strippedPackageName { return pkg, nil } diff --git a/servicereader.go b/servicereader.go index 0051c87..d870390 100644 --- a/servicereader.go +++ b/servicereader.go @@ -141,7 +141,38 @@ func readServicesInPackage(pkg *ast.Package, packageName string, serviceNames [] return } -func Read(goPath string, packageName string, serviceNames []string) (services []*Service, structs map[string]*Struct, err error) { +func loadConstants(pkg *ast.Package) map[string]*ast.BasicLit { + constants := map[string]*ast.BasicLit{} + for _, file := range pkg.Files { + for _, decl := range file.Decls { + if reflect.ValueOf(decl).Type().String() == "*ast.GenDecl" { + genDecl := decl.(*ast.GenDecl) + if genDecl.Tok == token.CONST { + trace("got a const", genDecl.Specs) + for _, spec := range genDecl.Specs { + if "*ast.ValueSpec" == reflect.ValueOf(spec).Type().String() { + spec := spec.(*ast.ValueSpec) + for _, val := range spec.Values { + if reflect.ValueOf(val).Type().String() == "*ast.BasicLit" { + + firstValueLit := val.(*ast.BasicLit) + //fmt.Println("a value spec", spec.Names[0], firstValueLit.Kind, firstValueLit.Value) + constants[spec.Names[0].String()] = firstValueLit //.Value + + } + + } + } + } + } + } + } + } + return constants + +} + +func Read(goPath string, packageName string, serviceNames []string) (services []*Service, structs map[string]*Struct, constants map[string]map[string]*ast.BasicLit, err error) { if len(serviceNames) == 0 { err = errors.New("nothing to do service names are empty") return @@ -150,6 +181,7 @@ func Read(goPath string, packageName string, serviceNames []string) (services [] if err != nil { return } + services, err = readServicesInPackage(pkg, packageName, serviceNames) if err != nil { return @@ -168,9 +200,22 @@ func Read(goPath string, packageName string, serviceNames []string) (services [] } collectErr := collectStructs(goPath, structs) if collectErr != nil { - err = errors.New("error while collecting structs: " + collectErr.Error()) } + constants = map[string]map[string]*ast.BasicLit{} + for _, structDef := range structs { + structPackage := structDef.Package + _, ok := constants[structPackage] + if !ok { + pkg, constPkgErr := parsePackage(goPath, structPackage) + if constPkgErr != nil { + err = constPkgErr + return + } + constants[structPackage] = loadConstants(pkg) + + } + } return } @@ -199,7 +244,7 @@ func collectStructs(goPath string, structs map[string]*Struct) error { packageName := strings.Join(fullNameParts, ".") - // trace(fullName, "==========================>", fullNameParts, "=============>", packageName) + trace(fullName, "==========================>", fullNameParts, "=============>", packageName) packageStructs, ok := scannedPackages[packageName] if !ok { @@ -211,7 +256,7 @@ func collectStructs(goPath string, structs map[string]*Struct) error { scannedPackages[packageName] = packageStructs } for packageStructName, packageStruct := range packageStructs { - // trace("------------------------------------>", packageStructName, packageStruct) + trace("------------------------------------>", packageStructName, packageStruct) existingStruct, needed := structs[packageStructName] if needed && existingStruct == nil { structs[packageStructName] = packageStruct @@ -306,20 +351,35 @@ func getStructsInPackage(goPath string, packageName string) (structs map[string] return structs, nil } +func getStructTypeForField(value *Value) *StructType { + //field.Value.StructType + var strType *StructType + switch true { + case value.StructType != nil: + strType = value.StructType + //case field.Value.ArrayType + case value.Map != nil: + strType = getStructTypeForField(value.Map.Value) + case value.Array != nil: + strType = getStructTypeForField(value.Array.Value) + } + return strType +} + func collecStructTypes(fields []*Field, structTypes map[string]*StructType) { for _, field := range fields { - if field.Value.StructType != nil { - fullName := field.Value.StructType.Package + "." + field.Value.StructType.Name - if len(field.Value.StructType.Package) == 0 { - fullName = field.Value.StructType.Name + strType := getStructTypeForField(field.Value) + if strType != nil { + fullName := strType.Package + "." + strType.Name + if len(strType.Package) == 0 { + fullName = strType.Name } switch fullName { case "error", "net/http.Request", "net/http.ResponseWriter": continue default: - structTypes[fullName] = field.Value.StructType + structTypes[fullName] = strType } - } } } diff --git a/typereader.go b/typereader.go index 0146dbd..d82a7f5 100644 --- a/typereader.go +++ b/typereader.go @@ -184,11 +184,13 @@ func readAstStructType(v *Value, structType *ast.StructType, fileImports fileImp } func (v *Value) loadExpr(expr ast.Expr, fileImports fileImportSpecMap) { - //fmt.Println(field.Names[0].Name, field.Type, reflect.ValueOf(field.Type).Type().String()) + switch reflect.ValueOf(expr).Type().String() { case "*ast.ArrayType": + fieldArray := expr.(*ast.ArrayType) v.Array = &Array{Value: &Value{}} + switch reflect.ValueOf(fieldArray.Elt).Type().String() { case "*ast.Ident": readAstType(v.Array.Value, fieldArray.Elt.(*ast.Ident), fileImports) diff --git a/typescript.go b/typescript.go index e3c01a0..cc263d6 100644 --- a/typescript.go +++ b/typescript.go @@ -3,11 +3,15 @@ package gotsrpc import ( "errors" "fmt" + "go/ast" + "sort" "strings" "github.com/foomo/gotsrpc/config" ) +const goConstPseudoPackage = "__goConstants" + var SkipGoTSRPC = false func (f *Field) tsName() string { @@ -67,16 +71,6 @@ func renderStruct(str *Struct, mappings config.TypeScriptMappings, ts *code) err return nil } -/* - export class ServiceClient { - static defaultInst = new ServiceClient() - constructor(public endPoint:string = "/service") { } - hello(name:string, success:(reply:string, err:Err) => void, err:(request:XMLHttpRequest) => void) { - GoTSRPC.call(this.endPoint, "Hello", [name], success, err); - } - } -*/ - func renderService(service *Service, mappings config.TypeScriptMappings, ts *code) error { clientName := service.Name + "Client" ts.l("export class " + clientName + " {").ind(1). @@ -131,7 +125,7 @@ func renderService(service *Service, mappings config.TypeScriptMappings, ts *cod ts.l("}") return nil } -func RenderStructsToPackages(structs map[string]*Struct, mappings config.TypeScriptMappings, mappedTypeScript map[string]map[string]*code) (err error) { +func RenderStructsToPackages(structs map[string]*Struct, mappings config.TypeScriptMappings, constants map[string]map[string]*ast.BasicLit, mappedTypeScript map[string]map[string]*code) (err error) { codeMap := map[string]map[string]*code{} for _, mapping := range mappings { @@ -148,26 +142,93 @@ func RenderStructsToPackages(structs map[string]*Struct, mappings config.TypeScr err = errors.New("missing code mapping for go package : " + str.Package + " => you have to add a mapping from this go package to a TypeScript module in your build-config.yml in the mappings section") return } - packageCodeMap[str.Name] = newCode().ind(1) + packageCodeMap[str.Name] = newCode(" ").ind(1) err = renderStruct(str, mappings, packageCodeMap[str.Name]) if err != nil { return } } + ensureCodeInPackage := func(goPackage string) { + _, ok := mappedTypeScript[goPackage] + if !ok { + mappedTypeScript[goPackage] = map[string]*code{} + } + return + } for _, mapping := range mappings { for structName, structCode := range codeMap[mapping.GoPackage] { - _, ok := mappedTypeScript[mapping.GoPackage] - if !ok { - mappedTypeScript[mapping.GoPackage] = map[string]*code{} - } + ensureCodeInPackage(mapping.GoPackage) mappedTypeScript[mapping.GoPackage][structName] = structCode } - //.ind(-1).l("}").string() + } + for packageName, packageConstants := range constants { + if len(packageConstants) > 0 { + ensureCodeInPackage(packageName) + _, done := mappedTypeScript[packageName][goConstPseudoPackage] + if done { + continue + } + constCode := newCode(" ").ind(1).l("// constants from " + packageName).l("export const GoConst = {").ind(1) + //constCode.l() + mappedTypeScript[packageName][goConstPseudoPackage] = constCode + constPrefixParts := split(packageName, []string{"/", ".", "-"}) + constPrefix := "" + for _, constPrefixPart := range constPrefixParts { + constPrefix += ucFirst(constPrefixPart) + } + constNames := []string{} + for constName, _ := range packageConstants { + constNames = append(constNames, constName) + } + sort.Strings(constNames) + for _, constName := range constNames { + basicLit := packageConstants[constName] + constCode.l(fmt.Sprint(constName, " : ", basicLit.Value, ",")) + } + constCode.ind(-1).l("}") + + } + } return nil } + +func split(str string, seps []string) []string { + res := []string{} + strs := []string{str} + for _, sep := range seps { + nextStrs := []string{} + for _, str := range strs { + for _, part := range strings.Split(str, sep) { + nextStrs = append(nextStrs, part) + } + } + strs = nextStrs + res = nextStrs + } + return res +} + +func ucFirst(str string) string { + strUpper := strings.ToUpper(str) + constPrefix := "" + var firstRune rune + for _, strUpperRune := range strUpper { + firstRune = strUpperRune + break + } + constPrefix += string(firstRune) + for i, strRune := range str { + if i == 0 { + continue + } + constPrefix += string(strRune) + } + return constPrefix +} + func RenderTypeScriptServices(services []*Service, mappings config.TypeScriptMappings, tsModuleName string) (typeScript string, err error) { - ts := newCode() + ts := newCode(" ") if !SkipGoTSRPC { ts.l(`module GoTSRPC { export function call(endPoint:string, method:string, args:any[], success:any, err:any) { diff --git a/typescript_test.go b/typescript_test.go new file mode 100644 index 0000000..7ef8f2b --- /dev/null +++ b/typescript_test.go @@ -0,0 +1,14 @@ +package gotsrpc + +import "testing" + +func TestSplit(t *testing.T) { + res := split("git.bestbytes.net/foo-bar", []string{".", "/", "-"}) + for i, expected := range []string{"git", "bestbytes", "net", "foo", "bar"} { + actual := res[i] + if actual != expected { + t.Fatal("expected", expected, "got", actual) + } + } + +}