-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
417 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
package sqext | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"strings" | ||
|
||
"database/sql" | ||
|
||
sq "github.com/Masterminds/squirrel" | ||
"github.com/lann/builder" | ||
) | ||
|
||
type auxData struct { | ||
PlaceholderFormat sq.PlaceholderFormat | ||
RunWith sq.BaseRunner | ||
Alias string | ||
Columns []string | ||
Recursive bool | ||
Statement sq.Sqlizer | ||
} | ||
|
||
func (a *auxData) Exec() (sql.Result, error) { | ||
if a.RunWith == nil { | ||
return nil, sq.RunnerNotSet | ||
} | ||
return sq.ExecWith(a.RunWith, a) | ||
} | ||
|
||
func (a *auxData) ToSql() (string, []interface{}, error) { | ||
if a.Alias == "" { | ||
return "", nil, fmt.Errorf("auxillary statement must contain alias") | ||
} | ||
|
||
var sql bytes.Buffer | ||
|
||
if a.Recursive { | ||
sql.WriteString("RECURSIVE ") | ||
} | ||
|
||
sql.WriteString(a.Alias) | ||
|
||
if len(a.Columns) > 0 { | ||
sql.WriteString("(") | ||
sql.WriteString(strings.Join(a.Columns, ", ")) | ||
sql.WriteString(")") | ||
} | ||
|
||
sql.WriteString(" AS (") | ||
var args []interface{} | ||
var err error | ||
args, err = appendToSql(a.Statement, &sql, args) | ||
if err != nil { | ||
return "", []interface{}{}, err | ||
} | ||
sql.WriteString(")") | ||
|
||
sqlStr, err := a.PlaceholderFormat.ReplacePlaceholders(sql.String()) | ||
if err != nil { | ||
return "", []interface{}{}, err | ||
} | ||
|
||
return sqlStr, args, nil | ||
} | ||
|
||
// Builder | ||
|
||
// AuxBuilder builds auxillary statements used by CTEs. | ||
type AuxBuilder builder.Builder | ||
|
||
func init() { | ||
builder.Register(AuxBuilder{}, auxData{}) | ||
} | ||
|
||
// Format methods | ||
|
||
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the | ||
// query. | ||
func (b AuxBuilder) PlaceholderFormat(f sq.PlaceholderFormat) AuxBuilder { | ||
return builder.Set(b, "PlaceholderFormat", f).(AuxBuilder) | ||
} | ||
|
||
// Runner methods | ||
|
||
// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. | ||
// For most cases runner will be a database connection. | ||
func (b AuxBuilder) RunWith(runner sq.BaseRunner) AuxBuilder { | ||
return setRunWith(b, runner).(AuxBuilder) | ||
} | ||
|
||
// Exec builds and Execs the query with the Runner set by RunWith. | ||
func (b AuxBuilder) Exec() (sql.Result, error) { | ||
data := builder.GetStruct(b).(auxData) | ||
return data.Exec() | ||
} | ||
|
||
// ToSql builds the query into a SQL string and bound args. | ||
func (b AuxBuilder) ToSql() (string, []interface{}, error) { | ||
data := builder.GetStruct(b).(auxData) | ||
return data.ToSql() | ||
} | ||
|
||
// Alias assigns an alias for the auxillary statements. | ||
func (b AuxBuilder) Alias(alias string) AuxBuilder { | ||
return builder.Set(b, "Alias", alias).(AuxBuilder) | ||
} | ||
|
||
// Recursive adds RECURSIVE modifier to the auxillary statments. | ||
func (b AuxBuilder) Recursive() AuxBuilder { | ||
return builder.Set(b, "Recursive", true).(AuxBuilder) | ||
} | ||
|
||
// Columns adds result columns of auxillary statement. | ||
func (b AuxBuilder) Columns(columns ...string) AuxBuilder { | ||
return builder.Extend(b, "Columns", columns).(AuxBuilder) | ||
} | ||
|
||
// Statement sets a subquery into auxillary statement. | ||
func (b AuxBuilder) Statement(stmt sq.Sqlizer) AuxBuilder { | ||
return builder.Set(b, "Statement", stmt).(AuxBuilder) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package sqext | ||
|
||
import ( | ||
"testing" | ||
|
||
sq "github.com/Masterminds/squirrel" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestAuxBuilderToSql(t *testing.T) { | ||
b := Aux(sq.Select("*").From("events").Where(sq.Eq{"is_completed": true})). | ||
Alias("completed_events"). | ||
Columns("a", "b"). | ||
Recursive() | ||
|
||
sql, args, err := b.ToSql() | ||
assert.NoError(t, err) | ||
|
||
expectedSql := "RECURSIVE completed_events(a, b) AS (SELECT * FROM events WHERE is_completed = ?)" | ||
assert.Equal(t, expectedSql, sql) | ||
assert.Equal(t, []interface{}{true}, args) | ||
} | ||
|
||
func TestAuxBuilderToSqlNoColumns(t *testing.T) { | ||
b := Aux(sq.Select("*").From("a").Where(sq.Eq{"x": 20})). | ||
Alias("a_data"). | ||
Recursive() | ||
|
||
sql, args, err := b.ToSql() | ||
assert.NoError(t, err) | ||
|
||
expectedSql := "RECURSIVE a_data AS (SELECT * FROM a WHERE x = ?)" | ||
assert.Equal(t, expectedSql, sql) | ||
assert.Equal(t, []interface{}{20}, args) | ||
} | ||
|
||
func TestAuxBuilderToSqlBasicAux(t *testing.T) { | ||
b := Aux(sq.Select("a", "b").From("x").Where(sq.Eq{"y": 1})). | ||
Alias("a_data") | ||
|
||
sql, args, err := b.ToSql() | ||
assert.NoError(t, err) | ||
|
||
expectedSql := "a_data AS (SELECT a, b FROM x WHERE y = ?)" | ||
assert.Equal(t, expectedSql, sql) | ||
assert.Equal(t, []interface{}{1}, args) | ||
} | ||
|
||
func TestAuxBuilderToSqlErr(t *testing.T) { | ||
b := Aux(sq.Select("*").From("events").Where(sq.Gt{"is_completed": true})).Recursive() | ||
|
||
_, _, err := b.ToSql() | ||
assert.Error(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
package sqext | ||
|
||
import ( | ||
"bytes" | ||
"database/sql" | ||
"fmt" | ||
|
||
sq "github.com/Masterminds/squirrel" | ||
"github.com/lann/builder" | ||
) | ||
|
||
type cteData struct { | ||
PlaceholderFormat sq.PlaceholderFormat | ||
RunWith sq.BaseRunner | ||
AuxStatements []sq.Sqlizer | ||
} | ||
|
||
func (c *cteData) Exec() (sql.Result, error) { | ||
if c.RunWith == nil { | ||
return nil, sq.RunnerNotSet | ||
} | ||
return sq.ExecWith(c.RunWith, c) | ||
} | ||
|
||
func (c *cteData) ToSql() (string, []interface{}, error) { | ||
if len(c.AuxStatements) == 0 { | ||
return "", []interface{}{}, fmt.Errorf("CTE must contain at least one auxillary statement") | ||
} | ||
|
||
var sql bytes.Buffer | ||
var args []interface{} | ||
|
||
sql.WriteString("WITH ") | ||
|
||
for i, stmt := range c.AuxStatements { | ||
var err error | ||
args, err = appendToSql(stmt, &sql, args) | ||
if err != nil { | ||
return "", []interface{}{}, err | ||
} | ||
|
||
if i != len(c.AuxStatements)-1 { | ||
sql.WriteString(", ") | ||
} | ||
} | ||
|
||
sqlStr, err := c.PlaceholderFormat.ReplacePlaceholders(sql.String()) | ||
if err != nil { | ||
return "", []interface{}{}, err | ||
} | ||
|
||
return sqlStr, args, nil | ||
} | ||
|
||
// Builder | ||
|
||
// CTEBuilder builds SQL WITH statement, also known as Common Table Expression (CTE). | ||
type CTEBuilder builder.Builder | ||
|
||
func init() { | ||
builder.Register(CTEBuilder{}, cteData{}) | ||
} | ||
|
||
// Format methods | ||
|
||
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the | ||
// query. | ||
func (b CTEBuilder) PlaceholderFormat(f sq.PlaceholderFormat) CTEBuilder { | ||
return builder.Set(b, "PlaceholderFormat", f).(CTEBuilder) | ||
} | ||
|
||
// Runner methods | ||
|
||
// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. | ||
// For most cases runner will be a database connection. | ||
func (b CTEBuilder) RunWith(runner sq.BaseRunner) CTEBuilder { | ||
return setRunWith(b, runner).(CTEBuilder) | ||
} | ||
|
||
// Exec builds and Execs the query with the Runner set by RunWith. | ||
func (b CTEBuilder) Exec() (sql.Result, error) { | ||
data := builder.GetStruct(b).(cteData) | ||
return data.Exec() | ||
} | ||
|
||
// ToSql builds the query into a SQL string and bound args. | ||
func (b CTEBuilder) ToSql() (string, []interface{}, error) { | ||
data := builder.GetStruct(b).(cteData) | ||
return data.ToSql() | ||
} | ||
|
||
// AuxStatements assigns auxillary statements to CTE. | ||
func (b CTEBuilder) AuxStatements(auxStmts ...AuxBuilder) CTEBuilder { | ||
return builder.Extend(b, "AuxStatements", auxStmts).(CTEBuilder) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package sqext | ||
|
||
import ( | ||
"testing" | ||
|
||
sq "github.com/Masterminds/squirrel" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestCTEBuilderToSql(t *testing.T) { | ||
insertA := sq.Insert("a"). | ||
Columns("a", "b"). | ||
Values(1, 3). | ||
Values(5, 8) | ||
|
||
selectB := sq.Select("a", "b").From("c").Where(sq.Eq{"a": 1}) | ||
|
||
b := CTE( | ||
Aux(insertA).Alias("a_inserted"), | ||
Aux(selectB).Alias("b_selected").Columns("a").Recursive(), | ||
) | ||
|
||
sql, args, err := b.ToSql() | ||
assert.NoError(t, err) | ||
|
||
expectedSql := "WITH a_inserted AS (INSERT INTO a (a,b) VALUES (?,?),(?,?)), RECURSIVE b_selected(a) AS (SELECT a, b FROM c WHERE a = ?)" | ||
assert.Equal(t, expectedSql, sql) | ||
assert.Equal(t, []interface{}{1, 3, 5, 8, 1}, args) | ||
} | ||
|
||
func TestCTEBuilderToSqlErr(t *testing.T) { | ||
b := CTE() | ||
_, _, err := b.ToSql() | ||
assert.Error(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
module github.com/bedakb/sqext | ||
|
||
go 1.19 | ||
|
||
require ( | ||
github.com/Masterminds/squirrel v1.5.3 // indirect | ||
github.com/davecgh/go-spew v1.1.1 // indirect | ||
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect | ||
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect | ||
github.com/pmezard/go-difflib v1.0.0 // indirect | ||
github.com/stretchr/objx v0.4.0 // indirect | ||
github.com/stretchr/testify v1.8.0 // indirect | ||
gopkg.in/yaml.v3 v3.0.1 // indirect | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
github.com/Masterminds/squirrel v1.5.3 h1:YPpoceAcxuzIljlr5iWpNKaql7hLeG1KLSrhvdHpkZc= | ||
github.com/Masterminds/squirrel v1.5.3/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= | ||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= | ||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= | ||
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= | ||
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= | ||
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= | ||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||
github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= | ||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= | ||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= | ||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | ||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= | ||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= | ||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package sqext | ||
|
||
import ( | ||
"io" | ||
|
||
sq "github.com/Masterminds/squirrel" | ||
) | ||
|
||
func appendToSql(q sq.Sqlizer, w io.Writer, args []interface{}) ([]interface{}, error) { | ||
sql, qArgs, err := q.ToSql() | ||
if err != nil { | ||
return nil, err | ||
} | ||
if sql == "" { | ||
return nil, nil | ||
} | ||
|
||
_, err = io.WriteString(w, sql) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
args = append(args, qArgs...) | ||
return args, nil | ||
} |
Oops, something went wrong.