diff --git a/ast/context.go b/ast/context.go index 7c98479..86f9339 100644 --- a/ast/context.go +++ b/ast/context.go @@ -4,12 +4,14 @@ type Context struct { QueryType string // select, insert, update, delete Variable map[string]string Sqls map[string]*SqlNode + Config *Config } -func NewContext() *Context { +func NewContext(config *Config) *Context { return &Context{ Variable: map[string]string{}, Sqls: map[string]*SqlNode{}, + Config: config, } } @@ -26,3 +28,10 @@ func (c *Context) GetSql(k string) (*SqlNode, bool) { sql, ok := c.Sqls[k] return sql, ok } + +type Config struct { + SkipErrorQuery bool + WithQueryId bool +} + +type ConfigFn func() func(*Config) diff --git a/ast/mapper.go b/ast/mapper.go index 2254625..4dc8151 100644 --- a/ast/mapper.go +++ b/ast/mapper.go @@ -68,7 +68,7 @@ func (m *Mapper) GetStmt(ctx *Context) (string, error) { return strings.TrimSuffix(buff.String(), "\n"), nil } -func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) { +func (m *Mapper) GetStmts(ctx *Context) ([]string, error) { var stmts []string ctx.Sqls = m.SqlNodes for _, a := range m.QueryNodes { @@ -77,7 +77,7 @@ func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) { stmts = append(stmts, data) continue } - if skipErrorQuery { + if ctx.Config.SkipErrorQuery { continue } return nil, err diff --git a/ast/query.go b/ast/query.go index 7482d14..294fa1b 100644 --- a/ast/query.go +++ b/ast/query.go @@ -3,6 +3,7 @@ package ast import ( "bytes" "encoding/xml" + "fmt" "github.com/actiontech/mybatis-mapper-2-sql/sqlfmt" ) @@ -32,6 +33,7 @@ func (s *QueryNode) Scan(start *xml.StartElement) error { func (s *QueryNode) GetStmt(ctx *Context) (string, error) { buff := bytes.Buffer{} ctx.QueryType = s.Type + for _, a := range s.Children { data, err := a.GetStmt(ctx) if err != nil { @@ -39,5 +41,13 @@ func (s *QueryNode) GetStmt(ctx *Context) (string, error) { } buff.WriteString(data) } - return sqlfmt.FormatSQL(buff.String()), nil + fmtSQL := sqlfmt.FormatSQL(buff.String()) + if ctx.Config.WithQueryId { + buff.Reset() + buff.WriteString(fmt.Sprintf("/* id: %s */\n", s.Id)) + buff.WriteString(fmtSQL) + return buff.String(), nil + } else { + return fmtSQL, nil + } } diff --git a/config.go b/config.go new file mode 100644 index 0000000..9711f99 --- /dev/null +++ b/config.go @@ -0,0 +1,15 @@ +package parser + +import "github.com/actiontech/mybatis-mapper-2-sql/ast" + +func SkipErrorQuery() func(*ast.Config) { + return func(c *ast.Config) { + c.SkipErrorQuery = true + } +} + +func WithQueryId() func(*ast.Config) { + return func(c *ast.Config) { + c.WithQueryId = true + } +} diff --git a/parser.go b/parser.go index f424043..dad656c 100644 --- a/parser.go +++ b/parser.go @@ -20,7 +20,7 @@ func ParseXML(data string) (string, error) { if n == nil { return "", nil } - stmt, err := n.GetStmt(ast.NewContext()) + stmt, err := n.GetStmt(ast.NewContext(&ast.Config{})) if err != nil { return "", err } @@ -28,8 +28,9 @@ func ParseXML(data string) (string, error) { } // ParseXMLQuery is a parser for parse all query in XML to []string one by one; -// you can set `skipErrorQuery` true to ignore invalid query. -func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) { +// ConfigFn: +// `SkipErrorQuery` to ignore invalid query. +func ParseXMLQuery(data string, configFns ...ast.ConfigFn) ([]string, error) { r := strings.NewReader(data) d := xml.NewDecoder(r) n, err := parse(d) @@ -43,7 +44,13 @@ func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) { if !ok { return nil, fmt.Errorf("the mapper is not found") } - stmts, err := m.GetStmts(ast.NewContext(), skipErrorQuery) + + config := &ast.Config{} + for _, configFn := range configFns { + configFn()(config) + } + + stmts, err := m.GetStmts(ast.NewContext(config)) if err != nil { return nil, err } diff --git a/parser_ibatis_test.go b/parser_ibatis_test.go index 30cf553..9b371d0 100644 --- a/parser_ibatis_test.go +++ b/parser_ibatis_test.go @@ -52,7 +52,7 @@ id = #id# } func TestParseIBatisInclude(t *testing.T) { - testParserQuery(t, false, ` + testParserQuery(t, ` @@ -76,7 +76,7 @@ SELECT id, name "SELECT `id`,`name` FROM `items` WHERE `parentid`=6", }) - testParserQuery(t, false, ` + testParserQuery(t, ` @@ -102,7 +102,7 @@ SELECT id, name } func TestParseIBatisAll(t *testing.T) { - testParserQuery(t, true, ` + testParserQuery(t, ` @@ -195,5 +195,6 @@ func TestParseIBatisAll(t *testing.T) { "SELECT * FROM `EMPLOYEE` WHERE (`username`=? OR `username`=?) AND `id` IS NULL AND `id`=?", "SELECT * FROM `EMPLOYEE` WHERE `ACC_FIRST_NAME`=? OR `ACC_LAST_NAME`=? AND `ACC_EMAIL` LIKE ? AND `ACC_ID`=? ORDER BY `ACC_LAST_NAME`", "SELECT * FROM `EMPLOYEE` ORDER BY ?", - }) -} \ No newline at end of file + }, + SkipErrorQuery) +} diff --git a/parser_test.go b/parser_test.go index 89ce67b..32eb5c9 100644 --- a/parser_test.go +++ b/parser_test.go @@ -2,6 +2,8 @@ package parser import ( "testing" + + "github.com/actiontech/mybatis-mapper-2-sql/ast" ) func testParser(t *testing.T, xmlData, expect string) { @@ -573,8 +575,8 @@ func TestParserSQLRefIdNotFound(t *testing.T) { } } -func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) { - actual, err := ParseXMLQuery(xmlData, skipError) +func testParserQuery(t *testing.T, xmlData string, expect []string, configFns ...ast.ConfigFn) { + actual, err := ParseXMLQuery(xmlData, configFns...) if err != nil { t.Errorf("parse error: %v", err) return @@ -593,7 +595,7 @@ func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []stri } func TestParserQueryFullFile(t *testing.T) { - testParserQuery(t, false, + testParserQuery(t, ` @@ -796,7 +798,7 @@ func TestParserQueryHasInvalidQuery(t *testing.T) { from t -`, false) +`) if err == nil { t.Errorf("expect has error, but no error") } @@ -806,7 +808,7 @@ func TestParserQueryHasInvalidQuery(t *testing.T) { } func TestParserQueryHasInvalidQueryButSkip(t *testing.T) { - testParserQuery(t, true, + testParserQuery(t, ` @@ -830,11 +832,12 @@ func TestParserQueryHasInvalidQueryButSkip(t *testing.T) { `, []string{ "SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?", - }) + }, + SkipErrorQuery) } func TestIssue302(t *testing.T) { - testParserQuery(t, false, + testParserQuery(t, ` @@ -1046,3 +1049,98 @@ func TestOtherwise_issue1193(t *testing.T) { `, "SELECT * FROM `fruits` WHERE `name`=? AND `price`=? AND `category`=?;", ) } + +func TestWithQueryId_issue1331(t *testing.T) { + testParserQuery(t, ` + + + + `, []string{"SELECT * FROM `fruits` WHERE `name`=? AND `price`=?"}, + ) + testParserQuery(t, ` + + + + `, []string{"/* id: testChoose */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?"}, + WithQueryId, + ) + testParserQuery(t, ` + + + + + `, []string{ + "/* id: testChoose */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?", + "/* id: testChoose2 */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?", + }, + WithQueryId, + ) +}