Skip to content

Commit

Permalink
Merge branch 'main' into haris/schemas-fix-nullable-fields
Browse files Browse the repository at this point in the history
  • Loading branch information
hariso authored Nov 7, 2024
2 parents 749884c + 9537066 commit 1d0075b
Show file tree
Hide file tree
Showing 31 changed files with 428 additions and 56 deletions.
16 changes: 16 additions & 0 deletions lang/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@

package lang

// Ptr returns a pointer to the value passed in.
func Ptr[T any](t T) *T {
return &t
}

// ValOrZero returns the value of the pointer passed in or the zero value of the
// type if the pointer is nil.
func ValOrZero[T any](t *T) T {
if t == nil {
return Zero[T]()
}
return *t
}

// Zero returns the zero value of the type passed in.
func Zero[T any]() T {
var t T
return t
}
6 changes: 3 additions & 3 deletions paramgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"os"
"strings"

"github.com/conduitio/conduit-commons/paramgen/internal"
"github.com/conduitio/conduit-commons/paramgen/paramgen"
)

func main() {
Expand All @@ -32,12 +32,12 @@ func main() {
args := parseFlags()

// parse the sdk parameters
params, pkg, err := internal.ParseParameters(args.path, args.structName)
params, pkg, err := paramgen.ParseParameters(args.path, args.structName)
if err != nil {
log.Fatalf("error: failed to parse parameters: %v", err)
}

code := internal.GenerateCode(params, pkg, args.structName)
code := paramgen.GenerateCode(params, pkg, args.structName)

path := strings.TrimSuffix(args.path, "/") + "/" + args.output
err = os.WriteFile(path, []byte(code), 0o600)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package internal
package paramgen

import (
"os"
Expand All @@ -39,6 +39,10 @@ func TestIntegration(t *testing.T) {
havePath: "./testdata/tags",
structName: "Config",
wantPath: "./testdata/tags/want.go",
}, {
havePath: "./testdata/dependencies",
structName: "Config",
wantPath: "./testdata/dependencies/want.go",
}}

for _, tc := range testCases {
Expand Down
166 changes: 121 additions & 45 deletions paramgen/internal/paramgen.go → paramgen/paramgen/paramgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
// limitations under the License.

//nolint:err113,wrapcheck,staticcheck // we don't care about wrapping errors here, also ignore usage of ast.Package (deprecated)
package internal
package paramgen

import (
"encoding/json"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
"io/fs"
"os/exec"
"reflect"
Expand Down Expand Up @@ -115,31 +116,34 @@ func parsePackage(path string) (*ast.Package, error) {
filterTests := func(info fs.FileInfo) bool {
return !strings.HasSuffix(info.Name(), "_test.go")
}
pkgs, err := parser.ParseDir(fset, path, filterTests, parser.ParseComments)
pkgs, err := parser.ParseDir(fset, path, filterTests, parser.ParseComments|parser.SkipObjectResolution)
if err != nil {
return nil, fmt.Errorf("couldn't parse directory %s: %w", path, err)
}
// Make sure they are all in one package.
if len(pkgs) == 0 {
return nil, fmt.Errorf("no source-code package in directory %s", path)
}
// Ignore files with go:build constraint set to "tools" (common pattern in
// Conduit connectors).
for pkgName, pkg := range pkgs {
// Ignore files with go:build constraint set to "tools" (common pattern in
// Conduit connectors).
maps.DeleteFunc(pkg.Files, func(_ string, f *ast.File) bool {
return hasBuildConstraint(f, "tools")
})
if len(pkg.Files) == 0 {
// Remove empty packages or the main package (can't be imported).
if len(pkg.Files) == 0 || pkgName == "main" {
delete(pkgs, pkgName)
}
}
if len(pkgs) > 1 {

// Make sure there is only 1 package.
switch len(pkgs) {
case 0:
return nil, fmt.Errorf("no source-code package in directory %s", path)
case 1:
for _, pkg := range pkgs {
return pkg, nil
}
panic("unreachable")
default:
return nil, fmt.Errorf("multiple packages %v in directory %s", maps.Keys(pkgs), path)
}
for _, v := range pkgs {
return v, nil // return first package
}
panic("unreachable")
}

// hasBuildConstraint is a very naive way to check if a file has a build
Expand Down Expand Up @@ -205,6 +209,10 @@ func (p *parameterParser) Parse(structType *ast.StructType) (map[string]config.P
}

func (p *parameterParser) parseIdent(ident *ast.Ident, field *ast.Field) (params map[string]config.Parameter, err error) {
if field != nil && p.shouldSkipField(field) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseIdent] %w", err)
Expand Down Expand Up @@ -252,6 +260,10 @@ func (p *parameterParser) parseIdent(ident *ast.Ident, field *ast.Field) (params
}

func (p *parameterParser) parseTypeSpec(ts *ast.TypeSpec, f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseTypeSpec] %w", err)
Expand All @@ -267,12 +279,18 @@ func (p *parameterParser) parseTypeSpec(ts *ast.TypeSpec, f *ast.Field) (params
return p.parseIdent(v, f)
case *ast.MapType:
return p.parseMapType(v, f)
case *ast.InterfaceType:
return nil, fmt.Errorf("error parsing type spec for %s.%s.%s: interface types not supported", p.pkg.Name, ts.Name.Name, p.getFieldNameOrUnknown(f))
default:
return nil, fmt.Errorf("unexpected type: %T", ts.Type)
}
}

func (p *parameterParser) parseStructType(st *ast.StructType, f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseStructType] %w", err)
Expand Down Expand Up @@ -303,6 +321,10 @@ func (p *parameterParser) parseStructType(st *ast.StructType, f *ast.Field) (par
}

func (p *parameterParser) parseField(f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseField] %w", err)
Expand All @@ -313,34 +335,45 @@ func (p *parameterParser) parseField(f *ast.Field) (params map[string]config.Par
return nil, nil //nolint:nilnil // ignore unexported fields
}

switch v := f.Type.(type) {
case *ast.Ident:
// identifier (builtin type or type in same package)
return p.parseIdent(v, f)
case *ast.StructType:
// nested type
return p.parseStructType(v, f)
case *ast.SelectorExpr:
return p.parseSelectorExpr(v, f)
case *ast.MapType:
return p.parseMapType(v, f)
case *ast.ArrayType:
strType := fmt.Sprintf("%s", v.Elt)
if !p.isBuiltinType(strType) && !strings.Contains(strType, "time Duration") {
return nil, fmt.Errorf("unsupported slice type: %s", strType)
}
expr := f.Type
for {
switch v := expr.(type) {
case *ast.StarExpr:
// dereference pointer
expr = v.X
continue
case *ast.Ident:
// identifier (builtin type or type in same package)
return p.parseIdent(v, f)
case *ast.StructType:
// nested type
return p.parseStructType(v, f)
case *ast.SelectorExpr:
return p.parseSelectorExpr(v, f)
case *ast.MapType:
return p.parseMapType(v, f)
case *ast.ArrayType:
strType := fmt.Sprintf("%s", v.Elt)
if !p.isBuiltinType(strType) && !strings.Contains(strType, "time Duration") {
return nil, fmt.Errorf("unsupported slice type: %s", strType)
}

name, param, err := p.parseSingleParameter(f, config.ParameterTypeString)
if err != nil {
return nil, err
name, param, err := p.parseSingleParameter(f, config.ParameterTypeString)
if err != nil {
return nil, err
}
return map[string]config.Parameter{name: param}, nil
default:
return nil, fmt.Errorf("unknown type: %T", f.Type)
}
return map[string]config.Parameter{name: param}, nil
default:
return nil, fmt.Errorf("unknown type: %T", f.Type)
}
}

func (p *parameterParser) parseMapType(mt *ast.MapType, f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

if fmt.Sprintf("%s", mt.Key) != "string" {
return nil, fmt.Errorf("unsupported map key type: %s", mt.Key)
}
Expand Down Expand Up @@ -378,6 +411,10 @@ func (p *parameterParser) parseMapType(mt *ast.MapType, f *ast.Field) (params ma
}

func (p *parameterParser) parseSelectorExpr(se *ast.SelectorExpr, f *ast.Field) (params map[string]config.Parameter, err error) {
if f != nil && p.shouldSkipField(f) {
return nil, nil //nolint:nilnil // ignore this validation
}

defer func() {
if err != nil {
err = fmt.Errorf("[parseSelectorExpr] %w", err)
Expand Down Expand Up @@ -428,17 +465,21 @@ func (p *parameterParser) findPackage(importPath string) (*ast.Package, error) {
// first cleanup string
importPath = strings.Trim(importPath, `"`)

if !strings.HasPrefix(importPath, p.mod.Path) {
// we only allow types declared in the same module
return nil, fmt.Errorf("we do not support parameters from package %v (please use builtin types or time.Duration)", importPath)
}

if pkg, ok := p.imports[importPath]; ok {
// it's cached already
return pkg, nil
}

pkgDir := p.mod.Dir + strings.TrimPrefix(importPath, p.mod.Path)
if !strings.HasPrefix(importPath, p.mod.Path) {
// Import path is not part of the module, we need to find the package path
var err error
pkgDir, err = p.packageToPath(importPath)
if err != nil {
return nil, fmt.Errorf("could not get package path for %q: %w", importPath, err)
}
}

pkg, err := parsePackage(pkgDir)
if err != nil {
return nil, fmt.Errorf("could not parse package dir %q: %w", pkgDir, err)
Expand Down Expand Up @@ -514,6 +555,11 @@ func (p *parameterParser) attachPrefix(f *ast.Field, params map[string]config.Pa
return prefixedParams
}

func (p *parameterParser) shouldSkipField(f *ast.Field) bool {
val := p.getTag(f.Tag, tagParamName)
return val == "-"
}

func (p *parameterParser) isBuiltinType(name string) bool {
switch name {
case "string", "bool", "int", "uint", "int8", "uint8", "int16", "uint16", "int32", "uint32", "int64", "uint64",
Expand Down Expand Up @@ -593,7 +639,7 @@ func (p *parameterParser) getParamType(i *ast.Ident) config.ParameterType {
// lowercase letter. If the string starts with multiple uppercase letters, all
// but the last character in the sequence will be converted into lowercase
// letters (e.g. HTTPRequest -> httpRequest).
func (p *parameterParser) formatFieldName(name string) string {
func (*parameterParser) formatFieldName(name string) string {
if name == "" {
return ""
}
Expand All @@ -619,7 +665,7 @@ func (p *parameterParser) formatFieldName(name string) string {
return newName
}

func (p *parameterParser) formatFieldComment(f *ast.Field) string {
func (*parameterParser) formatFieldComment(f *ast.Field) string {
doc := f.Doc
if doc == nil {
// fallback to line comment
Expand All @@ -644,7 +690,7 @@ func (p *parameterParser) formatFieldComment(f *ast.Field) string {
return c
}

func (p *parameterParser) getTag(lit *ast.BasicLit, tag string) string {
func (*parameterParser) getTag(lit *ast.BasicLit, tag string) string {
if lit == nil {
return ""
}
Expand All @@ -671,7 +717,7 @@ func (p *parameterParser) parseValidateTag(tag string) ([]config.Validation, err
return validations, nil
}

func (p *parameterParser) parseValidation(str string) (config.Validation, error) {
func (*parameterParser) parseValidation(str string) (config.Validation, error) {
if str == validationRequired {
return config.ValidationRequired{}, nil
}
Expand Down Expand Up @@ -715,3 +761,33 @@ func (p *parameterParser) parseValidation(str string) (config.Validation, error)
return nil, fmt.Errorf("invalid value for tag validate: %s", str)
}
}

// packageToPath takes a package import path and returns the path to the directory
// of that package.
func (p *parameterParser) packageToPath(pkg string) (string, error) {
cmd := exec.Command("go", "list", "-f", "{{.Dir}}", pkg)
cmd.Dir = p.mod.Dir
stdout, err := cmd.StdoutPipe()
if err != nil {
return "", fmt.Errorf("error piping stdout of go list command: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return "", fmt.Errorf("error piping stderr of go list command: %w", err)
}
if err := cmd.Start(); err != nil {
return "", fmt.Errorf("error starting go list command: %w", err)
}
path, err := io.ReadAll(stdout)
if err != nil {
return "", fmt.Errorf("error reading stdout of go list command: %w", err)
}
errMsg, err := io.ReadAll(stderr)
if err != nil {
return "", fmt.Errorf("error reading stderr of go list command: %w", err)
}
if err := cmd.Wait(); err != nil {
return "", fmt.Errorf("error running command %q (error message: %q): %w", cmd.String(), errMsg, err)
}
return strings.TrimRight(string(path), "\n"), nil
}
Loading

0 comments on commit 1d0075b

Please sign in to comment.