Files
gotsrpc/internal/parser/servicereader.go
Kevin Franklin Kim da8de21dff feat: use go/package
2026-05-26 16:12:53 +02:00

1050 lines
27 KiB
Go

package parser
import (
"errors"
"fmt"
"go/ast"
"go/token"
"reflect"
"sort"
"strings"
"github.com/foomo/gotsrpc/v2/config"
"github.com/foomo/gotsrpc/v2/internal/model"
)
// interfaceInfo holds a parsed interface type and its file imports.
type interfaceInfo struct {
iface *ast.InterfaceType
imports fileImportSpecMap
typeParams []string
}
// interfaceResolver lazily loads interface definitions from any package by path.
type interfaceResolver struct {
goPaths []string
gomod config.Namespace
cache map[string]map[string]interfaceInfo
}
func newInterfaceResolver(goPaths []string, gomod config.Namespace) *interfaceResolver {
return &interfaceResolver{
goPaths: goPaths,
gomod: gomod,
cache: map[string]map[string]interfaceInfo{},
}
}
func (r *interfaceResolver) seed(packagePath string, ifaces map[string]interfaceInfo) {
r.cache[packagePath] = ifaces
}
func (r *interfaceResolver) load(packagePath string) map[string]interfaceInfo {
if m, ok := r.cache[packagePath]; ok {
return m
}
pkg, err := parsePackage(r.goPaths, r.gomod, packagePath)
if err != nil || pkg == nil {
r.cache[packagePath] = map[string]interfaceInfo{}
return r.cache[packagePath]
}
m := collectPackageInterfaces(pkg, packagePath)
r.cache[packagePath] = m
return m
}
func (r *interfaceResolver) lookup(packagePath, name string) (interfaceInfo, bool) {
info, ok := r.load(packagePath)[name]
return info, ok
}
// extractInterfaceRef resolves an embedded interface reference expression to
// (packagePath, name). Handles both same-package Ident and cross-package SelectorExpr.
func extractInterfaceRef(expr ast.Expr, imports fileImportSpecMap) (pkgPath, name string, ok bool) {
switch x := expr.(type) {
case *ast.Ident:
return imports.getPackagePath(""), x.Name, true
case *ast.SelectorExpr:
pkgIdent, isIdent := x.X.(*ast.Ident)
if !isIdent {
return "", "", false
}
return imports.getPackagePath(pkgIdent.Name), x.Sel.Name, true
}
return "", "", false
}
// resolvedMethod pairs an AST function type with the file imports it was declared in.
type resolvedMethod struct {
name string
funcTyp *ast.FuncType
imports fileImportSpecMap
typeSubst map[string]ast.Expr
substImports fileImportSpecMap // imports for resolving typeSubst expressions
}
func Read(
goPaths []string,
gomod config.Namespace,
packageName string,
serviceMap map[string]string,
missingTypes map[string]bool,
missingConstants map[string]bool,
) (
pkgName string,
services model.ServiceList,
structs map[string]*model.Struct,
scalars map[string]*model.Scalar,
constantTypes map[string]map[string]any,
err error,
) {
if len(serviceMap) == 0 {
err = errors.New("nothing to do service names are empty")
return
}
pkg, parseErr := parsePackage(goPaths, gomod, packageName)
if parseErr != nil {
err = parseErr
return
}
pkgName = pkg.Name
services, err = readServicesInPackage(pkg, packageName, serviceMap, newInterfaceResolver(goPaths, gomod))
if err != nil {
return
}
for _, s := range services {
for _, m := range s.Methods {
collectStructTypes(m.Return, missingTypes)
collectStructTypes(m.Args, missingTypes)
collectScalarTypes(m.Return, missingTypes)
collectScalarTypes(m.Args, missingTypes)
}
}
trace("missing")
traceData(missingTypes)
structs = map[string]*model.Struct{}
scalars = map[string]*model.Scalar{}
collectErr := collectTypes(goPaths, gomod, missingTypes, structs, scalars)
if collectErr != nil {
err = errors.New("error while collecting structs: " + collectErr.Error())
}
trace("---------------- found structs -------------------")
traceData(structs)
trace("---------------- /found structs -------------------")
trace("---------------- found scalars -------------------")
traceData(scalars)
trace("---------------- /found scalars -------------------")
allConstantTypes := map[string]map[string]any{}
for _, structDef := range structs {
if structDef != nil {
structPackage := structDef.Package
if _, ok := allConstantTypes[structPackage]; !ok {
if pkg, constPkgErr := parsePackage(goPaths, gomod, structPackage); constPkgErr != nil {
err = constPkgErr
return
} else {
allConstantTypes[structPackage] = loadConstantTypes(pkg)
}
}
}
}
for _, scalarDef := range scalars {
if scalarDef != nil {
scalarPackage := scalarDef.Package
if _, ok := allConstantTypes[scalarPackage]; !ok {
if pkg, constPkgErr := parsePackage(goPaths, gomod, scalarPackage); constPkgErr != nil {
err = constPkgErr
return
} else {
allConstantTypes[scalarPackage] = loadConstantTypes(pkg)
}
}
}
}
flatStructs := map[string]bool{}
for _, s := range structs {
loadFlatStructs(s, flatStructs)
}
constantTypes = map[string]map[string]any{}
for constantTypePackage, constantType := range allConstantTypes {
for constantTypeName, constantTypeVales := range constantType {
fullName := constantTypePackage + "." + constantTypeName
_, scalarOK := scalars[fullName]
_, structOK := flatStructs[fullName]
_, constantsOK := missingConstants[fullName]
if scalarOK || structOK || constantsOK {
missingConstants[fullName] = false
if _, ok := constantTypes[constantTypePackage]; !ok {
constantTypes[constantTypePackage] = map[string]any{}
}
constantTypes[constantTypePackage][constantTypeName] = constantTypeVales
}
}
}
for missingConstant, missing := range missingConstants {
if missing {
err = errors.New("could not resolve constant: " + missingConstant)
return
}
}
// fix arg and return field lists
for _, service := range services {
for _, method := range service.Methods {
fixFieldStructs(method.Args, structs, scalars)
fixFieldStructs(method.Return, structs, scalars)
}
}
traceData("---------------------------", services)
return
}
func readServiceFile(file *ast.File, packageName string, services model.ServiceList, resolver *interfaceResolver) error {
findService := func(serviceName string) (service *model.Service, ok bool) {
for _, service := range services {
if service.Name == serviceName {
return service, true
}
}
return nil, false
}
fileImports := getFileImports(file, packageName)
for _, decl := range file.Decls {
if funcDecl, ok := decl.(*ast.FuncDecl); ok {
if funcDecl.Recv != nil {
trace("that is a method named", funcDecl.Name)
if len(funcDecl.Recv.List) == 1 {
firstReceiverField := funcDecl.Recv.List[0]
if starExpr, ok := firstReceiverField.Type.(*ast.StarExpr); ok {
if ident, ok := starExpr.X.(*ast.Ident); ok {
service, ok := findService(ident.Name)
firstCharOfMethodName := funcDecl.Name.Name[0:1]
if !ok || strings.ToLower(firstCharOfMethodName) == firstCharOfMethodName {
continue
}
trace(" on sth:", ident.Name)
service.Methods = append(service.Methods, &model.Method{
Name: funcDecl.Name.Name,
Args: readFields(funcDecl.Type.Params, fileImports),
Return: readFields(funcDecl.Type.Results, fileImports),
})
}
}
}
} else {
trace("no receiver for", funcDecl.Name)
}
} else if genDecl, ok := decl.(*ast.GenDecl); ok {
if genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
ident := typeSpec.Name
trace("that is an interface named", ident.Name)
if service, ok := findService(ident.Name); ok {
if iSpec, ok := typeSpec.Type.(*ast.InterfaceType); ok {
service.IsInterface = true
resolved := resolveInterfaceMethods(iSpec, fileImports, resolver, map[string]bool{packageName + "." + ident.Name: true}, nil, nil)
for _, m := range resolved {
trace(" on sth:", m.name)
var tpNames []string
for k := range m.typeSubst {
tpNames = append(tpNames, k)
}
args := readFields(m.funcTyp.Params, m.imports, tpNames...)
ret := readFields(m.funcTyp.Results, m.imports, tpNames...)
if len(m.typeSubst) > 0 {
substituteTypeParams(args, m.typeSubst, m.substImports)
substituteTypeParams(ret, m.typeSubst, m.substImports)
}
service.Methods = append(service.Methods, &model.Method{
Name: m.name,
Args: args,
Return: ret,
})
}
}
}
}
}
}
}
for _, s := range services {
sort.Sort(s.Methods)
}
return nil
}
func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap, typeParams ...string) (fields []*model.Field) {
trace("reading fields")
fields = []*model.Field{}
if fieldList == nil {
return
}
for _, param := range fieldList.List {
names, value, _ := readField(param, fileImports, typeParams)
for _, name := range names {
fields = append(fields, &model.Field{
Name: name,
Value: value,
})
}
}
trace("done reading fields")
return
}
// substituteTypeParams replaces TypeParam entries in fields with concrete types from the substitution map.
// substImports are the imports needed to resolve the substitution expressions (may differ from the method's file imports).
func substituteTypeParams(fields []*model.Field, subst map[string]ast.Expr, substImports fileImportSpecMap) {
for _, f := range fields {
substituteValue(f.Value, subst, substImports)
}
}
// substituteValue recursively replaces TypeParam references with concrete types.
func substituteValue(v *model.Value, subst map[string]ast.Expr, substImports fileImportSpecMap) {
if v == nil {
return
}
if v.TypeParam != "" {
if expr, ok := subst[v.TypeParam]; ok {
wasPtr := v.IsPtr
*v = model.Value{}
v.IsPtr = wasPtr
loadValueExpr(v, expr, substImports, nil)
}
return
}
if v.Array != nil {
substituteValue(v.Array.Value, subst, substImports)
}
if v.Map != nil {
substituteValue(v.Map.Key, subst, substImports)
substituteValue(v.Map.Value, subst, substImports)
}
for _, arg := range v.TypeArgs {
substituteValue(arg, subst, substImports)
}
}
func readServicesInPackage(pkg *parsedPackage, packageName string, serviceMap map[string]string, resolver *interfaceResolver) (services model.ServiceList, err error) {
if pkg == nil {
return nil, errors.New("package cannot be nil")
}
services = model.ServiceList{}
for endpoint, serviceName := range serviceMap {
services = append(services, &model.Service{
Name: serviceName,
Methods: []*model.Method{},
Endpoint: endpoint,
})
}
pkgInterfaces := collectPackageInterfaces(pkg, packageName)
resolver.seed(packageName, pkgInterfaces)
pkgFiles := make([]string, 0, len(pkg.Files))
for k := range pkg.Files {
pkgFiles = append(pkgFiles, k)
}
sort.Strings(pkgFiles)
for _, k := range pkgFiles {
file := pkg.Files[k]
err = readServiceFile(file, packageName, services, resolver)
if err != nil {
return
}
}
sort.Sort(services)
return
}
func loadConstantTypes(pkg *parsedPackage) map[string]any {
constantTypes := map[string]any{}
for _, file := range pkg.Files {
for _, decl := range file.Decls {
if genDecl, ok := decl.(*ast.GenDecl); ok {
switch genDecl.Tok {
case token.TYPE:
trace("got a type", genDecl.Specs)
for _, spec := range genDecl.Specs {
if spec, ok := spec.(*ast.TypeSpec); ok {
if _, ok := constantTypes[spec.Name.Name]; ok {
continue
}
switch specType := spec.Type.(type) {
case *ast.InterfaceType:
constantTypes[spec.Name.Name] = "any"
case *ast.Ident:
switch specType.Name {
case "byte":
constantTypes[spec.Name.Name] = "any"
case "string":
constantTypes[spec.Name.Name] = "string"
case "bool":
constantTypes[spec.Name.Name] = "boolean"
case "float", "float32", "float64",
"int", "int8", "int16", "int32", "int64",
"uint", "uint8", "uint16", "uint32", "uint64":
constantTypes[spec.Name.Name] = "number"
default:
trace("unhandled type", reflect.ValueOf(spec.Type).Type().String())
}
default:
trace("ignoring type", reflect.ValueOf(spec.Type).Type().String())
}
}
}
case token.CONST:
trace("got a const", genDecl.Specs)
for _, spec := range genDecl.Specs {
if spec, ok := spec.(*ast.ValueSpec); ok {
if specType, ok := spec.Type.(*ast.Ident); ok {
for _, val := range spec.Values {
if valType, ok := val.(*ast.BasicLit); ok {
if _, ok := constantTypes[specType.Name]; !ok {
constantTypes[specType.Name] = map[string]*ast.BasicLit{}
} else if _, ok := constantTypes[specType.Name].(map[string]*ast.BasicLit); !ok {
constantTypes[specType.Name] = map[string]*ast.BasicLit{}
}
constantTypes[specType.Name].(map[string]*ast.BasicLit)[spec.Names[0].Name] = valType //nolint:forcetypeassert
}
}
}
}
}
default:
trace("ignoring", genDecl.Tok)
}
}
}
}
return constantTypes
}
func loadFlatStructs(s *model.Struct, flatStructs map[string]bool) {
if s.Map != nil {
if s.Map.Key != nil {
loadFlatStructsValue(s.Map.Key, flatStructs)
}
if s.Map.Value != nil && s.Map.Value.Scalar != nil {
loadFlatStructsValue(s.Map.Value, flatStructs)
}
}
if s.Fields != nil {
for _, field := range s.Fields {
loadFlatStructsValue(field.Value, flatStructs)
}
}
flatStructs[s.FullName()] = true
}
func loadFlatStructsValue(s *model.Value, flatStructs map[string]bool) {
if s.Map != nil {
if s.Map.Key != nil {
loadFlatStructsValue(s.Map.Key, flatStructs)
}
if s.Map.Value != nil && s.Map.Value.Scalar != nil {
loadFlatStructsValue(s.Map.Value, flatStructs)
}
}
if s.Struct != nil {
loadFlatStructs(s.Struct, flatStructs)
}
if s.Scalar != nil {
flatStructs[s.Scalar.FullName()] = true
}
for _, arg := range s.TypeArgs {
loadFlatStructsValue(arg, flatStructs)
}
}
func fixFieldStructs(fields []*model.Field, structs map[string]*model.Struct, scalars map[string]*model.Scalar) {
for _, f := range fields {
if f.Value.StructType != nil {
name := f.Value.StructType.FullName()
s, strctExists := structs[name]
if strctExists {
f.Value.IsError = s.IsError
continue
}
scalar, scalarExists := scalars[name]
if scalarExists {
f.Value.StructType = nil
f.Value.Scalar = scalar
}
}
}
}
func collectTypes(goPaths []string, gomod config.Namespace, missingTypes map[string]bool, structs map[string]*model.Struct, scalars map[string]*model.Scalar) error {
scannedPackageStructs := map[string]map[string]*model.Struct{}
scannedPackageScalars := map[string]map[string]*model.Scalar{}
missingTypeNames := func() []string {
var missing []string
for name, isMissing := range missingTypes {
if isMissing {
missing = append(missing, name)
}
}
return missing
}
lastNumMissing := len(missingTypeNames())
for typesPending(structs, scalars, missingTypes) {
trace("pending", missingTypeNames())
for fullName, typeIsMissing := range missingTypes {
if !typeIsMissing {
continue
}
fullNameParts := strings.Split(fullName, ".")
fullNameParts = fullNameParts[:len(fullNameParts)-1]
packageName := strings.Join(fullNameParts, ".")
trace(fullName, "==========================>", fullNameParts, "=============>", packageName)
packageStructs, structOK := scannedPackageStructs[packageName]
packageScalars, scalarOK := scannedPackageScalars[packageName]
if !structOK || !scalarOK {
parsedPackageStructs, parsedPackageScalars, err := getTypesInPackage(goPaths, gomod, packageName)
if err != nil {
return err
}
trace("found structs in", goPaths, packageName)
for structName, strct := range packageStructs {
trace(" struct", structName, strct)
if strct == nil {
panic("how could that be")
}
}
trace("found scalars in", goPaths, packageName)
for scalarName, scalar := range packageScalars {
trace(" scalar", scalarName, scalar)
}
traceData(parsedPackageScalars)
packageStructs = parsedPackageStructs
packageScalars = parsedPackageScalars
scannedPackageStructs[packageName] = packageStructs
scannedPackageScalars[packageName] = packageScalars
}
traceData("packageStructs", packageName, packageStructs)
for packageStructName, packageStruct := range packageStructs {
missing, needed := missingTypes[packageStructName]
if needed && missing {
trace("picked up package struct", packageStructName, packageStruct)
missingTypes[packageStructName] = false
if packageStruct == nil {
panic("waaaaaaaaa")
}
structs[packageStructName] = packageStruct
}
}
traceData("packageScalars", packageScalars)
for packageScalarName, packageScalar := range packageScalars {
missing, needed := missingTypes[packageScalarName]
if needed && missing {
trace("picked up package scalar", packageScalarName, packageScalar)
missingTypes[packageScalarName] = false
scalars[packageScalarName] = packageScalar
}
}
}
newNumMissingTypes := len(missingTypeNames())
if newNumMissingTypes > 0 && newNumMissingTypes == lastNumMissing {
for scalarName, scalars := range scannedPackageScalars {
fmt.Println("scanned scalars ", scalarName)
for _, scalar := range scalars {
fmt.Println(" ", scalar.Name)
}
}
for structName, strcts := range scannedPackageStructs {
fmt.Println("scanned struct ", structName)
for _, strct := range strcts {
fmt.Println(" ", strct.Name)
}
}
return errors.New(fmt.Sprintln("could not resolve at least one of the following types", missingTypeNames()))
}
lastNumMissing = newNumMissingTypes
}
return nil
}
func typesPending(structs map[string]*model.Struct, scalars map[string]*model.Scalar, missingTypes map[string]bool) bool {
for _, missing := range missingTypes {
if missing {
return true
}
}
for _, structType := range structs {
if !depsSatisfied(structType, missingTypes, structs, scalars) {
return true
}
}
return false
}
func needsWorkValue(value *model.Value, needsWork func(fullName string) bool) bool {
switch {
case value.Scalar != nil:
if needsWork(value.Scalar.FullName()) {
return true
}
case value.StructType != nil:
if needsWork(value.StructType.FullName()) {
return true
}
case value.Array != nil:
if needsWorkValue(value.Array.Value, needsWork) {
return true
}
case value.Map != nil:
if needsWorkValue(value.Map.Key, needsWork) || needsWorkValue(value.Map.Value, needsWork) {
return true
}
}
for _, arg := range value.TypeArgs {
if needsWorkValue(arg, needsWork) {
return true
}
}
return false
}
func depsSatisfied(s *model.Struct, missingTypes map[string]bool, structs map[string]*model.Struct, scalars map[string]*model.Scalar) bool {
needsWork := func(fullName string) bool {
strct, strctOK := structs[fullName]
scalar, scalarOK := scalars[fullName]
if !strctOK && !scalarOK {
missingTypes[fullName] = true
trace("need work ----------------------" + fullName)
return true
}
if strct == nil && scalar == nil {
trace("need work ----------------------" + fullName)
return true
}
return false
}
needWorksFields := func(fields []*model.Field) bool {
for _, field := range fields {
if needsWorkValue(field.Value, needsWork) {
return false
}
}
return true
}
if ok := needWorksFields(s.Fields); !ok {
return false
} else if ok := needWorksFields(s.InlineFields); !ok {
return false
} else if ok := needWorksFields(s.UnionFields); !ok {
return false
}
if s.Array != nil {
if s.Array.Value != nil && needsWorkValue(s.Array.Value, needsWork) {
return false
}
}
if s.Map != nil {
if s.Map.Key != nil && needsWorkValue(s.Map.Key, needsWork) {
return false
}
if s.Map.Value != nil && needsWorkValue(s.Map.Value, needsWork) {
return false
}
}
return !needsWork(s.FullName())
}
func getTypesInPackage(goPaths []string, gomod config.Namespace, packageName string) (
structs map[string]*model.Struct,
scalars map[string]*model.Scalar,
err error,
) {
pkg, err := parsePackage(goPaths, gomod, packageName)
if err != nil {
return nil, nil, err
}
structs, scalars, err = readStructs(pkg, packageName)
if err != nil {
return nil, nil, err
}
return structs, scalars, nil
}
func getStructTypesForField(value *model.Value) []*model.StructType {
var types []*model.StructType
switch {
case value.StructType != nil:
types = append(types, value.StructType)
case value.Map != nil:
types = append(types, getStructTypesForField(value.Map.Value)...)
case value.Array != nil:
types = append(types, getStructTypesForField(value.Array.Value)...)
}
for _, arg := range value.TypeArgs {
types = append(types, getStructTypesForField(arg)...)
}
return types
}
func getScalarForField(value *model.Value) []*model.Scalar {
var scalarTypes []*model.Scalar
switch {
case value.Scalar != nil:
scalarTypes = append(scalarTypes, value.Scalar)
case value.StructType != nil:
// TypeArgs handled by outer loop below
case value.Map != nil:
if value.Map.Key != nil {
if v := getScalarForField(value.Map.Key); v != nil {
scalarTypes = append(scalarTypes, v...)
}
}
scalarTypes = append(scalarTypes, getScalarForField(value.Map.Value)...)
case value.Array != nil:
scalarTypes = append(scalarTypes, getScalarForField(value.Array.Value)...)
}
for _, arg := range value.TypeArgs {
scalarTypes = append(scalarTypes, getScalarForField(arg)...)
}
return scalarTypes
}
func collectScalarTypes(fields []*model.Field, scalarTypes map[string]bool) {
for _, field := range fields {
for _, scalarType := range getScalarForField(field.Value) {
if scalarType != nil {
fullName := scalarType.Package + "." + scalarType.Name
if len(scalarType.Package) == 0 {
fullName = scalarType.Name
}
switch fullName {
case "error", "net/http.Request", "net/http.ResponseWriter", "context.Context":
continue
default:
scalarTypes[fullName] = true
}
}
}
}
}
func collectStructTypes(fields []*model.Field, structTypes map[string]bool) {
for _, field := range fields {
for _, strType := range getStructTypesForField(field.Value) {
if strType != nil {
fullName := strType.Package + "." + strType.Name
if len(strType.Package) == 0 {
fullName = strType.Name
}
switch fullName {
case "error", "net/http.Request", "net/http.ResponseWriter", "context.Context":
continue
default:
structTypes[fullName] = true
}
}
}
}
}
// collectPackageInterfaces scans all files in the package and builds a map
// of interface names to their AST and file imports.
func collectPackageInterfaces(pkg *parsedPackage, packageName string) map[string]interfaceInfo {
result := map[string]interfaceInfo{}
for _, file := range pkg.Files {
fileImports := getFileImports(file, packageName)
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
iface, ok := typeSpec.Type.(*ast.InterfaceType)
if !ok {
continue
}
var typeParams []string
if typeSpec.TypeParams != nil {
for _, tp := range typeSpec.TypeParams.List {
for _, n := range tp.Names {
typeParams = append(typeParams, n.Name)
}
}
}
result[typeSpec.Name.Name] = interfaceInfo{
iface: iface,
imports: fileImports,
typeParams: typeParams,
}
}
}
}
return result
}
// resolveExpr resolves an AST expression through a type substitution map.
func resolveExpr(expr ast.Expr, typeSubst map[string]ast.Expr) ast.Expr {
if ident, ok := expr.(*ast.Ident); ok {
if sub, ok := typeSubst[ident.Name]; ok {
return sub
}
}
return expr
}
// resolveInterfaceMethods recursively collects all methods from an interface,
// following embedded interfaces via the pkgInterfaces map. Uses visited for cycle protection.
// typeSubst maps type parameter names to concrete type expressions.
// substImports are the imports needed to resolve expressions in typeSubst.
func resolveInterfaceMethods(iface *ast.InterfaceType, imports fileImportSpecMap, resolver *interfaceResolver, visited map[string]bool, typeSubst map[string]ast.Expr, substImports fileImportSpecMap) []resolvedMethod {
var methods []resolvedMethod
for _, field := range iface.Methods.List {
switch ft := field.Type.(type) {
case *ast.FuncType:
if len(field.Names) == 0 {
continue
}
methods = append(methods, resolvedMethod{
name: field.Names[0].Name,
funcTyp: ft,
imports: imports,
typeSubst: typeSubst,
substImports: substImports,
})
case *ast.Ident, *ast.SelectorExpr:
// Embedded interface reference (non-generic), possibly cross-package.
pkgPath, name, ok := extractInterfaceRef(ft, imports)
if !ok {
continue
}
visitKey := pkgPath + "." + name
if visited[visitKey] {
continue
}
visited[visitKey] = true
info, ok := resolver.lookup(pkgPath, name)
if !ok {
continue
}
methods = append(methods, resolveInterfaceMethods(info.iface, info.imports, resolver, visited, nil, nil)...)
case *ast.IndexExpr:
// Generic embedded interface with single type arg: Base[string], Base[T], or private.Base[T].
pkgPath, name, ok := extractInterfaceRef(ft.X, imports)
if !ok {
continue
}
visitKey := pkgPath + "." + name
if visited[visitKey] {
continue
}
visited[visitKey] = true
info, ok := resolver.lookup(pkgPath, name)
if !ok {
continue
}
// Build substitution map for the embedded interface's type params.
// Determine the imports needed to resolve the substitution expressions:
// if the arg was resolved from the parent's typeSubst, use substImports;
// otherwise use the current imports (where the embedding is written).
newSubst := map[string]ast.Expr{}
newSubstImports := imports
resolvedArg := resolveExpr(ft.Index, typeSubst)
if resolvedArg != ft.Index && substImports != nil {
newSubstImports = substImports
}
if len(info.typeParams) > 0 {
newSubst[info.typeParams[0]] = resolvedArg
}
methods = append(methods, resolveInterfaceMethods(info.iface, info.imports, resolver, visited, newSubst, newSubstImports)...)
case *ast.IndexListExpr:
// Generic embedded interface with multiple type args: Keyed[string, int] or private.Keyed[K, V].
pkgPath, name, ok := extractInterfaceRef(ft.X, imports)
if !ok {
continue
}
visitKey := pkgPath + "." + name
if visited[visitKey] {
continue
}
visited[visitKey] = true
info, ok := resolver.lookup(pkgPath, name)
if !ok {
continue
}
newSubst := map[string]ast.Expr{}
newSubstImports := imports
for i, idx := range ft.Indices {
resolvedArg := resolveExpr(idx, typeSubst)
if resolvedArg != idx && substImports != nil {
newSubstImports = substImports
}
if i < len(info.typeParams) {
newSubst[info.typeParams[i]] = resolvedArg
}
}
methods = append(methods, resolveInterfaceMethods(info.iface, info.imports, resolver, visited, newSubst, newSubstImports)...)
}
}
return methods
}