centralized gopath handling

This commit is contained in:
Frederik Löffert 2017-02-15 15:27:40 +01:00
parent 2c24d2c7a1
commit ba828f4713
3 changed files with 33 additions and 25 deletions

View File

@ -7,6 +7,7 @@ import (
"io/ioutil"
"os"
"path"
"runtime"
"sort"
"strings"
@ -58,7 +59,7 @@ func Build(conf *config.Config, goPath string) {
longPackageName := target.Package
longPackageNameParts := strings.Split(longPackageName, "/")
goRPCProxiesFilename := path.Join(goPath, "src", longPackageName, "gorpc.go")
goRPCProxiesFilename := path.Join(goPath, "src", longPackageName, "gorpc.go")
goRPCClientsFilename := path.Join(goPath, "src", longPackageName, "gorpcclient.go")
goTSRPCProxiesFilename := path.Join(goPath, "src", longPackageName, "gotsrpc.go")
goTSRPCClientsFilename := path.Join(goPath, "src", longPackageName, "gotsrpcclient.go")
@ -76,8 +77,8 @@ func Build(conf *config.Config, goPath string) {
remove(goTSRPCClientsFilename)
packageName := longPackageNameParts[len(longPackageNameParts)-1]
services, structs, scalarTypes, constants, err := Read(goPath, longPackageName, target.Services)
goPaths := []string{goPath, runtime.GOROOT()}
services, structs, scalarTypes, constants, err := Read(goPaths, longPackageName, target.Services)
if err != nil {
fmt.Fprintln(os.Stderr, " an error occured while trying to understand your code", err)

View File

@ -3,12 +3,12 @@ package gotsrpc
import (
"context"
"encoding/json"
"io/ioutil"
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"net/http"
"path"
"strings"
@ -94,12 +94,22 @@ func jsonDump(v interface{}) {
fmt.Println(string(jsonBytes))
}
func parsePackage(goPath string, packageName string) (pkg *ast.Package, err error) {
fset := token.NewFileSet()
dir := path.Join(goPath, "src", packageName)
pkgs, err := parser.ParseDir(fset, dir, nil, parser.AllErrors)
func parseDir(goPaths []string, packageName string) (map[string]*ast.Package, error) {
for _, goPath := range goPaths {
fset := token.NewFileSet()
dir := path.Join(goPath, "src", packageName)
pkgs, err := parser.ParseDir(fset, dir, nil, parser.AllErrors)
if err == nil {
return pkgs, nil
}
}
return nil, errors.New("could not parse dir for package name: " + packageName + " in goPaths " + strings.Join(goPaths, ", "))
}
func parsePackage(goPaths []string, packageName string) (pkg *ast.Package, err error) {
pkgs, err := parseDir(goPaths, packageName)
if err != nil {
return nil, err
return nil, errors.New("could not parse package " + packageName + ": " + err.Error())
}
packageNameParts := strings.Split(packageName, "/")
if len(packageNameParts) == 0 {

View File

@ -6,7 +6,6 @@ import (
"go/ast"
"go/token"
"reflect"
"runtime"
"sort"
"strings"
)
@ -182,13 +181,14 @@ func loadConstants(pkg *ast.Package) map[string]*ast.BasicLit {
}
func Read(goPath string, packageName string, serviceMap map[string]string) (services map[string]*Service, structs map[string]*Struct, scalars map[string]*Scalar, constants map[string]map[string]*ast.BasicLit, err error) {
func Read(goPaths []string, packageName string, serviceMap map[string]string) (services map[string]*Service, structs map[string]*Struct, scalars map[string]*Scalar, constants map[string]map[string]*ast.BasicLit, err error) {
if len(serviceMap) == 0 {
err = errors.New("nothing to do service names are empty")
return
}
pkg, err := parsePackage(goPath, packageName)
if err != nil {
pkg, parseErr := parsePackage(goPaths, packageName)
if parseErr != nil {
err = parseErr
return
}
services, err = readServicesInPackage(pkg, packageName, serviceMap)
@ -211,7 +211,7 @@ func Read(goPath string, packageName string, serviceMap map[string]string) (serv
structs = map[string]*Struct{}
scalars = map[string]*Scalar{}
collectErr := collectTypes(goPath, missingTypes, structs, scalars)
collectErr := collectTypes(goPaths, missingTypes, structs, scalars)
if collectErr != nil {
err = errors.New("error while collecting structs: " + collectErr.Error())
}
@ -225,7 +225,7 @@ func Read(goPath string, packageName string, serviceMap map[string]string) (serv
structPackage := structDef.Package
_, ok := constants[structPackage]
if !ok {
pkg, constPkgErr := parsePackage(goPath, structPackage)
pkg, constPkgErr := parsePackage(goPaths, structPackage)
if constPkgErr != nil {
err = constPkgErr
return
@ -265,7 +265,7 @@ func fixFieldStructs(fields []*Field, structs map[string]*Struct, scalars map[st
}
}
func collectTypes(goPath string, missingTypes map[string]bool, structs map[string]*Struct, scalars map[string]*Scalar) error {
func collectTypes(goPaths []string, missingTypes map[string]bool, structs map[string]*Struct, scalars map[string]*Scalar) error {
scannedPackageStructs := map[string]map[string]*Struct{}
scannedPackageScalars := map[string]map[string]*Scalar{}
missingTypeNames := func() []string {
@ -297,19 +297,19 @@ func collectTypes(goPath string, missingTypes map[string]bool, structs map[strin
packageStructs, structOK := scannedPackageStructs[packageName]
packageScalars, scalarOK := scannedPackageScalars[packageName]
if !structOK || !scalarOK {
parsedPackageStructs, parsedPackageScalars, err := getTypesInPackage(goPath, packageName)
parsedPackageStructs, parsedPackageScalars, err := getTypesInPackage(goPaths, packageName)
if err != nil {
return err
}
trace("found structs in", goPath, packageName)
trace("found structs in", goPaths, packageName)
for structName, strct := range packageStructs {
trace(" struct", structName, strct)
if strct == nil {
panic("how could that be")
}
}
trace("found scalars in", goPath, packageName)
trace("found scalars in", goPaths, packageName)
for scalarName, scalar := range parsedPackageScalars {
trace(" scalar", scalarName, scalar)
}
@ -417,13 +417,10 @@ func (st *StructType) FullName() string {
return fullName
}
func getTypesInPackage(goPath string, packageName string) (structs map[string]*Struct, scalars map[string]*Scalar, err error) {
pkg, err := parsePackage(goPath, packageName)
func getTypesInPackage(goPaths []string, packageName string) (structs map[string]*Struct, scalars map[string]*Scalar, err error) {
pkg, err := parsePackage(goPaths, packageName)
if err != nil {
pkg, err = parsePackage(runtime.GOROOT(), packageName)
if err != nil {
return nil, nil, err
}
return nil, nil, err
}
structs, scalars, err = readStructs(pkg, packageName)
if err != nil {