diff --git a/build.go b/build.go index 2e3d28d..07c4b1a 100644 --- a/build.go +++ b/build.go @@ -12,8 +12,9 @@ import ( "sort" "strings" - "github.com/foomo/gotsrpc/v2/config" "golang.org/x/tools/imports" + + "github.com/foomo/gotsrpc/v2/config" ) func deriveCommonJSMapping(conf *config.Config) { @@ -76,6 +77,20 @@ func Build(conf *config.Config, goPath string) { } sort.Strings(names) + missingTypes := map[string]bool{} + for _, mapping := range conf.Mappings { + for _, include := range mapping.Types { + missingTypes[include] = true + } + } + + missingConstants := map[string]bool{} + for _, mapping := range conf.Mappings { + for _, include := range mapping.Constants { + missingConstants[include] = true + } + } + for name, target := range conf.Targets { packageName := target.Package @@ -112,8 +127,7 @@ func Build(conf *config.Config, goPath string) { goPaths = append(goPaths, vendorDirectory) } - pkgName, services, structs, scalars, constantTypes, err := Read(goPaths, conf.Module, packageName, target.Services) - + pkgName, services, structs, scalars, constantTypes, err := Read(goPaths, conf.Module, packageName, target.Services, missingTypes, missingConstants) if err != nil { fmt.Fprintln(os.Stderr, "\t an error occured while trying to understand your code: ", err) os.Exit(2) diff --git a/config/config.go b/config/config.go index 03ab71b..05e210b 100644 --- a/config/config.go +++ b/config/config.go @@ -13,12 +13,12 @@ import ( ) type Target struct { - Package string `yaml:"package"` - Services map[string]string `yaml:"services"` - TypeScriptModule string `yaml:"module"` - Out string `yaml:"out"` - GoRPC []string `yaml:"gorpc"` - TSRPC []string `yaml:"tsrpc"` + Package string `yaml:"package"` + Services map[string]string `yaml:"services"` + TypeScriptModule string `yaml:"module"` + Out string `yaml:"out"` + GoRPC []string `yaml:"gorpc"` + TSRPC []string `yaml:"tsrpc"` } func (t *Target) IsGoRPC(service string) bool { @@ -43,9 +43,11 @@ func (t *Target) IsTSRPC(service string) bool { } type Mapping struct { - GoPackage string `yaml:"-"` - Out string `yaml:"out"` - TypeScriptModule string `yaml:"module"` + GoPackage string `yaml:"-"` + Out string `yaml:"out"` + Types []string `yaml:"types"` + Constants []string `yaml:"constants"` + TypeScriptModule string `yaml:"module"` } type TypeScriptMappings map[string]*Mapping diff --git a/demo/output-commonjs-async/demo.ts b/demo/output-commonjs-async/demo.ts index 8338386..f2de030 100644 --- a/demo/output-commonjs-async/demo.ts +++ b/demo/output-commonjs-async/demo.ts @@ -22,9 +22,7 @@ export interface AttributeDefinition { // github.com/foomo/gotsrpc/v2/demo.AttributeID export type AttributeID = string // github.com/foomo/gotsrpc/v2/demo.AttributeMapping -export type AttributeMapping = Record -// github.com/foomo/gotsrpc/v2/demo.Bar -export type Bar = any +export type AttributeMapping = Record // github.com/foomo/gotsrpc/v2/demo.Check export interface Check { Foo:string; @@ -72,7 +70,7 @@ export type LocalKey = string // github.com/foomo/gotsrpc/v2/demo.MapOfOtherStuff export type MapOfOtherStuff = Record // github.com/foomo/gotsrpc/v2/demo.MapWithLocalStuff -export type MapWithLocalStuff = Record +export type MapWithLocalStuff = Record // github.com/foomo/gotsrpc/v2/demo.OuterInline export interface OuterInline { one:string; diff --git a/demo/output-commonjs/demo.ts b/demo/output-commonjs/demo.ts index eaee0a6..5dfa224 100644 --- a/demo/output-commonjs/demo.ts +++ b/demo/output-commonjs/demo.ts @@ -22,9 +22,7 @@ export interface AttributeDefinition { // github.com/foomo/gotsrpc/v2/demo.AttributeID export type AttributeID = string // github.com/foomo/gotsrpc/v2/demo.AttributeMapping -export type AttributeMapping = Record -// github.com/foomo/gotsrpc/v2/demo.Bar -export type Bar = any +export type AttributeMapping = Record // github.com/foomo/gotsrpc/v2/demo.Check export interface Check { Foo:string; @@ -72,7 +70,7 @@ export type LocalKey = string // github.com/foomo/gotsrpc/v2/demo.MapOfOtherStuff export type MapOfOtherStuff = Record // github.com/foomo/gotsrpc/v2/demo.MapWithLocalStuff -export type MapWithLocalStuff = Record +export type MapWithLocalStuff = Record // github.com/foomo/gotsrpc/v2/demo.OuterInline export interface OuterInline { one:string; diff --git a/demo/output/demo.ts b/demo/output/demo.ts index fa9a332..9ccca79 100644 --- a/demo/output/demo.ts +++ b/demo/output/demo.ts @@ -20,9 +20,7 @@ module GoTSRPC.Demo { // github.com/foomo/gotsrpc/v2/demo.AttributeID export type AttributeID = string // github.com/foomo/gotsrpc/v2/demo.AttributeMapping - export type AttributeMapping = Record - // github.com/foomo/gotsrpc/v2/demo.Bar - export type Bar = any + export type AttributeMapping = Record // github.com/foomo/gotsrpc/v2/demo.Check export interface Check { Foo:string; @@ -70,7 +68,7 @@ module GoTSRPC.Demo { // github.com/foomo/gotsrpc/v2/demo.MapOfOtherStuff export type MapOfOtherStuff = Record // github.com/foomo/gotsrpc/v2/demo.MapWithLocalStuff - export type MapWithLocalStuff = Record + export type MapWithLocalStuff = Record // github.com/foomo/gotsrpc/v2/demo.OuterInline export interface OuterInline { one:string; diff --git a/servicereader.go b/servicereader.go index f9e4054..97e5758 100644 --- a/servicereader.go +++ b/servicereader.go @@ -264,6 +264,8 @@ func Read( gomod config.Namespace, packageName string, serviceMap map[string]string, + missingTypes map[string]bool, + missingConstants map[string]bool, ) ( pkgName string, services ServiceList, @@ -287,7 +289,6 @@ func Read( return } - missingTypes := map[string]bool{} for _, s := range services { for _, m := range s.Methods { collectStructTypes(m.Return, missingTypes) @@ -312,11 +313,11 @@ func Read( trace("---------------- found scalars -------------------") traceData(scalars) trace("---------------- /found scalars -------------------") - constantTypes = map[string]map[string]interface{}{} + allConstantTypes := map[string]map[string]interface{}{} for _, structDef := range structs { if structDef != nil { structPackage := structDef.Package - _, ok := constantTypes[structPackage] + _, ok := allConstantTypes[structPackage] if !ok { // fmt.Println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", structPackage) pkg, constPkgErr := parsePackage(goPaths, gomod, structPackage) @@ -324,14 +325,14 @@ func Read( err = constPkgErr return } - constantTypes[structPackage] = loadConstantTypes(pkg) + allConstantTypes[structPackage] = loadConstantTypes(pkg) } } } for _, scalarDef := range scalars { if scalarDef != nil { scalarPackage := scalarDef.Package - _, ok := constantTypes[scalarPackage] + _, ok := allConstantTypes[scalarPackage] if !ok { // fmt.Println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", structPackage) pkg, constPkgErr := parsePackage(goPaths, gomod, scalarPackage) @@ -339,11 +340,41 @@ func Read( err = constPkgErr return } - constantTypes[scalarPackage] = loadConstantTypes(pkg) + allConstantTypes[scalarPackage] = loadConstantTypes(pkg) } } } + flatStructs := map[string]bool{} + for _, s := range structs { + loadFlatStructs(s, flatStructs) + } + + constantTypes = map[string]map[string]interface{}{} + for constantTypePackage, constantType := range allConstantTypes { + for constantTypeName, constantTypeVales := range constantType { + fullName := constantTypePackage + "." + constantTypeName + _, scalarOK := scalars[fullName] + _, structOK := flatStructs[fullName] + _, constantsOK := missingConstants[fullName] + + if scalarOK || structOK || constantsOK { + missingConstants[fullName] = false + if _, ok := constantTypes[constantTypePackage]; !ok { + constantTypes[constantTypePackage] = map[string]interface{}{} + } + constantTypes[constantTypePackage][constantTypeName] = constantTypeVales + } + } + } + + for missingConstant, missing := range missingConstants { + if missing { + err = errors.New("could not resolve constant: " + missingConstant) + return + } + } + // fix arg and return field lists for _, service := range services { for _, method := range service.Methods { @@ -355,9 +386,41 @@ func Read( return } +func loadFlatStructs(s *Struct, flatStructs map[string]bool) { + if s.Map != nil { + if s.Map.Key != nil { + loadFlatStructsValue(s.Map.Key, flatStructs) + } + if s.Map.Value != nil && s.Map.Value.Scalar != nil { + loadFlatStructsValue(s.Map.Value, flatStructs) + } + } + if s.Fields != nil { + for _, field := range s.Fields { + loadFlatStructsValue(field.Value, flatStructs) + } + } + flatStructs[s.FullName()] = true +} + +func loadFlatStructsValue(s *Value, flatStructs map[string]bool) { + if s.Map != nil { + if s.Map.Key != nil { + loadFlatStructsValue(s.Map.Key, flatStructs) + } + if s.Map.Value != nil && s.Map.Value.Scalar != nil { + loadFlatStructsValue(s.Map.Value, flatStructs) + } + } + if s.Struct != nil { + loadFlatStructs(s.Struct, flatStructs) + } + if s.Scalar != nil { + flatStructs[s.Scalar.FullName()] = true + } +} func fixFieldStructs(fields []*Field, structs map[string]*Struct, scalars map[string]*Scalar) { for _, f := range fields { - if f.Value.StructType != nil { // do we have that struct or is it a hidden scalar name := f.Value.StructType.FullName() @@ -600,35 +663,40 @@ func getStructTypeForField(value *Value) *StructType { return strType } -func getScalarForField(value *Value) *Scalar { +func getScalarForField(value *Value) []*Scalar { //field.Value.StructType - var scalarType *Scalar + var scalarTypes []*Scalar switch true { case value.Scalar != nil: - scalarType = value.Scalar + scalarTypes = append(scalarTypes, value.Scalar) //case field.Value.ArrayType case value.Map != nil: - scalarType = getScalarForField(value.Map.Value) + if value.Map.Key != nil { + if v := getScalarForField(value.Map.Key); v != nil { + scalarTypes = append(scalarTypes, v...) + } + } + scalarTypes = append(scalarTypes, getScalarForField(value.Map.Value)...) case value.Array != nil: - scalarType = getScalarForField(value.Array.Value) + scalarTypes = append(scalarTypes, getScalarForField(value.Array.Value)...) } - return scalarType + return scalarTypes } func collectScalarTypes(fields []*Field, scalarTypes map[string]bool) { for _, field := range fields { - - scalarType := getScalarForField(field.Value) - if scalarType != nil { - fullName := scalarType.Package + "." + scalarType.Name - if len(scalarType.Package) == 0 { - fullName = scalarType.Name - } - switch fullName { - case "error", "net/http.Request", "net/http.ResponseWriter": - continue - default: - scalarTypes[fullName] = true + for _, scalarType := range getScalarForField(field.Value) { + if scalarType != nil { + fullName := scalarType.Package + "." + scalarType.Name + if len(scalarType.Package) == 0 { + fullName = scalarType.Name + } + switch fullName { + case "error", "net/http.Request", "net/http.ResponseWriter": + continue + default: + scalarTypes[fullName] = true + } } } } diff --git a/typereader.go b/typereader.go index 1befbd9..d10f85c 100644 --- a/typereader.go +++ b/typereader.go @@ -185,6 +185,8 @@ func readAstMapType(m *Map, mapType *ast.MapType, fileImports fileImportSpecMap) _, scalarType := getTypesFromAstType(mapType.Key.(*ast.Ident)) m.KeyType = string(scalarType) m.KeyGoType = mapType.Key.(*ast.Ident).Name + m.Key = &Value{} + readAstType(m.Key, mapType.Key.(*ast.Ident), fileImports, "") case "*ast.SelectorExpr": m.Key = &Value{} readAstSelectorExpr(m.Key, mapType.Key.(*ast.SelectorExpr), fileImports)