diff --git a/compiler.go b/compiler.go index 8a4e8a0..ef61832 100644 --- a/compiler.go +++ b/compiler.go @@ -425,74 +425,132 @@ func (c *compiler) evalCallExpression(node *ast.CallExpression) (interface{}, er rt := rv.Type() rtNumIn := rt.NumIn() + isVariadic := rt.IsVariadic() + args := []reflect.Value{} - if len(node.Arguments) > rtNumIn { - return nil, errors.WithStack(errors.Errorf("%s too many arguments (%d for %d)", node.String(), len(node.Arguments), rtNumIn)) - } + if !isVariadic { + if len(node.Arguments) > rtNumIn { + return nil, errors.WithStack(errors.Errorf("%s too many arguments (%d for %d)", node.String(), len(node.Arguments), rtNumIn)) + } - args := []reflect.Value{} - for pos, a := range node.Arguments { - v, err := c.evalExpression(a) - if err != nil { - return nil, errors.WithStack(err) + for pos, a := range node.Arguments { + v, err := c.evalExpression(a) + if err != nil { + return nil, errors.WithStack(err) + } + + var ar reflect.Value + expectedT := rt.In(pos) + if v != nil { + ar = reflect.ValueOf(v) + } else { + ar = reflect.New(expectedT).Elem() + } + + actualT := ar.Type() + if !actualT.AssignableTo(expectedT) { + return nil, errors.WithStack(errors.Errorf("%+v (%T) is an invalid argument for %s at pos %d: expected (%s)", v, v, node.Function.String(), pos, expectedT)) + } + + args = append(args, ar) } - var ar reflect.Value - expectedT := rt.In(pos) - if v != nil { - ar = reflect.ValueOf(v) - } else { - ar = reflect.New(expectedT).Elem() + hc := func(arg reflect.Type) { + if arg.Name() == helperContextKind { + hargs := HelperContext{ + Context: c.ctx, + compiler: c, + block: node.Block, + } + args = append(args, reflect.ValueOf(hargs)) + return + } + if arg.Name() == "Data" { + args = append(args, reflect.ValueOf(c.ctx.export())) + return + } + if arg.Kind() == reflect.Map { + args = append(args, reflect.ValueOf(map[string]interface{}{})) + } } - actualT := ar.Type() - if !actualT.AssignableTo(expectedT) { - return nil, errors.WithStack(errors.Errorf("%+v (%T) is an invalid argument for %s at pos %d - evalCallExpression", v, v, node.Function.String(), pos)) + if len(args) < rtNumIn { + // missing some args, let's see if we can figure out what they are. + diff := rtNumIn - len(args) + switch diff { + case 2: + // check last is help + // check if last -1 is map + arg := rt.In(rtNumIn - 2) + hc(arg) + last := rt.In(rtNumIn - 1) + hc(last) + case 1: + // check if help or map + last := rt.In(rtNumIn - 1) + hc(last) + } } - args = append(args, ar) - } + if len(args) > rtNumIn { + return nil, errors.WithStack(errors.Errorf("%s too many arguments (%d for %d) - %+v", node.String(), len(args), rtNumIn, args)) + } + } else { + // Variadic func + nodeArgs := node.Arguments + nodeArgsLen := len(nodeArgs) + if nodeArgsLen < rtNumIn { + return nil, errors.WithStack(errors.Errorf("%s too few arguments (%d for %d) - %+v", node.String(), len(args), rtNumIn, args)) + } + var pos int - hc := func(arg reflect.Type) { - if arg.Name() == helperContextKind { - hargs := HelperContext{ - Context: c.ctx, - compiler: c, - block: node.Block, + // Handle normal args + for pos = 0; pos < rtNumIn-1; pos++ { + v, err := c.evalExpression(nodeArgs[pos]) + if err != nil { + return nil, errors.WithStack(err) } - args = append(args, reflect.ValueOf(hargs)) - return - } - if arg.Name() == "Data" { - args = append(args, reflect.ValueOf(c.ctx.export())) - return + + var ar reflect.Value + expectedT := rt.In(pos) + if v != nil { + ar = reflect.ValueOf(v) + } else { + ar = reflect.New(expectedT).Elem() + } + + actualT := ar.Type() + if !actualT.AssignableTo(expectedT) { + return nil, errors.WithStack(errors.Errorf("%+v (%T) is an invalid argument for %s at pos %d: expected (%s)", v, v, node.Function.String(), pos, expectedT)) + } + + args = append(args, ar) } - if arg.Kind() == reflect.Map { - args = append(args, reflect.ValueOf(map[string]interface{}{})) + + // Unroll variadic arg + expectedT := rt.In(pos).Elem() + for ; pos < nodeArgsLen; pos++ { + v, err := c.evalExpression(nodeArgs[pos]) + if err != nil { + return nil, errors.WithStack(err) + } + + var ar reflect.Value + if v != nil { + ar = reflect.ValueOf(v) + } else { + ar = reflect.New(expectedT) + } + + actualT := ar.Type() + if !actualT.AssignableTo(expectedT) { + return nil, errors.WithStack(errors.Errorf("%+v (%T) is an invalid argument for %s at pos %d: expected (%s)", v, v, node.Function.String(), pos, expectedT)) + } + + args = append(args, ar) } } - if len(args) < rtNumIn { - // missing some args, let's see if we can figure out what they are. - diff := rtNumIn - len(args) - switch diff { - case 2: - // check last is help - // check if last -1 is map - arg := rt.In(rtNumIn - 2) - hc(arg) - last := rt.In(rtNumIn - 1) - hc(last) - case 1: - // check if help or map - last := rt.In(rtNumIn - 1) - hc(last) - } - } - - if len(args) > rtNumIn { - return nil, errors.WithStack(errors.Errorf("%s too many arguments (%d for %d) - %+v", node.String(), len(args), rtNumIn, args)) - } if len(args) < rtNumIn { return nil, errors.WithStack(errors.Errorf("%s too few arguments (%d for %d) - %+v", node.String(), len(args), rtNumIn, args)) } diff --git a/plush_test.go b/plush_test.go index ed67f40..e67c19b 100644 --- a/plush_test.go +++ b/plush_test.go @@ -345,6 +345,32 @@ func Test_UndefinedArg(t *testing.T) { r.Equal(ErrUnknownIdentifier, errors.Cause(err)) } +func Test_VariadicHelper(t *testing.T) { + r := require.New(t) + input := `<%= foo(1, 2, 3) %>` + ctx := NewContext() + ctx.Set("foo", func(args ...int) int { + return len(args) + }) + + s, err := Render(input, ctx) + r.NoError(err) + r.Equal("3", s) +} + +func Test_VariadicHelperWithWrongParam(t *testing.T) { + r := require.New(t) + input := `<%= foo(1, 2, "test") %>` + ctx := NewContext() + ctx.Set("foo", func(args ...int) int { + return len(args) + }) + + _, err := Render(input, ctx) + r.Error(err) + r.Contains(err.Error(), "test (string) is an invalid argument for foo at pos 2: expected (int)") +} + func Test_RunScript(t *testing.T) { r := require.New(t) bb := &bytes.Buffer{}