adding support for recursive struct resolution across packages including import conflict resolution

This commit is contained in:
Florian Schlegel 2016-07-21 18:23:48 +02:00
parent 018e8a68dc
commit d56bb49d60
8 changed files with 415 additions and 65 deletions

View File

@ -56,12 +56,10 @@ func main() {
services, structs, err := gotsrpc.Read(goPath, longPackageName, args[1:])
if err != nil {
fmt.Fprintln(os.Stderr, "an error occured", err)
fmt.Fprintln(os.Stderr, "an error occured while trying to understand your code", err)
os.Exit(2)
}
//jsonDump(services)
//jsonDump(structs)
jsonDump(structs)
ts, err := gotsrpc.RenderTypeScript(services, structs, *flagTsModule)
if err != nil {
fmt.Fprintln(os.Stderr, "could not generate ts code", err)

39
config.go Normal file
View File

@ -0,0 +1,39 @@
package gotsrpc
import (
"io/ioutil"
"gopkg.in/yaml.v2"
)
type GoTypeScriptMapping struct {
GoPackage string `yaml:"-"`
TypeScriptDir string `yaml:"dir"`
TypeScriptModule string `yaml:"module"`
}
type Config struct {
Mappings map[string]*GoTypeScriptMapping
}
func loadConfigfile(file string) (conf *Config, err error) {
yamlBytes, readErr := ioutil.ReadFile(file)
if err != nil {
err = readErr
return
}
return loadConfig(yamlBytes)
}
func loadConfig(yamlBytes []byte) (conf *Config, err error) {
conf = &Config{}
yamlErr := yaml.Unmarshal(yamlBytes, conf)
if yamlErr != nil {
err = yamlErr
return
}
for goPackage, mapping := range conf.Mappings {
mapping.GoPackage = goPackage
}
return
}

33
config_test.go Normal file
View File

@ -0,0 +1,33 @@
package gotsrpc
import "testing"
const sampleConf = `---
mappings:
foo/bar:
module: Sample.Module
dir: path/to/ts
github.com/foomo/gotsrpc:
module: Sample.Module.RPC
dir: path/to/other/folder
`
func TestLoadConfig(t *testing.T) {
c, err := loadConfig([]byte(sampleConf))
if err != nil {
t.Fatal(err)
}
goPackage := "foo/bar"
foo, ok := c.Mappings[goPackage]
if !ok {
t.Fatal("foo/bar not found")
}
if foo.GoPackage != goPackage {
t.Fatal("wrong go package value")
}
if foo.TypeScriptDir != "path/to/ts" || foo.TypeScriptModule != "Sample.Module" {
t.Fatal("unexpected data", foo)
}
}

56
go.go
View File

@ -6,23 +6,23 @@ import (
)
func (v *Value) isHTTPResponseWriter() bool {
return v.StructType != nil && v.StructType.Name == "ResponseWriter" && v.StructType.Package == "http"
return v.StructType != nil && v.StructType.Name == "ResponseWriter" && v.StructType.Package == "net/http"
}
func (v *Value) isHTTPRequest() bool {
return v.IsPtr && v.StructType != nil && v.StructType.Name == "Request" && v.StructType.Package == "http"
return v.IsPtr && v.StructType != nil && v.StructType.Name == "Request" && v.StructType.Package == "net/http"
}
func (v *Value) goType() (t string) {
func (v *Value) goType(aliases map[string]string) (t string) {
if v.IsPtr {
t = "*"
}
switch true {
case v.Array != nil:
t += "[]" + v.Array.Value.goType()
t += "[]" + v.Array.Value.goType(aliases)
case len(v.GoScalarType) > 0:
t += v.GoScalarType
case v.StructType != nil:
t += v.StructType.Name
t += aliases[v.StructType.Package] + "." + v.StructType.Name
}
return
}
@ -93,12 +93,52 @@ func strfirst(str string, strfunc func(string) string) string {
func renderServiceProxies(services []*Service, packageName string, g *code) error {
aliases := map[string]string{
"net/http": "http",
"github.com/foomo/gotsrpc": "gotsrpc",
}
r := strings.NewReplacer(".", "_", "/", "_", "-", "_")
extractImports := func(fields []*Field) {
for _, f := range fields {
if f.Value.StructType != nil {
st := f.Value.StructType
if st.Package != packageName {
alias, ok := aliases[st.Package]
if !ok {
packageParts := strings.Split(st.Package, "/")
beautifulAlias := packageParts[len(packageParts)-1]
uglyAlias := r.Replace(st.Package)
alias = beautifulAlias
for _, otherAlias := range aliases {
if otherAlias == beautifulAlias {
alias = uglyAlias
break
}
}
aliases[st.Package] = alias
}
}
}
}
}
for _, s := range services {
for _, m := range s.Methods {
extractImports(m.Args)
extractImports(m.Return)
}
}
imports := ""
for packageName, alias := range aliases {
imports += alias + " \"" + packageName + "\"\n"
}
g.l(`
// this file was auto generated by gotsrpc https://github.com/foomo/gotsrpc
package ` + packageName + `
import (
"net/http"
"github.com/foomo/gotsrpc"
` + imports + `
)
`)
for _, service := range services {
@ -165,7 +205,7 @@ func renderServiceProxies(services []*Service, packageName string, g *code) erro
callArgs = append(callArgs, fmt.Sprint(arg.Value.GoScalarType+"(args[", skipArgI, "].(float64))"))
default:
// assert
callArgs = append(callArgs, fmt.Sprint("args[", skipArgI, "].("+arg.Value.goType()+")"))
callArgs = append(callArgs, fmt.Sprint("args[", skipArgI, "].("+arg.Value.goType(aliases)+")"))
}

View File

@ -58,6 +58,7 @@ type Method struct {
}
type Struct struct {
Name string
Fields []*Field
Package string
Name string
Fields []*Field
}

View File

@ -3,11 +3,13 @@ package gotsrpc
import (
"errors"
"go/ast"
"go/token"
"reflect"
"runtime"
"strings"
)
func readServiceFile(file *ast.File, services []*Service) error {
func readServiceFile(file *ast.File, packageName string, services []*Service) error {
findService := func(serviceName string) (service *Service, ok bool) {
for _, service := range services {
if service.Name == serviceName {
@ -16,6 +18,9 @@ func readServiceFile(file *ast.File, services []*Service) error {
}
return nil, false
}
fileImports := getFileImports(file, packageName)
for _, decl := range file.Decls {
if reflect.ValueOf(decl).Type().String() == "*ast.FuncDecl" {
funcDecl := decl.(*ast.FuncDecl)
@ -34,8 +39,8 @@ func readServiceFile(file *ast.File, services []*Service) error {
if ok && strings.ToLower(firstCharOfMethodName) != firstCharOfMethodName {
service.Methods = append(service.Methods, &Method{
Name: funcDecl.Name.Name,
Args: readFields(funcDecl.Type.Params),
Return: readFields(funcDecl.Type.Results),
Args: readFields(funcDecl.Type.Params, fileImports),
Return: readFields(funcDecl.Type.Results, fileImports),
})
}
}
@ -46,17 +51,68 @@ func readServiceFile(file *ast.File, services []*Service) error {
}
}
}
return nil
}
func readFields(fieldList *ast.FieldList) (fields []*Field) {
type importSpec struct {
alias string
name string
path string
}
type fileImportSpecMap map[string]importSpec
func (fileImports fileImportSpecMap) getPackagePath(packageName string) string {
importSpec, ok := fileImports[packageName]
if ok {
packageName = importSpec.path
}
return packageName
}
func standardImportName(importPath string) string {
pathParts := strings.Split(importPath, "/")
return pathParts[len(pathParts)-1]
}
func getFileImports(file *ast.File, packageName string) (imports fileImportSpecMap) {
imports = fileImportSpecMap{"": importSpec{alias: "", name: "", path: packageName}}
for _, decl := range file.Decls {
if reflect.ValueOf(decl).Type().String() == "*ast.GenDecl" {
genDecl := decl.(*ast.GenDecl)
if genDecl.Tok == token.IMPORT {
trace("got an import", genDecl.Specs)
for _, spec := range genDecl.Specs {
if "*ast.ImportSpec" == reflect.ValueOf(spec).Type().String() {
spec := spec.(*ast.ImportSpec)
importPath := spec.Path.Value[1 : len(spec.Path.Value)-1]
importName := spec.Name.String()
if importName == "" || importName == "<nil>" {
importName = standardImportName(importPath)
}
imports[importName] = importSpec{
alias: importName,
name: standardImportName(importPath),
path: importPath,
}
//trace(" import >>>>>>>>>>>>>>>>>>>>", importName, importPath)
}
}
}
}
}
return imports
}
func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap) (fields []*Field) {
fields = []*Field{}
if fieldList == nil {
return
}
for _, param := range fieldList.List {
name, value, _ := readField(param)
name, value, _ := readField(param, fileImports)
fields = append(fields, &Field{
Name: name,
Value: value,
@ -66,7 +122,7 @@ func readFields(fieldList *ast.FieldList) (fields []*Field) {
}
func readServicesInPackage(pkg *ast.Package, serviceNames []string) (services []*Service, err error) {
func readServicesInPackage(pkg *ast.Package, packageName string, serviceNames []string) (services []*Service, err error) {
services = []*Service{}
for _, serviceName := range serviceNames {
services = append(services, &Service{
@ -75,7 +131,7 @@ func readServicesInPackage(pkg *ast.Package, serviceNames []string) (services []
})
}
for _, file := range pkg.Files {
err = readServiceFile(file, services)
err = readServiceFile(file, packageName, services)
if err != nil {
return
}
@ -93,10 +149,160 @@ func Read(goPath string, packageName string, serviceNames []string) (services []
if err != nil {
return
}
services, err = readServicesInPackage(pkg, serviceNames)
services, err = readServicesInPackage(pkg, packageName, serviceNames)
if err != nil {
return
}
structs, err = readStructs(pkg)
jsonTrace(services)
structTypes := map[string]*StructType{}
for _, s := range services {
for _, m := range s.Methods {
collecStructTypes(m.Return, structTypes)
collecStructTypes(m.Args, structTypes)
}
}
jsonTrace(structTypes)
structs = map[string]*Struct{}
for wantedName := range structTypes {
structs[wantedName] = nil
}
collectErr := collectStructs(goPath, structs)
if collectErr != nil {
err = errors.New("error while collecting structs: " + collectErr.Error())
}
jsonTrace(structs)
return
}
func collectStructs(goPath string, structs map[string]*Struct) error {
scannedPackages := map[string]map[string]*Struct{}
for structsPending(structs) {
for fullName, strct := range structs {
if strct != nil {
continue
}
fullNameParts := strings.Split(fullName, ".")
fullNameParts = fullNameParts[:len(fullNameParts)-1]
//path := fullNameParts[:len(fullNameParts)-1][0]
packageName := strings.Join(fullNameParts, ".")
//trace(fullName, "==========================>", fullNameParts, "=============>", packageName)
packageStructs, ok := scannedPackages[packageName]
if !ok {
parsedPackageStructs, err := getStructsInPackage(goPath, packageName)
if err != nil {
return err
}
packageStructs = parsedPackageStructs
scannedPackages[packageName] = packageStructs
}
for packageStructName, packageStruct := range packageStructs {
existingStruct, needed := structs[packageStructName]
if needed && existingStruct == nil {
structs[packageStructName] = packageStruct
}
}
}
}
return nil
}
func structsPending(structs map[string]*Struct) bool {
for _, structType := range structs {
if structType == nil || !structType.DepsSatisfied(structs) {
return true
}
}
return false
}
func (s *Struct) DepsSatisfied(structs map[string]*Struct) bool {
needsWork := func(fullName string) bool {
strct, ok := structs[fullName]
if !ok {
// hey there is more todo
structs[fullName] = nil
return true
}
if strct == nil {
trace("need work " + fullName)
return true
}
return false
}
for _, field := range s.Fields {
var fieldStructType *StructType
fieldStructType = nil
if field.Value.StructType != nil {
fieldStructType = field.Value.StructType
} else if field.Value.Array != nil && field.Value.Array.Value.StructType != nil {
fieldStructType = field.Value.Array.Value.StructType
} else if field.Value.Map != nil && field.Value.Map.Value.StructType != nil {
fieldStructType = field.Value.Map.Value.StructType
}
if fieldStructType != nil {
if needsWork(fieldStructType.FullName()) {
return false
}
}
}
return !needsWork(s.FullName())
}
func (s *Struct) FullName() string {
fullName := s.Package + "." + s.Name
if len(fullName) == 0 {
fullName = s.Name
}
return fullName
}
func (st *StructType) FullName() string {
fullName := st.Package + "." + st.Name
if len(fullName) == 0 {
fullName = st.Name
}
return fullName
}
func getStructsInPackage(goPath string, packageName string) (structs map[string]*Struct, err error) {
pkg, err := parsePackage(goPath, packageName)
if err != nil {
pkg, err = parsePackage(runtime.GOROOT(), packageName)
if err != nil {
return nil, err
}
}
structs, err = readStructs(pkg, packageName)
if err != nil {
return nil, err
}
return structs, nil
}
func collecStructTypes(fields []*Field, structTypes map[string]*StructType) {
for _, field := range fields {
if field.Value.StructType != nil {
fullName := field.Value.StructType.Package + "." + field.Value.StructType.Name
if len(field.Value.StructType.Package) == 0 {
fullName = field.Value.StructType.Name
}
switch fullName {
case "error", "net/http.Request", "net/http.ResponseWriter":
continue
default:
structTypes[fullName] = field.Value.StructType
}
}
}
}
//func collectStructs(goPath, structs)

View File

@ -1,6 +1,7 @@
package gotsrpc
import (
"encoding/json"
"fmt"
"go/ast"
"os"
@ -10,11 +11,13 @@ import (
var ReaderTrace = false
func readStructs(pkg *ast.Package) (structs map[string]*Struct, err error) {
func readStructs(pkg *ast.Package, packageName string) (structs map[string]*Struct, err error) {
structs = map[string]*Struct{}
trace("reading files in package", packageName)
for _, file := range pkg.Files {
//readFile(filename, file)
err = extractStructs(file, structs)
err = extractStructs(file, packageName, structs)
if err != nil {
return
}
@ -27,6 +30,18 @@ func trace(args ...interface{}) {
fmt.Fprintln(os.Stderr, args...)
}
}
func jsonTrace(args ...interface{}) {
if ReaderTrace {
for _, arg := range args {
jsonBytes, jsonErr := json.MarshalIndent(arg, "", " ")
if jsonErr != nil {
trace(arg)
continue
}
trace(string(jsonBytes))
}
}
}
func extractJSONInfo(tag string) *JSONInfo {
t := reflect.StructTag(tag)
@ -94,46 +109,49 @@ func getTypesFromAstType(ident *ast.Ident) (structType string, scalarType Scalar
}
func readAstType(v *Value, fieldIdent *ast.Ident) {
structType, scalarType := getTypesFromAstType(fieldIdent)
_, scalarType := getTypesFromAstType(fieldIdent)
v.ScalarType = scalarType
if len(structType) > 0 {
v.StructType = &StructType{
Name: structType,
}
} else {
v.GoScalarType = fieldIdent.Name
}
// if len(structType) > 0 {
// v.StructType = &StructType{
// Name: structType,
// //Package: fieldIdent.String(),
// }
// trace("----------------->", fieldIdent)
// } else {
v.GoScalarType = fieldIdent.Name
// }
}
func readAstStarExpr(v *Value, starExpr *ast.StarExpr) {
func readAstStarExpr(v *Value, starExpr *ast.StarExpr, fileImports fileImportSpecMap) {
v.IsPtr = true
switch reflect.ValueOf(starExpr.X).Type().String() {
case "*ast.Ident":
ident := starExpr.X.(*ast.Ident)
v.StructType = &StructType{
Name: ident.Name,
Name: ident.Name,
Package: fileImports.getPackagePath(""),
}
case "*ast.StructType":
// nested anonymous
readAstStructType(v, starExpr.X.(*ast.StructType))
readAstStructType(v, starExpr.X.(*ast.StructType), fileImports)
case "*ast.SelectorExpr":
readAstSelectorExpr(v, starExpr.X.(*ast.SelectorExpr))
readAstSelectorExpr(v, starExpr.X.(*ast.SelectorExpr), fileImports)
default:
trace("a pointer on what", reflect.ValueOf(starExpr.X).Type().String())
}
}
func readAstArrayType(v *Value, arrayType *ast.ArrayType) {
func readAstArrayType(v *Value, arrayType *ast.ArrayType, fileImports fileImportSpecMap) {
switch reflect.ValueOf(arrayType.Elt).Type().String() {
case "*ast.StarExpr":
readAstStarExpr(v, arrayType.Elt.(*ast.StarExpr))
readAstStarExpr(v, arrayType.Elt.(*ast.StarExpr), fileImports)
default:
trace("array type elt", reflect.ValueOf(arrayType.Elt).Type().String())
}
}
func readAstMapType(m *Map, mapType *ast.MapType) {
func readAstMapType(m *Map, mapType *ast.MapType, fileImports fileImportSpecMap) {
trace(" map key", mapType.Key, reflect.ValueOf(mapType.Key).Type().String())
trace(" map value", mapType.Value, reflect.ValueOf(mapType.Value).Type().String())
// key
@ -143,16 +161,17 @@ func readAstMapType(m *Map, mapType *ast.MapType) {
m.KeyType = string(scalarType)
}
// value
m.Value.loadExpr(mapType.Value)
m.Value.loadExpr(mapType.Value, fileImports)
}
func readAstSelectorExpr(v *Value, selectorExpr *ast.SelectorExpr) {
func readAstSelectorExpr(v *Value, selectorExpr *ast.SelectorExpr, fileImports fileImportSpecMap) {
switch reflect.ValueOf(selectorExpr.X).Type().String() {
case "*ast.Ident":
// that could be the package name
selectorIdent := selectorExpr.X.(*ast.Ident)
packageName := selectorIdent.Name
v.StructType = &StructType{
Package: selectorIdent.Name,
Package: fileImports.getPackagePath(packageName),
Name: selectorExpr.Sel.Name,
}
default:
@ -160,12 +179,12 @@ func readAstSelectorExpr(v *Value, selectorExpr *ast.SelectorExpr) {
}
}
func readAstStructType(v *Value, structType *ast.StructType) {
func readAstStructType(v *Value, structType *ast.StructType, fileImports fileImportSpecMap) {
v.Struct = &Struct{}
v.Struct.Fields = readFieldList(structType.Fields.List)
v.Struct.Fields = readFieldList(structType.Fields.List, fileImports)
}
func (v *Value) loadExpr(expr ast.Expr) {
func (v *Value) loadExpr(expr ast.Expr, fileImports fileImportSpecMap) {
//fmt.Println(field.Names[0].Name, field.Type, reflect.ValueOf(field.Type).Type().String())
switch reflect.ValueOf(expr).Type().String() {
case "*ast.ArrayType":
@ -175,14 +194,14 @@ func (v *Value) loadExpr(expr ast.Expr) {
case "*ast.Ident":
readAstType(v.Array.Value, fieldArray.Elt.(*ast.Ident))
case "*ast.StarExpr":
readAstStarExpr(v.Array.Value, fieldArray.Elt.(*ast.StarExpr))
readAstStarExpr(v.Array.Value, fieldArray.Elt.(*ast.StarExpr), fileImports)
case "*ast.ArrayType":
readAstArrayType(v.Array.Value, fieldArray.Elt.(*ast.ArrayType))
readAstArrayType(v.Array.Value, fieldArray.Elt.(*ast.ArrayType), fileImports)
case "*ast.MapType":
v.Array.Value.Map = &Map{
Value: &Value{},
}
readAstMapType(v.Array.Value.Map, fieldArray.Elt.(*ast.MapType))
readAstMapType(v.Array.Value.Map, fieldArray.Elt.(*ast.MapType), fileImports)
default:
trace("---------------------> array of", reflect.ValueOf(fieldArray.Elt).Type().String())
}
@ -191,40 +210,44 @@ func (v *Value) loadExpr(expr ast.Expr) {
readAstType(v, fieldIdent)
case "*ast.StarExpr":
// a pointer on sth
readAstStarExpr(v, expr.(*ast.StarExpr))
readAstStarExpr(v, expr.(*ast.StarExpr), fileImports)
case "*ast.MapType":
v.Map = &Map{
Value: &Value{},
}
readAstMapType(v.Map, expr.(*ast.MapType))
readAstMapType(v.Map, expr.(*ast.MapType), fileImports)
case "*ast.SelectorExpr":
readAstSelectorExpr(v, expr.(*ast.SelectorExpr))
readAstSelectorExpr(v, expr.(*ast.SelectorExpr), fileImports)
case "*ast.StructType":
readAstStructType(v, expr.(*ast.StructType))
readAstStructType(v, expr.(*ast.StructType), fileImports)
default:
trace("what kind of field ident would that be ?!", reflect.ValueOf(expr).Type().String())
}
}
func readField(astField *ast.Field) (name string, v *Value, jsonInfo *JSONInfo) {
func readField(astField *ast.Field, fileImports fileImportSpecMap) (name string, v *Value, jsonInfo *JSONInfo) {
name = ""
if len(astField.Names) > 0 {
name = astField.Names[0].Name
}
trace(" ", name)
trace(" reading field with name", name, "of type", astField.Type)
v = &Value{}
v.loadExpr(astField.Type)
v.loadExpr(astField.Type, fileImports)
if astField.Tag != nil {
jsonInfo = extractJSONInfo(astField.Tag.Value[1 : len(astField.Tag.Value)-1])
}
return
}
func readFieldList(fieldList []*ast.Field) (fields []*Field) {
func readFieldList(fieldList []*ast.Field, fileImports fileImportSpecMap) (fields []*Field) {
fields = []*Field{}
for _, field := range fieldList {
name, value, jsonInfo := readField(field)
name, value, jsonInfo := readField(field, fileImports)
if len(name) == 0 {
trace("i do not understand this one", field, name, value, jsonInfo)
continue
}
if strings.Compare(strings.ToLower(name[:1]), name[:1]) == 0 {
continue
}
@ -240,17 +263,22 @@ func readFieldList(fieldList []*ast.Field) (fields []*Field) {
return
}
func extractStructs(file *ast.File, structs map[string]*Struct) error {
func extractStructs(file *ast.File, packageName string, structs map[string]*Struct) error {
trace("reading file", file.Name.Name)
//for _, imp := range file.Imports {
// fmt.Println("import", imp.Name, imp.Path)
//}
fileImports := getFileImports(file, packageName)
for name, obj := range file.Scope.Objects {
//fmt.Println(name, obj.Kind, obj.Data)
if obj.Kind == ast.Typ && obj.Decl != nil {
//ast.StructType
structs[name] = &Struct{
Name: name,
Fields: []*Field{},
structName := packageName + "." + name
structs[structName] = &Struct{
Name: name,
Fields: []*Field{},
Package: packageName,
}
if reflect.ValueOf(obj.Decl).Type().String() == "*ast.TypeSpec" {
typeSpec := obj.Decl.(*ast.TypeSpec)
@ -258,7 +286,7 @@ func extractStructs(file *ast.File, structs map[string]*Struct) error {
if typeSpecRefl.Type().String() == "*ast.StructType" {
structType := typeSpec.Type.(*ast.StructType)
trace("StructType", obj.Name)
structs[name].Fields = readFieldList(structType.Fields.List)
structs[structName].Fields = readFieldList(structType.Fields.List, fileImports)
} else {
// fmt.Println(" what would that be", typeSpecRefl.Type().String())
}

View File

@ -1,6 +1,7 @@
package gotsrpc
import (
"errors"
"fmt"
"strings"
)
@ -39,6 +40,7 @@ func (v *Value) tsType() string {
}
func renderStruct(str *Struct, ts *code) error {
ts.l("// " + str.Package + "." + str.Name).ind(1)
ts.l("export interface " + str.Name + " {").ind(1)
for _, f := range str.Fields {
if f.JSONInfo != nil && f.JSONInfo.Ignore {
@ -151,7 +153,10 @@ func RenderTypeScript(services []*Service, structs map[string]*Struct, tsModuleN
ts.l("module " + tsModuleName + " {")
ts.ind(1)
for _, str := range structs {
for name, str := range structs {
if str == nil {
return "", errors.New("could not resolve: " + name)
}
err = renderStruct(str, ts)
if err != nil {
return