diff --git a/model.go b/model.go index 2bee865..57e96ac 100644 --- a/model.go +++ b/model.go @@ -4,6 +4,7 @@ type ScalarType string const ( ScalarTypeString ScalarType = "string" + ScalarTypeAny ScalarType = "any" ScalarTypeByte ScalarType = "byte" ScalarTypeNumber ScalarType = "number" ScalarTypeBool ScalarType = "bool" diff --git a/servicereader.go b/servicereader.go index 9bf3b5c..1804a25 100644 --- a/servicereader.go +++ b/servicereader.go @@ -110,7 +110,7 @@ func getFileImports(file *ast.File, packageName string) (imports fileImportSpecM name: standardImportName(importPath), path: importPath, } - //trace(" import >>>>>>>>>>>>>>>>>>>>", importName, importPath) + // trace(" import >>>>>>>>>>>>>>>>>>>>", importName, importPath) } } } @@ -135,7 +135,6 @@ func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap) (fields } trace("done reading fields") return - } func readServicesInPackage(pkg *ast.Package, packageName string, serviceMap map[string]string) (services ServiceList, err error) { @@ -193,7 +192,6 @@ func loadConstants(pkg *ast.Package) map[string]*ast.BasicLit { } } return constants - } func Read( @@ -278,7 +276,6 @@ func Read( func fixFieldStructs(fields []*Field, structs map[string]*Struct, scalars map[string]*Scalar) { for _, f := range fields { - if f.Value.StructType != nil { // do we have that struct or is it a hidden scalar name := f.Value.StructType.FullName() @@ -320,7 +317,7 @@ func collectTypes(goPaths []string, gomod config.Namespace, missingTypes map[str fullNameParts := strings.Split(fullName, ".") fullNameParts = fullNameParts[:len(fullNameParts)-1] - //path := fullNameParts[:len(fullNameParts)-1][0] + // path := fullNameParts[:len(fullNameParts)-1][0] packageName := strings.Join(fullNameParts, ".") @@ -377,7 +374,7 @@ func collectTypes(goPaths []string, gomod config.Namespace, missingTypes map[str } newNumMissingTypes := len(missingTypeNames()) if newNumMissingTypes > 0 && newNumMissingTypes == lastNumMissing { - //packageStructs, structOK := scannedPackageStructs[packageName] + // packageStructs, structOK := scannedPackageStructs[packageName] for scalarName, scalars := range scannedPackageScalars { fmt.Println("scanned scalars ", scalarName) for _, scalar := range scalars { @@ -488,16 +485,17 @@ func getTypesInPackage( if err != nil { return nil, nil, err } + return structs, scalars, nil } func getStructTypeForField(value *Value) *StructType { - //field.Value.StructType + // field.Value.StructType var strType *StructType switch true { case value.StructType != nil: strType = value.StructType - //case field.Value.ArrayType + // case field.Value.ArrayType case value.Map != nil: strType = getStructTypeForField(value.Map.Value) case value.Array != nil: @@ -507,12 +505,12 @@ func getStructTypeForField(value *Value) *StructType { } func getScalarForField(value *Value) *Scalar { - //field.Value.StructType + // field.Value.StructType var scalarType *Scalar switch true { case value.Scalar != nil: scalarType = value.Scalar - //case field.Value.ArrayType + // case field.Value.ArrayType case value.Map != nil: scalarType = getScalarForField(value.Map.Value) case value.Array != nil: @@ -552,5 +550,3 @@ func collectStructTypes(fields []*Field, structTypes map[string]bool) { } } } - -//func collectStructs(goPath, structs) diff --git a/typereader.go b/typereader.go index 5081568..8b8427c 100644 --- a/typereader.go +++ b/typereader.go @@ -7,7 +7,7 @@ import ( "reflect" "strings" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" ) var ReaderTrace = false @@ -34,9 +34,9 @@ func readStructs(pkg *ast.Package, packageName string) (structs map[string]*Stru structType.IsError = true } } - //jsonDump(errorTypes) - //jsonDump(scalarTypes) - //jsonDump(structs) + // jsonDump(errorTypes) + // jsonDump(scalarTypes) + // jsonDump(structs) return } @@ -45,6 +45,7 @@ func trace(args ...interface{}) { fmt.Fprintln(os.Stderr, args...) } } + func traceData(args ...interface{}) { if ReaderTrace { for _, arg := range args { @@ -105,6 +106,8 @@ func getScalarFromAstIdent(ident *ast.Ident) ScalarType { switch ident.Name { case "string": return ScalarTypeString + case "any": + return ScalarTypeAny case "bool": return ScalarTypeBool case "byte": @@ -137,6 +140,7 @@ func getTypesFromAstType(ident *ast.Ident) (structType string, scalarType Scalar func readAstType(v *Value, fieldIdent *ast.Ident, fileImports fileImportSpecMap) { structType, scalarType := getTypesFromAstType(fieldIdent) v.ScalarType = scalarType + if len(structType) > 0 { v.StructType = &StructType{ Name: structType, @@ -145,31 +149,27 @@ func readAstType(v *Value, fieldIdent *ast.Ident, fileImports fileImportSpecMap) } else { v.GoScalarType = fieldIdent.Name if fieldIdent.Obj != nil && fieldIdent.Obj.Decl != nil && reflect.ValueOf(fieldIdent.Obj.Decl).Type().String() == "*ast.TypeSpec" { - //typeSpec := fieldIdent.Obj.Decl.(*ast.TypeSpec) - //fmt.Println("-------------------------------------->", fieldIdent.Name, reflect.ValueOf(typeSpec.Type).Type()) + // typeSpec := fieldIdent.Obj.Decl.(*ast.TypeSpec) + // fmt.Println("-------------------------------------->", fieldIdent.Name, reflect.ValueOf(typeSpec.Type).Type()) v.Scalar = &Scalar{ Package: fileImports.getPackagePath(""), Name: fieldIdent.Name, Type: scalarType, } - //jsonDump(v) } - } - // } func readAstStarExpr(v *Value, starExpr *ast.StarExpr, fileImports fileImportSpecMap) { v.IsPtr = true - switch reflect.ValueOf(starExpr.X).Type().String() { - case "*ast.Ident": - ident := starExpr.X.(*ast.Ident) - readAstType(v, ident, fileImports) - case "*ast.StructType": + switch starExprType := starExpr.X.(type) { + case *ast.Ident: + readAstType(v, starExprType, fileImports) + case *ast.StructType: // nested anonymous - readAstStructType(v, starExpr.X.(*ast.StructType), fileImports) - case "*ast.SelectorExpr": - readAstSelectorExpr(v, starExpr.X.(*ast.SelectorExpr), fileImports) + readAstStructType(v, starExprType, fileImports) + case *ast.SelectorExpr: + readAstSelectorExpr(v, starExprType, fileImports) default: trace("a pointer on what", reflect.ValueOf(starExpr.X).Type().String()) } @@ -179,11 +179,11 @@ func readAstMapType(m *Map, mapType *ast.MapType, fileImports fileImportSpecMap) trace(" map key", mapType.Key, reflect.ValueOf(mapType.Key).Type().String()) trace(" map value", mapType.Value, reflect.ValueOf(mapType.Value).Type().String()) // key - switch reflect.ValueOf(mapType.Key).Type().String() { - case "*ast.Ident": - _, scalarType := getTypesFromAstType(mapType.Key.(*ast.Ident)) + switch keyType := mapType.Key.(type) { + case *ast.Ident: + _, scalarType := getTypesFromAstType(keyType) m.KeyType = string(scalarType) - m.KeyGoType = mapType.Key.(*ast.Ident).Name + m.KeyGoType = keyType.Name default: // todo: implement support for "*ast.Scalar" type (sca) // this is important for scalar types in map keys @@ -204,25 +204,25 @@ func readAstMapType(m *Map, mapType *ast.MapType, fileImports fileImportSpecMap) // }) //}) - //fmt.Println("--------------------------->", reflect.ValueOf(mapType.Key).Type().String()) + // fmt.Println("--------------------------->", reflect.ValueOf(mapType.Key).Type().String()) } // value m.Value.loadExpr(mapType.Value, fileImports) } func readAstSelectorExpr(v *Value, selectorExpr *ast.SelectorExpr, fileImports fileImportSpecMap) { - switch reflect.ValueOf(selectorExpr.X).Type().String() { - case "*ast.Ident": + switch selExpType := selectorExpr.X.(type) { + case *ast.Ident: // that could be the package name - //selectorIdent := selectorExpr.X.(*ast.Ident) + // selectorIdent := selectorExpr.X.(*ast.Ident) // fmt.Println(selectorExpr, selectorExpr.X.(*ast.Ident)) - readAstType(v, selectorExpr.X.(*ast.Ident), fileImports) + readAstType(v, selExpType, fileImports) if v.StructType != nil { v.StructType.Package = fileImports.getPackagePath(v.StructType.Name) v.StructType.Name = selectorExpr.Sel.Name } - //fmt.Println(selectorExpr.X.(*ast.Ident).Name, ".", selectorExpr.Sel) - //readAstType(v, selectorExpr.Sel, fileImports) + // fmt.Println(selectorExpr.X.(*ast.Ident).Name, ".", selectorExpr.Sel) + // readAstType(v, selectorExpr.Sel, fileImports) default: trace("selectorExpr.Sel !?", selectorExpr.X, reflect.ValueOf(selectorExpr.X).Type().String()) } @@ -235,54 +235,52 @@ func readAstStructType(v *Value, structType *ast.StructType, fileImports fileImp func readAstInterfaceType(v *Value, interfaceType *ast.InterfaceType, fileImports fileImportSpecMap) { v.IsInterface = true - } func (v *Value) loadExpr(expr ast.Expr, fileImports fileImportSpecMap) { - - switch reflect.ValueOf(expr).Type().String() { - case "*ast.ArrayType": + switch expr.(type) { + case *ast.ArrayType: fieldArray := expr.(*ast.ArrayType) v.Array = &Array{Value: &Value{}} - switch reflect.ValueOf(fieldArray.Elt).Type().String() { - case "*ast.ArrayType": - //readAstArrayType(v.Array.Value, fieldArray.Elt.(*ast.ArrayType), fileImports) - v.Array.Value.loadExpr(fieldArray.Elt.(*ast.ArrayType), fileImports) - case "*ast.Ident": - readAstType(v.Array.Value, fieldArray.Elt.(*ast.Ident), fileImports) - case "*ast.StarExpr": - readAstStarExpr(v.Array.Value, fieldArray.Elt.(*ast.StarExpr), fileImports) - case "*ast.MapType": + switch faEltType := fieldArray.Elt.(type) { + case *ast.ArrayType: + // readAstArrayType(v.Array.Value, fieldArray.Elt.(*ast.ArrayType), fileImports) + v.Array.Value.loadExpr(faEltType, fileImports) + case *ast.Ident: + readAstType(v.Array.Value, faEltType, fileImports) + case *ast.StarExpr: + readAstStarExpr(v.Array.Value, faEltType, fileImports) + case *ast.MapType: v.Array.Value.Map = &Map{ Value: &Value{}, } - readAstMapType(v.Array.Value.Map, fieldArray.Elt.(*ast.MapType), fileImports) - case "*ast.SelectorExpr": - readAstSelectorExpr(v.Array.Value, fieldArray.Elt.(*ast.SelectorExpr), fileImports) - case "*ast.StructType": - readAstStructType(v.Array.Value, fieldArray.Elt.(*ast.StructType), fileImports) - case "*ast.InterfaceType": - readAstInterfaceType(v.Array.Value, fieldArray.Elt.(*ast.InterfaceType), fileImports) + readAstMapType(v.Array.Value.Map, faEltType, fileImports) + case *ast.SelectorExpr: + readAstSelectorExpr(v.Array.Value, faEltType, fileImports) + case *ast.StructType: + readAstStructType(v.Array.Value, faEltType, fileImports) + case *ast.InterfaceType: + readAstInterfaceType(v.Array.Value, faEltType, fileImports) default: trace("---------------------> array of", reflect.ValueOf(fieldArray.Elt).Type().String()) } - case "*ast.Ident": + case *ast.Ident: fieldIdent := expr.(*ast.Ident) readAstType(v, fieldIdent, fileImports) - case "*ast.StarExpr": + case *ast.StarExpr: // a pointer on sth readAstStarExpr(v, expr.(*ast.StarExpr), fileImports) - case "*ast.MapType": + case *ast.MapType: v.Map = &Map{ Value: &Value{}, } readAstMapType(v.Map, expr.(*ast.MapType), fileImports) - case "*ast.SelectorExpr": + case *ast.SelectorExpr: readAstSelectorExpr(v, expr.(*ast.SelectorExpr), fileImports) - case "*ast.StructType": + case *ast.StructType: readAstStructType(v, expr.(*ast.StructType), fileImports) - case "*ast.InterfaceType": + case *ast.InterfaceType: readAstInterfaceType(v, expr.(*ast.InterfaceType), fileImports) default: trace("what kind of field ident would that be ?!", reflect.ValueOf(expr).Type().String()) @@ -295,6 +293,7 @@ func readField(astField *ast.Field, fileImports fileImportSpecMap) (name string, name = astField.Names[0].Name } v = &Value{} + v.loadExpr(astField.Type, fileImports) if astField.Tag != nil { jsonInfo = extractJSONInfo(astField.Tag.Value[1 : len(astField.Tag.Value)-1]) @@ -303,7 +302,6 @@ func readField(astField *ast.Field, fileImports fileImportSpecMap) (name string, } func readFieldList(fieldList []*ast.Field, fileImports fileImportSpecMap) (fields []*Field) { - fields = []*Field{} for _, field := range fieldList { name, value, jsonInfo := readField(field, fileImports) if len(name) == 0 { @@ -328,23 +326,18 @@ func readFieldList(fieldList []*ast.Field, fileImports fileImportSpecMap) (field func extractErrorTypes(file *ast.File, packageName string, errorTypes map[string]bool) (err error) { for _, d := range file.Decls { - if reflect.ValueOf(d).Type().String() == "*ast.FuncDecl" { - funcDecl := d.(*ast.FuncDecl) + if funcDecl, ok := d.(*ast.FuncDecl); ok { if funcDecl.Recv != nil && len(funcDecl.Recv.List) == 1 { firstReceiverField := funcDecl.Recv.List[0] - if "*ast.StarExpr" == reflect.ValueOf(firstReceiverField.Type).Type().String() { - starExpr := firstReceiverField.Type.(*ast.StarExpr) - if "*ast.Ident" == reflect.ValueOf(starExpr.X).Type().String() { - ident := starExpr.X.(*ast.Ident) + if starExpr, ok := firstReceiverField.Type.(*ast.StarExpr); ok { + if ident, ok := starExpr.X.(*ast.Ident); ok { if funcDecl.Name.Name == "Error" && funcDecl.Type.Params.NumFields() == 0 && funcDecl.Type.Results.NumFields() == 1 { returnValueField := funcDecl.Type.Results.List[0] - refl := reflect.ValueOf(returnValueField.Type) - if refl.Type().String() == "*ast.Ident" { - returnValueIdent := returnValueField.Type.(*ast.Ident) + if returnValueIdent, ok := returnValueField.Type.(*ast.Ident); ok { if returnValueIdent.Name == "string" { errorTypes[packageName+"."+ident.Name] = true } - //fmt.Println("error for:", ident.Name, returnValueIdent.Name) + // fmt.Println("error for:", ident.Name, returnValueIdent.Name) } } } @@ -361,12 +354,9 @@ func extractTypes(file *ast.File, packageName string, structs map[string]*Struct if obj.Kind == ast.Typ && obj.Decl != nil { structName := packageName + "." + name - if reflect.ValueOf(obj.Decl).Type().String() == "*ast.TypeSpec" { - typeSpec := obj.Decl.(*ast.TypeSpec) - typeSpecRefl := reflect.ValueOf(typeSpec.Type) - typeName := typeSpecRefl.Type().String() - switch typeName { - case "*ast.StructType": + if typeSpec, ok := obj.Decl.(*ast.TypeSpec); ok { + switch tst := typeSpec.Type.(type) { + case *ast.StructType: structs[structName] = &Struct{ Name: name, Fields: []*Field{}, @@ -375,7 +365,8 @@ func extractTypes(file *ast.File, packageName string, structs map[string]*Struct structType := typeSpec.Type.(*ast.StructType) trace("StructType", obj.Name) structs[structName].Fields = readFieldList(structType.Fields.List, fileImports) - case "*ast.Ident": + + case *ast.Ident: trace("Scalar", obj.Name) scalarIdent := typeSpec.Type.(*ast.Ident) scalarTypes[structName] = &Scalar{ @@ -383,7 +374,7 @@ func extractTypes(file *ast.File, packageName string, structs map[string]*Struct Package: packageName, Type: getScalarFromAstIdent(scalarIdent), } - case "*ast.ArrayType": + case *ast.ArrayType: arrayValue := &Value{} arrayValue.loadExpr(typeSpec.Type, fileImports) structs[structName] = &Struct{ @@ -391,7 +382,7 @@ func extractTypes(file *ast.File, packageName string, structs map[string]*Struct Package: packageName, Array: arrayValue.Array, } - case "*ast.MapType": + case *ast.MapType: mapValue := &Value{} mapValue.loadExpr(typeSpec.Type, fileImports) structs[structName] = &Struct{ @@ -400,7 +391,7 @@ func extractTypes(file *ast.File, packageName string, structs map[string]*Struct Map: mapValue.Map, } default: - fmt.Println(" ignoring", obj.Name, typeSpecRefl.Type().String()) + fmt.Printf(" ignoring %s %T\n", obj.Name, tst) } } }