Skip to content

Commit

Permalink
feat: introduce util.SplitStatements (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
atzoum authored Mar 20, 2024
1 parent b00a62e commit 4039fcc
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 9 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,15 @@ for row := range ch {
}
_ = row.Value
}
```
```

## Utilities

**SplitStatements**: Splits a string of SQL statements separated with semicolons into individual statements
```go
import sqlconnectutil "github.com/rudderlabs/sqlconnect-go/sqlconnect/util"

func main() {
statements := sqlconnectutil.SplitStatements("SELECT * FROM table; SELECT * FROM table;")
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"os"
"regexp"
"strings"
"sync"
"testing"
Expand All @@ -19,6 +18,7 @@ import (

"github.com/rudderlabs/rudder-go-kit/testhelper/rand"
"github.com/rudderlabs/sqlconnect-go/sqlconnect"
sqlconnectutil "github.com/rudderlabs/sqlconnect-go/sqlconnect/util"
)

type Options struct {
Expand Down Expand Up @@ -586,7 +586,6 @@ func ExecuteStatements(t *testing.T, c sqlconnect.DB, schema, path string) {

func ReadSQLStatements(t *testing.T, schema, path string) []string {
t.Helper()
SQLComment := regexp.MustCompile(`(?m)--.*\n`)
data, err := os.ReadFile(path)
require.NoErrorf(t, err, "it should be able to read the sql script file %q", path)
tpl, err := template.New("data").Parse(string(data))
Expand All @@ -595,10 +594,5 @@ func ReadSQLStatements(t *testing.T, schema, path string) []string {
templateData := map[string]any{"schema": schema}
err = tpl.Execute(sql, templateData)
require.NoErrorf(t, err, "it should be able to execute the sql script file %q", path)
allStmts := sql.String()
stmts := lo.FilterMap(strings.Split(allStmts, ";"), func(stmt string, _ int) (string, bool) {
stmt = SQLComment.ReplaceAllString(strings.TrimSpace(stmt), "")
return stmt, stmt != ""
})
return stmts
return sqlconnectutil.SplitStatements(sql.String())
}
85 changes: 85 additions & 0 deletions sqlconnect/util/split_query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package util

import (
"strings"

"github.com/samber/lo"
)

// SplitStatements splits a string containing multiple sql statements separated by semicolons.
// It strips out comments from the statements, both simple (--) and bracketed (/* */) ones.
// It also handles sql strings properly which can contain semi colons, escaped quotes and comment character sequences without affecting the splitting behaviour.
func SplitStatements(statements string) []string {
var inString bool // flag signalling that we are inside a SQL string
var inEscapedQuote bool // flag signalling that we are inside an escaped quote character inside a SQL string
var inSimpleComment bool // flag signalling that we are inside a simple comment (--)
var inBracketedComment bool // flag signalling that we are inside a bracketed comment (/* */)

var stmts []string //
var stmt string
var previous rune

next := func(input string, i int) (rune, bool) {
runes := []rune(input)
if len(input) > i+1 {
return runes[i+1], true
}
return 0, false
}

for i, c := range statements {
if inString {
if c == '\'' {
if inEscapedQuote {
inEscapedQuote = false
} else {
if next, ok := next(statements, i); ok {
if next == '\'' {
inEscapedQuote = true
} else {
inString = false
}
}
}
}
stmt += string(c)
} else if inSimpleComment {
if c == '\n' {
inSimpleComment = false
}
} else if inBracketedComment {
if c == '/' && previous == '*' {
inBracketedComment = false
}
} else {
if c == '\'' {
inString = true
stmt += string(c)
} else if c == '-' && previous == '-' {
inSimpleComment = true
stmtRunes := []rune(stmt)
stmt = string(stmtRunes[:len(stmtRunes)-1]) // remove the previous dash
} else if c == '*' && previous == '/' {
inBracketedComment = true
stmtRunes := []rune(stmt)
stmt = string(stmtRunes[:len(stmtRunes)-1]) // remove the previous slash
} else if c == ';' {
stmts = append(stmts, stmt)
stmt = ""
continue
} else {
stmt += string(c)
}
}
previous = c
}
if stmt != "" {
stmts = append(stmts, stmt)
}

return lo.FilterMap(stmts, func(stmt string, _ int) (string, bool) {
// remove leading and trailing whitespaces tabs and newlines
stmt = strings.TrimRight(strings.TrimLeft(stmt, "\n\t "), "\n\t ")
return stmt, stmt != ""
})
}
139 changes: 139 additions & 0 deletions sqlconnect/util/split_query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package util_test

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/rudderlabs/sqlconnect-go/sqlconnect/util"
)

func TestSplitStatements(t *testing.T) {
t.Run("single statement", func(t *testing.T) {
query := "SELECT * FROM table"
expected := []string{"SELECT * FROM table"}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("single statement with semicolon", func(t *testing.T) {
query := "SELECT * FROM table;"
expected := []string{"SELECT * FROM table"}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("multiple statements", func(t *testing.T) {
query := `
SELECT * FROM table1;
SELECT * FROM table2;
`
expected := []string{"SELECT * FROM table1", "SELECT * FROM table2"}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("multiple statements with simple comments", func(t *testing.T) {
query := `
SELECT * FROM table1; -- this is an inline comment
-- this is another comment on its own line
SELECT * FROM table2
`
expected := []string{"SELECT * FROM table1", "SELECT * FROM table2"}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("multiple statements with bracketed comments", func(t *testing.T) {
query := `
SELECT * FROM table1; /* this is a bracketed comment
that spans multiple lines */
/* this is another bracketed comment */
SELECT * FROM table2;
`
expected := []string{"SELECT * FROM table1", "SELECT * FROM table2"}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("multiple statements with both types of comments", func(t *testing.T) {
query := `
SELECT * FROM table1; -- this is an inline comment
/* this is a bracketed comment
that spans multiple lines */
-- this is another inline comment
SELECT * FROM table2;
/* this is another bracketed
comment */
`
expected := []string{"SELECT * FROM table1", "SELECT * FROM table2"}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("multiple statements with semicolon inside comments", func(t *testing.T) {
query := `
SELECT * FROM table1; -- this is an inline comment;
/* this is a bracketed comment;
that spans multiple lines */
-- this is another inline; comment;
SELECT * FROM table2;
/* this is another bracketed
comment; */
`
expected := []string{"SELECT * FROM table1", "SELECT * FROM table2"}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("multiple mulitline statements with semicolon in sql string", func(t *testing.T) {
query := `
SELECT *
FROM table1
WHERE value='some;value';
SELECT
*
FROM table2
WHERE value='another '' ; value;';
SELECT * FROM table3 WHERE value='' AND value1='some' ;
`
expected := []string{
"SELECT * \n\t\t\tFROM table1 \n\t\tWHERE value='some;value'",
"SELECT \n\t\t\t* \n\t\tFROM table2 \n\t\t WHERE value='another '' ; value;'",
"SELECT * FROM table3 WHERE value='' AND value1='some'",
}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("single statement with simple comment char sequence in sql string", func(t *testing.T) {
query := `SELECT * FROM table1 WHERE value='some --value'`

expected := []string{
"SELECT * FROM table1 WHERE value='some --value'",
}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("single statement with bracketed comment char sequence in sql string", func(t *testing.T) {
query := `SELECT * FROM table1 WHERE value='some /* value */'`

expected := []string{
"SELECT * FROM table1 WHERE value='some /* value */'",
}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})

t.Run("single statement with all kinds of impediments in sql string", func(t *testing.T) {
query := `SELECT * FROM table1 WHERE value='''some /* value */ -- comment'''`

expected := []string{
"SELECT * FROM table1 WHERE value='''some /* value */ -- comment'''",
}
actual := util.SplitStatements(query)
require.ElementsMatch(t, expected, actual)
})
}

0 comments on commit 4039fcc

Please sign in to comment.