Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support actiontech/sqle#1331 #11

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion ast/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ type Context struct {
QueryType string // select, insert, update, delete
Variable map[string]string
Sqls map[string]*SqlNode
Config *Config
}

func NewContext() *Context {
func NewContext(config *Config) *Context {
return &Context{
Variable: map[string]string{},
Sqls: map[string]*SqlNode{},
Config: config,
}
}

Expand All @@ -26,3 +28,10 @@ func (c *Context) GetSql(k string) (*SqlNode, bool) {
sql, ok := c.Sqls[k]
return sql, ok
}

type Config struct {
SkipErrorQuery bool
WithQueryId bool
}

type ConfigFn func() func(*Config)
4 changes: 2 additions & 2 deletions ast/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (m *Mapper) GetStmt(ctx *Context) (string, error) {
return strings.TrimSuffix(buff.String(), "\n"), nil
}

func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) {
func (m *Mapper) GetStmts(ctx *Context) ([]string, error) {
var stmts []string
ctx.Sqls = m.SqlNodes
for _, a := range m.QueryNodes {
Expand All @@ -77,7 +77,7 @@ func (m *Mapper) GetStmts(ctx *Context, skipErrorQuery bool) ([]string, error) {
stmts = append(stmts, data)
continue
}
if skipErrorQuery {
if ctx.Config.SkipErrorQuery {
continue
}
return nil, err
Expand Down
12 changes: 11 additions & 1 deletion ast/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ast
import (
"bytes"
"encoding/xml"
"fmt"

"github.com/actiontech/mybatis-mapper-2-sql/sqlfmt"
)
Expand Down Expand Up @@ -32,12 +33,21 @@ func (s *QueryNode) Scan(start *xml.StartElement) error {
func (s *QueryNode) GetStmt(ctx *Context) (string, error) {
buff := bytes.Buffer{}
ctx.QueryType = s.Type

for _, a := range s.Children {
data, err := a.GetStmt(ctx)
if err != nil {
return "", err
}
buff.WriteString(data)
}
return sqlfmt.FormatSQL(buff.String()), nil
fmtSQL := sqlfmt.FormatSQL(buff.String())
if ctx.Config.WithQueryId {
buff.Reset()
buff.WriteString(fmt.Sprintf("/* id: %s */\n", s.Id))
buff.WriteString(fmtSQL)
return buff.String(), nil
} else {
return fmtSQL, nil
}
}
15 changes: 15 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package parser

import "github.com/actiontech/mybatis-mapper-2-sql/ast"

func SkipErrorQuery() func(*ast.Config) {
return func(c *ast.Config) {
c.SkipErrorQuery = true
}
}

func WithQueryId() func(*ast.Config) {
return func(c *ast.Config) {
c.WithQueryId = true
}
}
15 changes: 11 additions & 4 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ func ParseXML(data string) (string, error) {
if n == nil {
return "", nil
}
stmt, err := n.GetStmt(ast.NewContext())
stmt, err := n.GetStmt(ast.NewContext(&ast.Config{}))
if err != nil {
return "", err
}
return stmt, nil
}

// ParseXMLQuery is a parser for parse all query in XML to []string one by one;
// you can set `skipErrorQuery` true to ignore invalid query.
func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) {
// ConfigFn:
// `SkipErrorQuery` to ignore invalid query.
func ParseXMLQuery(data string, configFns ...ast.ConfigFn) ([]string, error) {
r := strings.NewReader(data)
d := xml.NewDecoder(r)
n, err := parse(d)
Expand All @@ -43,7 +44,13 @@ func ParseXMLQuery(data string, skipErrorQuery bool) ([]string, error) {
if !ok {
return nil, fmt.Errorf("the mapper is not found")
}
stmts, err := m.GetStmts(ast.NewContext(), skipErrorQuery)

config := &ast.Config{}
for _, configFn := range configFns {
configFn()(config)
}

stmts, err := m.GetStmts(ast.NewContext(config))
if err != nil {
return nil, err
}
Expand Down
11 changes: 6 additions & 5 deletions parser_ibatis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ id = #id#
}

func TestParseIBatisInclude(t *testing.T) {
testParserQuery(t, false, `<?xml version="1.0" encoding="UTF-8"?>
testParserQuery(t, `<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE sqlMap PUBLIC "-//ibatis.apache.org//DTD SQL Map 2.0//EN" "http://ibatis.apache.org/dtd/sql-map-2.dtd">

<sqlMap namespace="Employee">
Expand All @@ -76,7 +76,7 @@ SELECT id, name
"SELECT `id`,`name` FROM `items` WHERE `parentid`=6",
})

testParserQuery(t, false, `<?xml version="1.0" encoding="UTF-8"?>
testParserQuery(t, `<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE sqlMap PUBLIC "-//ibatis.apache.org//DTD SQL Map 2.0//EN" "http://ibatis.apache.org/dtd/sql-map-2.dtd">

<sqlMap namespace="Employee">
Expand All @@ -102,7 +102,7 @@ SELECT id, name
}

func TestParseIBatisAll(t *testing.T) {
testParserQuery(t, true, `
testParserQuery(t, `
<!DOCTYPE sqlMap PUBLIC "-//ibatis.apache.org//DTD SQL Map 2.0//EN" "http://ibatis.apache.org/dtd/sql-map-2.dtd">

<sqlMap namespace="Employee">
Expand Down Expand Up @@ -195,5 +195,6 @@ func TestParseIBatisAll(t *testing.T) {
"SELECT * FROM `EMPLOYEE` WHERE (`username`=? OR `username`=?) AND `id` IS NULL AND `id`=?",
"SELECT * FROM `EMPLOYEE` WHERE `ACC_FIRST_NAME`=? OR `ACC_LAST_NAME`=? AND `ACC_EMAIL` LIKE ? AND `ACC_ID`=? ORDER BY `ACC_LAST_NAME`",
"SELECT * FROM `EMPLOYEE` ORDER BY ?",
})
}
},
SkipErrorQuery)
}
114 changes: 106 additions & 8 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package parser

import (
"testing"

"github.com/actiontech/mybatis-mapper-2-sql/ast"
)

func testParser(t *testing.T, xmlData, expect string) {
Expand Down Expand Up @@ -573,8 +575,8 @@ func TestParserSQLRefIdNotFound(t *testing.T) {
}
}

func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []string) {
actual, err := ParseXMLQuery(xmlData, skipError)
func testParserQuery(t *testing.T, xmlData string, expect []string, configFns ...ast.ConfigFn) {
actual, err := ParseXMLQuery(xmlData, configFns...)
if err != nil {
t.Errorf("parse error: %v", err)
return
Expand All @@ -593,7 +595,7 @@ func testParserQuery(t *testing.T, skipError bool, xmlData string, expect []stri
}

func TestParserQueryFullFile(t *testing.T) {
testParserQuery(t, false,
testParserQuery(t,
`
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
Expand Down Expand Up @@ -796,7 +798,7 @@ func TestParserQueryHasInvalidQuery(t *testing.T) {
<include refid="someinclude2" />
from t
</select>
</mapper>`, false)
</mapper>`)
if err == nil {
t.Errorf("expect has error, but no error")
}
Expand All @@ -806,7 +808,7 @@ func TestParserQueryHasInvalidQuery(t *testing.T) {
}

func TestParserQueryHasInvalidQueryButSkip(t *testing.T) {
testParserQuery(t, true,
testParserQuery(t,
`
<mapper namespace="Test">
<sql id="someinclude">
Expand All @@ -830,11 +832,12 @@ func TestParserQueryHasInvalidQueryButSkip(t *testing.T) {
</select>
</mapper>`, []string{
"SELECT `name`,`category`,`price` FROM `fruits` WHERE `name` LIKE ?",
})
},
SkipErrorQuery)
}

func TestIssue302(t *testing.T) {
testParserQuery(t, false,
testParserQuery(t,
`
<mapper namespace="Test">
<select id="selectUserByState" resultType="com.bz.model.entity.User">
Expand All @@ -851,7 +854,7 @@ func TestIssue302(t *testing.T) {
</mapper>`, []string{
"SELECT * FROM `user` WHERE `name`=? AND `name`=?",
})
testParserQuery(t, false,
testParserQuery(t,
`
<mapper namespace="Test">
<select id="selectUserByState" resultType="com.bz.model.entity.User">
Expand Down Expand Up @@ -1046,3 +1049,98 @@ func TestOtherwise_issue1193(t *testing.T) {
`, "SELECT * FROM `fruits` WHERE `name`=? AND `price`=? AND `category`=?;",
)
}

func TestWithQueryId_issue1331(t *testing.T) {
testParserQuery(t, `
<mapper namespace="Test">
<select id="testChoose">
SELECT
*
FROM
fruits
<where>
<choose>
<when test="name != null">
AND name = #{name}
</when>
<otherwise>
<if test="price != null and price !=''">
AND price = ${price}
</if>
</otherwise>
</choose>
</where>
</select>
</mapper>
`, []string{"SELECT * FROM `fruits` WHERE `name`=? AND `price`=?"},
)
testParserQuery(t, `
<mapper namespace="Test">
<select id="testChoose">
SELECT
*
FROM
fruits
<where>
<choose>
<when test="name != null">
AND name = #{name}
</when>
<otherwise>
<if test="price != null and price !=''">
AND price = ${price}
</if>
</otherwise>
</choose>
</where>
</select>
</mapper>
`, []string{"/* id: testChoose */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?"},
WithQueryId,
)
testParserQuery(t, `
<mapper namespace="Test">
<select id="testChoose">
SELECT
*
FROM
fruits
<where>
<choose>
<when test="name != null">
AND name = #{name}
</when>
<otherwise>
<if test="price != null and price !=''">
AND price = ${price}
</if>
</otherwise>
</choose>
</where>
</select>
<select id="testChoose2">
SELECT
*
FROM
fruits
<where>
<choose>
<when test="name != null">
AND name = #{name}
</when>
<otherwise>
<if test="price != null and price !=''">
AND price = ${price}
</if>
</otherwise>
</choose>
</where>
</select>
</mapper>
`, []string{
"/* id: testChoose */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?",
"/* id: testChoose2 */\nSELECT * FROM `fruits` WHERE `name`=? AND `price`=?",
},
WithQueryId,
)
}