From 1dcb271a07ff587803627f6549f68fa2d11595c3 Mon Sep 17 00:00:00 2001 From: Belmin Bedak Date: Sun, 2 Oct 2022 20:38:07 +0200 Subject: [PATCH] initial commit --- aux.go | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++ aux_test.go | 54 +++++++++++++++++++++++ cte.go | 95 ++++++++++++++++++++++++++++++++++++++++ cte_test.go | 35 +++++++++++++++ go.mod | 14 ++++++ go.sum | 22 ++++++++++ part.go | 25 +++++++++++ statement.go | 51 ++++++++++++++++++++++ 8 files changed, 417 insertions(+) create mode 100644 aux.go create mode 100644 aux_test.go create mode 100644 cte.go create mode 100644 cte_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 part.go create mode 100644 statement.go diff --git a/aux.go b/aux.go new file mode 100644 index 0000000..3818e7f --- /dev/null +++ b/aux.go @@ -0,0 +1,121 @@ +package sqext + +import ( + "bytes" + "fmt" + "strings" + + "database/sql" + + sq "github.com/Masterminds/squirrel" + "github.com/lann/builder" +) + +type auxData struct { + PlaceholderFormat sq.PlaceholderFormat + RunWith sq.BaseRunner + Alias string + Columns []string + Recursive bool + Statement sq.Sqlizer +} + +func (a *auxData) Exec() (sql.Result, error) { + if a.RunWith == nil { + return nil, sq.RunnerNotSet + } + return sq.ExecWith(a.RunWith, a) +} + +func (a *auxData) ToSql() (string, []interface{}, error) { + if a.Alias == "" { + return "", nil, fmt.Errorf("auxillary statement must contain alias") + } + + var sql bytes.Buffer + + if a.Recursive { + sql.WriteString("RECURSIVE ") + } + + sql.WriteString(a.Alias) + + if len(a.Columns) > 0 { + sql.WriteString("(") + sql.WriteString(strings.Join(a.Columns, ", ")) + sql.WriteString(")") + } + + sql.WriteString(" AS (") + var args []interface{} + var err error + args, err = appendToSql(a.Statement, &sql, args) + if err != nil { + return "", []interface{}{}, err + } + sql.WriteString(")") + + sqlStr, err := a.PlaceholderFormat.ReplacePlaceholders(sql.String()) + if err != nil { + return "", []interface{}{}, err + } + + return sqlStr, args, nil +} + +// Builder + +// AuxBuilder builds auxillary statements used by CTEs. +type AuxBuilder builder.Builder + +func init() { + builder.Register(AuxBuilder{}, auxData{}) +} + +// Format methods + +// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the +// query. +func (b AuxBuilder) PlaceholderFormat(f sq.PlaceholderFormat) AuxBuilder { + return builder.Set(b, "PlaceholderFormat", f).(AuxBuilder) +} + +// Runner methods + +// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. +// For most cases runner will be a database connection. +func (b AuxBuilder) RunWith(runner sq.BaseRunner) AuxBuilder { + return setRunWith(b, runner).(AuxBuilder) +} + +// Exec builds and Execs the query with the Runner set by RunWith. +func (b AuxBuilder) Exec() (sql.Result, error) { + data := builder.GetStruct(b).(auxData) + return data.Exec() +} + +// ToSql builds the query into a SQL string and bound args. +func (b AuxBuilder) ToSql() (string, []interface{}, error) { + data := builder.GetStruct(b).(auxData) + return data.ToSql() +} + +// Alias assigns an alias for the auxillary statements. +func (b AuxBuilder) Alias(alias string) AuxBuilder { + return builder.Set(b, "Alias", alias).(AuxBuilder) +} + +// Recursive adds RECURSIVE modifier to the auxillary statments. +func (b AuxBuilder) Recursive() AuxBuilder { + return builder.Set(b, "Recursive", true).(AuxBuilder) +} + +// Columns adds result columns of auxillary statement. +func (b AuxBuilder) Columns(columns ...string) AuxBuilder { + return builder.Extend(b, "Columns", columns).(AuxBuilder) +} + +// Statement sets a subquery into auxillary statement. +func (b AuxBuilder) Statement(stmt sq.Sqlizer) AuxBuilder { + return builder.Set(b, "Statement", stmt).(AuxBuilder) +} diff --git a/aux_test.go b/aux_test.go new file mode 100644 index 0000000..1dd5726 --- /dev/null +++ b/aux_test.go @@ -0,0 +1,54 @@ +package sqext + +import ( + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/assert" +) + +func TestAuxBuilderToSql(t *testing.T) { + b := Aux(sq.Select("*").From("events").Where(sq.Eq{"is_completed": true})). + Alias("completed_events"). + Columns("a", "b"). + Recursive() + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSql := "RECURSIVE completed_events(a, b) AS (SELECT * FROM events WHERE is_completed = ?)" + assert.Equal(t, expectedSql, sql) + assert.Equal(t, []interface{}{true}, args) +} + +func TestAuxBuilderToSqlNoColumns(t *testing.T) { + b := Aux(sq.Select("*").From("a").Where(sq.Eq{"x": 20})). + Alias("a_data"). + Recursive() + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSql := "RECURSIVE a_data AS (SELECT * FROM a WHERE x = ?)" + assert.Equal(t, expectedSql, sql) + assert.Equal(t, []interface{}{20}, args) +} + +func TestAuxBuilderToSqlBasicAux(t *testing.T) { + b := Aux(sq.Select("a", "b").From("x").Where(sq.Eq{"y": 1})). + Alias("a_data") + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSql := "a_data AS (SELECT a, b FROM x WHERE y = ?)" + assert.Equal(t, expectedSql, sql) + assert.Equal(t, []interface{}{1}, args) +} + +func TestAuxBuilderToSqlErr(t *testing.T) { + b := Aux(sq.Select("*").From("events").Where(sq.Gt{"is_completed": true})).Recursive() + + _, _, err := b.ToSql() + assert.Error(t, err) +} diff --git a/cte.go b/cte.go new file mode 100644 index 0000000..b70e3ee --- /dev/null +++ b/cte.go @@ -0,0 +1,95 @@ +package sqext + +import ( + "bytes" + "database/sql" + "fmt" + + sq "github.com/Masterminds/squirrel" + "github.com/lann/builder" +) + +type cteData struct { + PlaceholderFormat sq.PlaceholderFormat + RunWith sq.BaseRunner + AuxStatements []sq.Sqlizer +} + +func (c *cteData) Exec() (sql.Result, error) { + if c.RunWith == nil { + return nil, sq.RunnerNotSet + } + return sq.ExecWith(c.RunWith, c) +} + +func (c *cteData) ToSql() (string, []interface{}, error) { + if len(c.AuxStatements) == 0 { + return "", []interface{}{}, fmt.Errorf("CTE must contain at least one auxillary statement") + } + + var sql bytes.Buffer + var args []interface{} + + sql.WriteString("WITH ") + + for i, stmt := range c.AuxStatements { + var err error + args, err = appendToSql(stmt, &sql, args) + if err != nil { + return "", []interface{}{}, err + } + + if i != len(c.AuxStatements)-1 { + sql.WriteString(", ") + } + } + + sqlStr, err := c.PlaceholderFormat.ReplacePlaceholders(sql.String()) + if err != nil { + return "", []interface{}{}, err + } + + return sqlStr, args, nil +} + +// Builder + +// CTEBuilder builds SQL WITH statement, also known as Common Table Expression (CTE). +type CTEBuilder builder.Builder + +func init() { + builder.Register(CTEBuilder{}, cteData{}) +} + +// Format methods + +// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the +// query. +func (b CTEBuilder) PlaceholderFormat(f sq.PlaceholderFormat) CTEBuilder { + return builder.Set(b, "PlaceholderFormat", f).(CTEBuilder) +} + +// Runner methods + +// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec. +// For most cases runner will be a database connection. +func (b CTEBuilder) RunWith(runner sq.BaseRunner) CTEBuilder { + return setRunWith(b, runner).(CTEBuilder) +} + +// Exec builds and Execs the query with the Runner set by RunWith. +func (b CTEBuilder) Exec() (sql.Result, error) { + data := builder.GetStruct(b).(cteData) + return data.Exec() +} + +// ToSql builds the query into a SQL string and bound args. +func (b CTEBuilder) ToSql() (string, []interface{}, error) { + data := builder.GetStruct(b).(cteData) + return data.ToSql() +} + +// AuxStatements assigns auxillary statements to CTE. +func (b CTEBuilder) AuxStatements(auxStmts ...AuxBuilder) CTEBuilder { + return builder.Extend(b, "AuxStatements", auxStmts).(CTEBuilder) +} diff --git a/cte_test.go b/cte_test.go new file mode 100644 index 0000000..42fafa2 --- /dev/null +++ b/cte_test.go @@ -0,0 +1,35 @@ +package sqext + +import ( + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/assert" +) + +func TestCTEBuilderToSql(t *testing.T) { + insertA := sq.Insert("a"). + Columns("a", "b"). + Values(1, 3). + Values(5, 8) + + selectB := sq.Select("a", "b").From("c").Where(sq.Eq{"a": 1}) + + b := CTE( + Aux(insertA).Alias("a_inserted"), + Aux(selectB).Alias("b_selected").Columns("a").Recursive(), + ) + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSql := "WITH a_inserted AS (INSERT INTO a (a,b) VALUES (?,?),(?,?)), RECURSIVE b_selected(a) AS (SELECT a, b FROM c WHERE a = ?)" + assert.Equal(t, expectedSql, sql) + assert.Equal(t, []interface{}{1, 3, 5, 8, 1}, args) +} + +func TestCTEBuilderToSqlErr(t *testing.T) { + b := CTE() + _, _, err := b.ToSql() + assert.Error(t, err) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..76e97f9 --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/bedakb/sqext + +go 1.19 + +require ( + github.com/Masterminds/squirrel v1.5.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect + github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.4.0 // indirect + github.com/stretchr/testify v1.8.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..326f593 --- /dev/null +++ b/go.sum @@ -0,0 +1,22 @@ +github.com/Masterminds/squirrel v1.5.3 h1:YPpoceAcxuzIljlr5iWpNKaql7hLeG1KLSrhvdHpkZc= +github.com/Masterminds/squirrel v1.5.3/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/part.go b/part.go new file mode 100644 index 0000000..46f8b8c --- /dev/null +++ b/part.go @@ -0,0 +1,25 @@ +package sqext + +import ( + "io" + + sq "github.com/Masterminds/squirrel" +) + +func appendToSql(q sq.Sqlizer, w io.Writer, args []interface{}) ([]interface{}, error) { + sql, qArgs, err := q.ToSql() + if err != nil { + return nil, err + } + if sql == "" { + return nil, nil + } + + _, err = io.WriteString(w, sql) + if err != nil { + return nil, err + } + + args = append(args, qArgs...) + return args, nil +} diff --git a/statement.go b/statement.go new file mode 100644 index 0000000..4f0652f --- /dev/null +++ b/statement.go @@ -0,0 +1,51 @@ +package sqext + +import ( + sq "github.com/Masterminds/squirrel" + "github.com/lann/builder" +) + +// StatementBuilderType is the type of StatementBuilder. +type StatementBuilderType builder.Builder + +// PlaceholderFormat sets the PlaceholderFormat field for any child builders. +func (b StatementBuilderType) PlaceholderFormat(f sq.PlaceholderFormat) StatementBuilderType { + return builder.Set(b, "PlaceholderFormat", f).(StatementBuilderType) +} + +// Aux returns a AuxBuilder for this StatementBuilderType. +func (b StatementBuilderType) Aux(stmt sq.Sqlizer) AuxBuilder { + return AuxBuilder(b).Statement(stmt) +} + +// CTE returns a CTEBuilder for this StatementBuilderType. +func (b StatementBuilderType) CTE(auxStmts ...AuxBuilder) CTEBuilder { + return CTEBuilder(b).AuxStatements(auxStmts...) +} + +// StatementBuilder is a parent builder for other builders. +var StatementBuilder = StatementBuilderType(builder.EmptyBuilder).PlaceholderFormat(sq.Question) + +// Aux returns a new AuxBuilder with a given sq.Sqlizer. +// +// See AuxBuilder.Statement. +func Aux(stmt sq.Sqlizer) AuxBuilder { + return StatementBuilder.Aux(stmt) +} + +// CTE returns a new CTEBulder with a list of given auxillary statements. +// +// See AuxBuilder.Statments. +func CTE(auxStmts ...AuxBuilder) CTEBuilder { + return StatementBuilder.CTE(auxStmts...) +} + +func setRunWith(b interface{}, runner sq.BaseRunner) interface{} { + switch r := runner.(type) { + case sq.StdSqlCtx: + runner = sq.WrapStdSqlCtx(r) + case sq.StdSql: + runner = sq.WrapStdSql(r) + } + return builder.Set(b, "RunWith", runner) +}