Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bedakb committed Oct 2, 2022
1 parent f9adbf9 commit 1dcb271
Show file tree
Hide file tree
Showing 8 changed files with 417 additions and 0 deletions.
121 changes: 121 additions & 0 deletions aux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package sqext

import (
"bytes"
"fmt"
"strings"

"database/sql"

sq "github.com/Masterminds/squirrel"
"github.com/lann/builder"
)

type auxData struct {
PlaceholderFormat sq.PlaceholderFormat
RunWith sq.BaseRunner
Alias string
Columns []string
Recursive bool
Statement sq.Sqlizer
}

func (a *auxData) Exec() (sql.Result, error) {
if a.RunWith == nil {
return nil, sq.RunnerNotSet
}
return sq.ExecWith(a.RunWith, a)
}

func (a *auxData) ToSql() (string, []interface{}, error) {
if a.Alias == "" {
return "", nil, fmt.Errorf("auxillary statement must contain alias")
}

var sql bytes.Buffer

if a.Recursive {
sql.WriteString("RECURSIVE ")
}

sql.WriteString(a.Alias)

if len(a.Columns) > 0 {
sql.WriteString("(")
sql.WriteString(strings.Join(a.Columns, ", "))
sql.WriteString(")")
}

sql.WriteString(" AS (")
var args []interface{}
var err error
args, err = appendToSql(a.Statement, &sql, args)
if err != nil {
return "", []interface{}{}, err
}
sql.WriteString(")")

sqlStr, err := a.PlaceholderFormat.ReplacePlaceholders(sql.String())
if err != nil {
return "", []interface{}{}, err
}

return sqlStr, args, nil
}

// Builder

// AuxBuilder builds auxillary statements used by CTEs.
type AuxBuilder builder.Builder

func init() {
builder.Register(AuxBuilder{}, auxData{})
}

// Format methods

// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b AuxBuilder) PlaceholderFormat(f sq.PlaceholderFormat) AuxBuilder {
return builder.Set(b, "PlaceholderFormat", f).(AuxBuilder)
}

// Runner methods

// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
// For most cases runner will be a database connection.
func (b AuxBuilder) RunWith(runner sq.BaseRunner) AuxBuilder {
return setRunWith(b, runner).(AuxBuilder)
}

// Exec builds and Execs the query with the Runner set by RunWith.
func (b AuxBuilder) Exec() (sql.Result, error) {
data := builder.GetStruct(b).(auxData)
return data.Exec()
}

// ToSql builds the query into a SQL string and bound args.
func (b AuxBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(auxData)
return data.ToSql()
}

// Alias assigns an alias for the auxillary statements.
func (b AuxBuilder) Alias(alias string) AuxBuilder {
return builder.Set(b, "Alias", alias).(AuxBuilder)
}

// Recursive adds RECURSIVE modifier to the auxillary statments.
func (b AuxBuilder) Recursive() AuxBuilder {
return builder.Set(b, "Recursive", true).(AuxBuilder)
}

// Columns adds result columns of auxillary statement.
func (b AuxBuilder) Columns(columns ...string) AuxBuilder {
return builder.Extend(b, "Columns", columns).(AuxBuilder)
}

// Statement sets a subquery into auxillary statement.
func (b AuxBuilder) Statement(stmt sq.Sqlizer) AuxBuilder {
return builder.Set(b, "Statement", stmt).(AuxBuilder)
}
54 changes: 54 additions & 0 deletions aux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package sqext

import (
"testing"

sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/assert"
)

func TestAuxBuilderToSql(t *testing.T) {
b := Aux(sq.Select("*").From("events").Where(sq.Eq{"is_completed": true})).
Alias("completed_events").
Columns("a", "b").
Recursive()

sql, args, err := b.ToSql()
assert.NoError(t, err)

expectedSql := "RECURSIVE completed_events(a, b) AS (SELECT * FROM events WHERE is_completed = ?)"
assert.Equal(t, expectedSql, sql)
assert.Equal(t, []interface{}{true}, args)
}

func TestAuxBuilderToSqlNoColumns(t *testing.T) {
b := Aux(sq.Select("*").From("a").Where(sq.Eq{"x": 20})).
Alias("a_data").
Recursive()

sql, args, err := b.ToSql()
assert.NoError(t, err)

expectedSql := "RECURSIVE a_data AS (SELECT * FROM a WHERE x = ?)"
assert.Equal(t, expectedSql, sql)
assert.Equal(t, []interface{}{20}, args)
}

func TestAuxBuilderToSqlBasicAux(t *testing.T) {
b := Aux(sq.Select("a", "b").From("x").Where(sq.Eq{"y": 1})).
Alias("a_data")

sql, args, err := b.ToSql()
assert.NoError(t, err)

expectedSql := "a_data AS (SELECT a, b FROM x WHERE y = ?)"
assert.Equal(t, expectedSql, sql)
assert.Equal(t, []interface{}{1}, args)
}

