-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcontext.go
141 lines (123 loc) · 3.66 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package codegen
import (
"go/token"
"go/types"
"path/filepath"
"reflect"
"strings"
"text/template"
"github.com/pkg/errors"
"golang.org/x/tools/go/types/typeutil"
)
// Context represents the context in which a code generation operation is run.
type GenContext struct {
PackageName string
templates map[string]*template.Template
imports []string
importsSeen map[string]struct{}
fset *token.FileSet
packages map[string]*types.Package
invocationsSeen []invocationSeen
generated []string
}
type invocationSeen struct {
GenTypeName string
StructName string
Args map[string]string
}
func NewGenContext(fset *token.FileSet, rootPackage *types.Package) *GenContext {
allPackages := typeutil.Dependencies(rootPackage)
packageMap := make(map[string]*types.Package)
for _, pkg := range allPackages {
packageMap[pkg.Path()] = pkg
}
ctx := &GenContext{
PackageName: rootPackage.Name(),
templates: make(map[string]*template.Template),
importsSeen: make(map[string]struct{}),
fset: fset,
packages: packageMap,
}
ctx.importsSeen[rootPackage.Path()] = struct{}{}
return ctx
}
func (ctx *GenContext) RunTemplate(invocation Invocation, aStruct *types.Named) error {
// If this exact invocation has already occurred (genType + structType +
// args), then don't do it again.
onStruct := invocationSeen{
GenTypeName: fullTypeName(invocation.GenType),
StructName: fullTypeName(aStruct),
Args: invocation.Args,
}
for _, i := range ctx.invocationsSeen {
// We have to use `reflect.DeepEqual` instead of `==` because `Args` is a
// map.
if reflect.DeepEqual(i, onStruct) {
return nil
}
}
ctx.invocationsSeen = append(ctx.invocationsSeen, onStruct)
template, err := ctx.templateForGenType(invocation.GenType)
if err != nil {
return errors.Wrap(err, "getting template")
}
generated, err := RunTemplate(template, aStruct, invocation.Args, ctx)
if err != nil {
return err
}
ctx.generated = append(ctx.generated, generated)
return nil
}
func (ctx *GenContext) AddImport(pkg string) {
if _, seen := ctx.importsSeen[pkg]; seen {
return
}
ctx.imports = append(ctx.imports, pkg)
ctx.importsSeen[pkg] = struct{}{}
}
func (ctx *GenContext) Imports() []string {
i := make([]string, len(ctx.imports))
copy(i, ctx.imports)
return i
}
func (ctx *GenContext) Generated() []string {
i := make([]string, len(ctx.generated))
copy(i, ctx.generated)
return i
}
func (ctx *GenContext) GetType(fullName string) (types.Type, error) {
lastDot := strings.LastIndex(fullName, ".")
if lastDot == -1 {
return nil, errors.Errorf("%s not a fully qualified type name", fullName)
}
pkgName := fullName[:lastDot]
name := fullName[lastDot+1:]
pkg, ok := ctx.packages[pkgName]
if !ok {
return nil, errors.Errorf("package %s not found", pkgName)
}
t := pkg.Scope().Lookup(name)
if t == nil {
return nil, errors.Errorf("type %s not found in package %s", name, pkgName)
}
// `Lookup` returns a `*types.Named`, we need the underlying type
return t.Type().Underlying(), nil
}
func fullTypeName(named *types.Named) string {
return named.Obj().Pkg().Path() + "." + named.Obj().Name()
}
func (ctx *GenContext) templateForGenType(genType *types.Named) (*template.Template, error) {
fullName := fullTypeName(genType)
if template, ok := ctx.templates[fullName]; ok {
return template, nil
}
pos := genType.Obj().Pos()
fpath := ctx.fset.Position(pos).Filename
templatePath := filepath.Join(filepath.Dir(fpath), genType.Obj().Name()+".tmpl")
template, err := ParseTemplate(templatePath)
if err != nil {
return nil, errors.Wrap(err, "parsing template")
}
ctx.templates[fullName] = template
return template, nil
}