diff --git a/compiler.go b/compiler.go index 59e8d17..e71ac58 100644 --- a/compiler.go +++ b/compiler.go @@ -59,7 +59,7 @@ func (c *compiler) compile() (string, error) { if c.curStmt != nil { s = c.curStmt } - return "", fmt.Errorf("line %d: %s", s.T().LineNumber, err) + return "", fmt.Errorf("line %d: %w", s.T().LineNumber, err) } c.write(bb, res) @@ -768,7 +768,7 @@ func (c *compiler) evalCallExpression(node *ast.CallExpression) (interface{}, er res := rv.Call(args) if len(res) > 0 { if e, ok := res[len(res)-1].Interface().(error); ok { - return nil, fmt.Errorf("could not call %s function: %s", node.Function, e) + return nil, fmt.Errorf("could not call %s function: %w", node.Function, e) } return res[0].Interface(), nil } diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..9736f7c --- /dev/null +++ b/error_test.go @@ -0,0 +1,22 @@ +package plush + +import ( + "database/sql" + + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestErrorType(t *testing.T) { + r := require.New(t) + + ctx := NewContext() + ctx.Set("sqlError", func() error { + return sql.ErrNoRows + }) + + _, err := Render(`<%= sqlError() %>`, ctx) + r.True(errors.Is(err, sql.ErrNoRows)) +}