func TestAuxBuilderToSqlErr(t *testing.T) {
b := Aux(sq.Select("*").From("events").Where(sq.Gt{"is_completed": true})).Recursive()

_, _, err := b.ToSql()
assert.Error(t, err)
}
95 changes: 95 additions & 0 deletions cte.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package sqext

import (
"bytes"
"database/sql"
"fmt"

sq "github.com/Masterminds/squirrel"
"github.com/lann/builder"
)

type cteData struct {
PlaceholderFormat sq.PlaceholderFormat
RunWith sq.BaseRunner
AuxStatements []sq.Sqlizer
}

func (c *cteData) Exec() (sql.Result, error) {
if c.RunWith == nil {
return nil, sq.RunnerNotSet
}
return sq.ExecWith(c.RunWith, c)
}

func (c *cteData) ToSql() (string, []interface{}, error) {
if len(c.AuxStatements) == 0 {
return "", []interface{}{}, fmt.Errorf("CTE must contain at least one auxillary statement")
}

var sql bytes.Buffer
var args []interface{}

sql.WriteString("WITH ")

for i, stmt := range c.AuxStatements {
var err error
args, err = appendToSql(stmt, &sql, args)
if err != nil {
return "", []interface{}{}, err
}

if i != len(c.AuxStatements)-1 {
sql.WriteString(", ")
}
}

sqlStr, err := c.PlaceholderFormat.ReplacePlaceholders(sql.String())
if err != nil {
return "", []interface{}{}, err
}

return sqlStr, args, nil
}

// Builder

// CTEBuilder builds SQL WITH statement, also known as Common Table Expression (CTE).
type CTEBuilder builder.Builder

func init() {
builder.Register(CTEBuilder{}, cteData{})
}

// Format methods

// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// query.
func (b CTEBuilder) PlaceholderFormat(f sq.PlaceholderFormat) CTEBuilder {
return builder.Set(b, "PlaceholderFormat", f).(CTEBuilder)
}

// Runner methods

// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
// For most cases runner will be a database connection.
func (b CTEBuilder) RunWith(runner sq.BaseRunner) CTEBuilder {
return setRunWith(b, runner).(CTEBuilder)
}

// Exec builds and Execs the query with the Runner set by RunWith.
func (b CTEBuilder) Exec() (sql.Result, error) {
data := builder.GetStruct(b).(cteData)
return data.Exec()
}

// ToSql builds the query into a SQL string and bound args.
func (b CTEBuilder) ToSql() (string, []interface{}, error) {
data := builder.GetStruct(b).(cteData)
return data.ToSql()
}

// AuxStatements assigns auxillary statements to CTE.
func (b CTEBuilder) AuxStatements(auxStmts ...AuxBuilder) CTEBuilder {
return builder.Extend(b, "AuxStatements", auxStmts).(CTEBuilder)
}
35 changes: 35 additions & 0 deletions cte_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package sqext

import (
"testing"

sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/assert"
)

func TestCTEBuilderToSql(t *testing.T) {
insertA := sq.Insert("a").
Columns("a", "b").
Values(1, 3).
Values(5, 8)

selectB := sq.Select("a", "b").From("c").Where(sq.Eq{"a": 1})

b := CTE(
Aux(insertA).Alias("a_inserted"),
Aux(selectB).Alias("b_selected").Columns("a").Recursive(),
)

sql, args, err := b.ToSql()
assert.NoError(t, err)

expectedSql := "WITH a_inserted AS (INSERT INTO a (a,b) VALUES (?,?),(?,?)), RECURSIVE b_selected(a) AS (SELECT a, b FROM c WHERE a = ?)"
assert.Equal(t, expectedSql, sql)
assert.Equal(t, []interface{}{1, 3, 5, 8, 1}, args)
}

func TestCTEBuilderToSqlErr(t *testing.T) {
b := CTE()
_, _, err := b.ToSql()
assert.Error(t, err)
}
14 changes: 14 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module github.com/bedakb/sqext

go 1.19

require (
github.com/Masterminds/squirrel v1.5.3 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.4.0 // indirect
github.com/stretchr/testify v1.8.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
22 changes: 22 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
github.com/Masterminds/squirrel v1.5.3 h1:YPpoceAcxuzIljlr5iWpNKaql7hLeG1KLSrhvdHpkZc=
github.com/Masterminds/squirrel v1.5.3/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw=
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk=
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
25 changes: 25 additions & 0 deletions part.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package sqext

import (
"io"

sq "github.com/Masterminds/squirrel"
)

func appendToSql(q sq.Sqlizer, w io.Writer, args []interface{}) ([]interface{}, error) {
sql, qArgs, err := q.ToSql()
if err != nil {
return nil, err
}
if sql == "" {
return nil, nil
}

_, err = io.WriteString(w, sql)
if err != nil {
return nil, err
}

args = append(args, qArgs...)
return args, nil
}
Loading

0 comments on commit 1dcb271

Please sign in to comment.