package gotsrpc import ( "errors" "fmt" "go/ast" "go/token" "reflect" "sort" "strings" ) func readServiceFile(file *ast.File, packageName string, services map[string]*Service) error { findService := func(serviceName string) (service *Service, ok bool) { for _, service := range services { if service.Name == serviceName { return service, true } } return nil, false } fileImports := getFileImports(file, packageName) for _, decl := range file.Decls { if reflect.ValueOf(decl).Type().String() == "*ast.FuncDecl" { funcDecl := decl.(*ast.FuncDecl) if funcDecl.Recv != nil { trace("that is a method named", funcDecl.Name) if len(funcDecl.Recv.List) == 1 { firstReceiverField := funcDecl.Recv.List[0] if "*ast.StarExpr" == reflect.ValueOf(firstReceiverField.Type).Type().String() { starExpr := firstReceiverField.Type.(*ast.StarExpr) if "*ast.Ident" == reflect.ValueOf(starExpr.X).Type().String() { ident := starExpr.X.(*ast.Ident) service, ok := findService(ident.Name) firstCharOfMethodName := funcDecl.Name.Name[0:1] if !ok || strings.ToLower(firstCharOfMethodName) == firstCharOfMethodName { // skip this method continue } trace(" on sth:", ident.Name) service.Methods = append(service.Methods, &Method{ Name: funcDecl.Name.Name, Args: readFields(funcDecl.Type.Params, fileImports), Return: readFields(funcDecl.Type.Results, fileImports), }) } } } } else { trace("no receiver for", funcDecl.Name) } } } for _, s := range services { sort.Sort(s.Methods) } return nil } type importSpec struct { alias string name string path string } type fileImportSpecMap map[string]importSpec func (fileImports fileImportSpecMap) getPackagePath(packageName string) string { importSpec, ok := fileImports[packageName] if ok { packageName = importSpec.path } return packageName } func standardImportName(importPath string) string { pathParts := strings.Split(importPath, "/") return pathParts[len(pathParts)-1] } func getFileImports(file *ast.File, packageName string) (imports fileImportSpecMap) { imports = fileImportSpecMap{"": importSpec{alias: "", name: "", path: packageName}} for _, decl := range file.Decls { if reflect.ValueOf(decl).Type().String() == "*ast.GenDecl" { genDecl := decl.(*ast.GenDecl) if genDecl.Tok == token.IMPORT { trace("got an import", genDecl.Specs) for _, spec := range genDecl.Specs { if "*ast.ImportSpec" == reflect.ValueOf(spec).Type().String() { spec := spec.(*ast.ImportSpec) importPath := spec.Path.Value[1 : len(spec.Path.Value)-1] importName := spec.Name.String() if importName == "" || importName == "" { importName = standardImportName(importPath) } imports[importName] = importSpec{ alias: importName, name: standardImportName(importPath), path: importPath, } //trace(" import >>>>>>>>>>>>>>>>>>>>", importName, importPath) } } } } } return imports } func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap) (fields []*Field) { trace("reading fields") fields = []*Field{} if fieldList == nil { return } for _, param := range fieldList.List { name, value, _ := readField(param, fileImports) fields = append(fields, &Field{ Name: name, Value: value, }) } trace("done reading fields") return } func readServicesInPackage(pkg *ast.Package, packageName string, serviceMap map[string]string) (services map[string]*Service, err error) { services = map[string]*Service{} for endpoint, serviceName := range serviceMap { services[endpoint] = &Service{ Name: serviceName, Methods: []*Method{}, } } for _, file := range pkg.Files { err = readServiceFile(file, packageName, services) if err != nil { return } } return } func loadConstants(pkg *ast.Package) map[string]*ast.BasicLit { constants := map[string]*ast.BasicLit{} for _, file := range pkg.Files { for _, decl := range file.Decls { if reflect.ValueOf(decl).Type().String() == "*ast.GenDecl" { genDecl := decl.(*ast.GenDecl) if genDecl.Tok == token.CONST { trace("got a const", genDecl.Specs) for _, spec := range genDecl.Specs { if "*ast.ValueSpec" == reflect.ValueOf(spec).Type().String() { spec := spec.(*ast.ValueSpec) for _, val := range spec.Values { if reflect.ValueOf(val).Type().String() == "*ast.BasicLit" { firstValueLit := val.(*ast.BasicLit) constName := spec.Names[0].String() for indexRune, r := range constName { if indexRune == 0 { if string(r) == strings.ToUpper(string(r)) { constants[constName] = firstValueLit } break } } } } } } } } } } return constants } func Read(goPaths []string, packageName string, serviceMap map[string]string) (services map[string]*Service, structs map[string]*Struct, scalars map[string]*Scalar, constants map[string]map[string]*ast.BasicLit, err error) { if len(serviceMap) == 0 { err = errors.New("nothing to do service names are empty") return } pkg, parseErr := parsePackage(goPaths, packageName) if parseErr != nil { err = parseErr return } services, err = readServicesInPackage(pkg, packageName, serviceMap) if err != nil { return } missingTypes := map[string]bool{} for _, s := range services { for _, m := range s.Methods { collectStructTypes(m.Return, missingTypes) collectStructTypes(m.Args, missingTypes) collectScalarTypes(m.Return, missingTypes) collectScalarTypes(m.Args, missingTypes) } } trace("missing") traceJSON(missingTypes) structs = map[string]*Struct{} scalars = map[string]*Scalar{} collectErr := collectTypes(goPaths, missingTypes, structs, scalars) if collectErr != nil { err = errors.New("error while collecting structs: " + collectErr.Error()) } trace("---------------- found structs -------------------") traceJSON(structs) trace("---------------- found scalars -------------------") traceJSON(scalars) constants = map[string]map[string]*ast.BasicLit{} for _, structDef := range structs { if structDef != nil { structPackage := structDef.Package _, ok := constants[structPackage] if !ok { pkg, constPkgErr := parsePackage(goPaths, structPackage) if constPkgErr != nil { err = constPkgErr return } constants[structPackage] = loadConstants(pkg) } } } // fix arg and return field lists for _, service := range services { for _, method := range service.Methods { fixFieldStructs(method.Args, structs, scalars) fixFieldStructs(method.Return, structs, scalars) } } traceJSON("---------------------------", services) return } func fixFieldStructs(fields []*Field, structs map[string]*Struct, scalars map[string]*Scalar) { for _, f := range fields { if f.Value.StructType != nil { // do we have that struct or is it a hidden scalar name := f.Value.StructType.FullName() _, strctExists := structs[name] if strctExists { continue } scalar, scalarExists := scalars[name] if scalarExists { f.Value.StructType = nil f.Value.Scalar = scalar } } } } func collectTypes(goPaths []string, missingTypes map[string]bool, structs map[string]*Struct, scalars map[string]*Scalar) error { scannedPackageStructs := map[string]map[string]*Struct{} scannedPackageScalars := map[string]map[string]*Scalar{} missingTypeNames := func() []string { missing := []string{} for name, isMissing := range missingTypes { if isMissing { missing = append(missing, name) } } return missing } lastNumMissing := len(missingTypeNames()) for typesPending(structs, scalars, missingTypes) { trace("pending", missingTypeNames()) for fullName, typeIsMissing := range missingTypes { if !typeIsMissing { continue } fullNameParts := strings.Split(fullName, ".") fullNameParts = fullNameParts[:len(fullNameParts)-1] //path := fullNameParts[:len(fullNameParts)-1][0] packageName := strings.Join(fullNameParts, ".") trace(fullName, "==========================>", fullNameParts, "=============>", packageName) packageStructs, structOK := scannedPackageStructs[packageName] packageScalars, scalarOK := scannedPackageScalars[packageName] if !structOK || !scalarOK { parsedPackageStructs, parsedPackageScalars, err := getTypesInPackage(goPaths, packageName) if err != nil { return err } trace("found structs in", goPaths, packageName) for structName, strct := range packageStructs { trace(" struct", structName, strct) if strct == nil { panic("how could that be") } } trace("found scalars in", goPaths, packageName) for scalarName, scalar := range parsedPackageScalars { trace(" scalar", scalarName, scalar) } traceJSON(parsedPackageScalars) packageStructs = parsedPackageStructs packageScalars = parsedPackageScalars scannedPackageStructs[packageName] = packageStructs scannedPackageScalars[packageName] = packageScalars } traceJSON("packageStructs", packageName, packageStructs) for packageStructName, packageStruct := range packageStructs { missing, needed := missingTypes[packageStructName] if needed && missing { trace("picked up package struct", packageStructName, packageStruct) missingTypes[packageStructName] = false if packageStruct == nil { panic("waaaaaaaaa") } structs[packageStructName] = packageStruct } } traceJSON("packageScalars", packageScalars) for packageScalarName, packageScalar := range packageScalars { missing, needed := missingTypes[packageScalarName] if needed && missing { trace("picked up package scalar", packageScalarName, packageScalar) missingTypes[packageScalarName] = false scalars[packageScalarName] = packageScalar } } } newNumMissingTypes := len(missingTypeNames()) if newNumMissingTypes > 0 && newNumMissingTypes == lastNumMissing { return errors.New(fmt.Sprintln("could not resolve at least one of the following types", missingTypeNames())) } lastNumMissing = newNumMissingTypes } return nil } func typesPending(structs map[string]*Struct, scalars map[string]*Scalar, missingTypes map[string]bool) bool { for _, missing := range missingTypes { if missing { return true } } for _, structType := range structs { if !structType.DepsSatisfied(missingTypes, structs, scalars) { return true } } return false } func (s *Struct) DepsSatisfied(missingTypes map[string]bool, structs map[string]*Struct, scalarTypes map[string]*Scalar) bool { needsWork := func(fullName string) bool { strct, strctOK := structs[fullName] scalar, scalarOK := scalarTypes[fullName] if !strctOK && !scalarOK { // hey there is more todo missingTypes[fullName] = true trace("need work ----------------------" + fullName) return true } if strct == nil && scalar == nil { trace("need work ----------------------" + fullName) return true } return false } for _, field := range s.Fields { var fieldStructType *StructType fieldStructType = nil if field.Value.StructType != nil { fieldStructType = field.Value.StructType } else if field.Value.Array != nil && field.Value.Array.Value.StructType != nil { fieldStructType = field.Value.Array.Value.StructType } else if field.Value.Map != nil && field.Value.Map.Value.StructType != nil { fieldStructType = field.Value.Map.Value.StructType } if fieldStructType != nil { if needsWork(fieldStructType.FullName()) { return false } } } return !needsWork(s.FullName()) } func (s *Struct) FullName() string { fullName := s.Package + "." + s.Name if len(fullName) == 0 { fullName = s.Name } return fullName } func (st *StructType) FullName() string { fullName := st.Package + "." + st.Name if len(fullName) == 0 { fullName = st.Name } return fullName } func getTypesInPackage(goPaths []string, packageName string) (structs map[string]*Struct, scalars map[string]*Scalar, err error) { pkg, err := parsePackage(goPaths, packageName) if err != nil { return nil, nil, err } structs, scalars, err = readStructs(pkg, packageName) if err != nil { return nil, nil, err } return structs, scalars, nil } func getStructTypeForField(value *Value) *StructType { //field.Value.StructType var strType *StructType switch true { case value.StructType != nil: strType = value.StructType //case field.Value.ArrayType case value.Map != nil: strType = getStructTypeForField(value.Map.Value) case value.Array != nil: strType = getStructTypeForField(value.Array.Value) } return strType } func getScalarForField(value *Value) *Scalar { //field.Value.StructType var scalarType *Scalar switch true { case value.Scalar != nil: scalarType = value.Scalar //case field.Value.ArrayType case value.Map != nil: scalarType = getScalarForField(value.Map.Value) case value.Array != nil: scalarType = getScalarForField(value.Array.Value) } return scalarType } func collectScalarTypes(fields []*Field, scalarTypes map[string]bool) { for _, field := range fields { scalarType := getScalarForField(field.Value) if scalarType != nil { fullName := scalarType.Package + "." + scalarType.Name if len(scalarType.Package) == 0 { fullName = scalarType.Name } scalarTypes[fullName] = true } } } func collectStructTypes(fields []*Field, structTypes map[string]bool) { for _, field := range fields { strType := getStructTypeForField(field.Value) if strType != nil { fullName := strType.Package + "." + strType.Name if len(strType.Package) == 0 { fullName = strType.Name } switch fullName { case "error", "net/http.Request", "net/http.ResponseWriter": continue default: structTypes[fullName] = true } } } } //func collectStructs(goPath, structs)