Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper functions for SQL(), Pos(), End() #120

6 changes: 3 additions & 3 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:

- name: Set up Go 1.19
uses: actions/setup-go@v5
- name: Set up Go 1.20
uses: actions/setup-go@v4
with:
go-version: 1.19
go-version: "1.20"
id: go

- name: Check out code into the Go module directory
Expand Down
69 changes: 48 additions & 21 deletions ast/pos.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,51 @@ import (
"github.com/cloudspannerecosystem/memefish/token"
)

// ================================================================================
//
// Helper functions for Pos(), End()
// These functions are intended for use within this file only.
//
// ================================================================================

// lastNode returns last element of Node slice.
// This function corresponds to NodeSliceVar[$] in ast.go.
func lastNode[T Node](s []T) T {
return s[len(s)-1]
}

// firstValidEnd returns the first valid Pos() in argument.
// "valid" means the node is not nil and Pos().Invalid() is not true.
// This function corresponds to "(n0 ?? n1 ?? ...).End()"
func firstValidEnd(ns ...Node) token.Pos {
for _, n := range ns {
if n != nil && !n.End().Invalid() {
return n.End()
}
}
return token.InvalidPos
}

// firstPos returns the Pos() of the first node.
// If argument is an empty slice, this function returns token.InvalidPos.
// This function corresponds to NodeSliceVar[0].pos in ast.go.
func firstPos[T Node](s []T) token.Pos {
if len(s) == 0 {
return token.InvalidPos
}
return s[0].Pos()
}

// lastEnd returns the End() of the last node.
// If argument is an empty slice, this function returns token.InvalidPos.
// This function corresponds to NodeSliceVar[$].end in ast.go.
func lastEnd[T Node](s []T) token.Pos {
if len(s) == 0 {
return token.InvalidPos
}
return lastNode(s).End()
}

// ================================================================================
//
// SELECT
Expand Down Expand Up @@ -39,25 +84,7 @@ func (c *CTE) End() token.Pos { return c.Rparen + 1 }
func (s *Select) Pos() token.Pos { return s.Select }

func (s *Select) End() token.Pos {
if s.Limit != nil {
return s.Limit.End()
}
if s.OrderBy != nil {
return s.OrderBy.End()
}
if s.Having != nil {
return s.Having.End()
}
if s.GroupBy != nil {
return s.GroupBy.End()
}
if s.Where != nil {
return s.Where.End()
}
if s.From != nil {
return s.From.End()
}
return s.Results[len(s.Results)-1].End()
return firstValidEnd(s.Limit, s.OrderBy, s.Having, s.GroupBy, s.Where, s.From, lastNode(s.Results))
}

func (c *CompoundQuery) Pos() token.Pos {
Expand Down Expand Up @@ -376,8 +403,8 @@ func (p *Param) End() token.Pos { return p.Atmark + 1 + token.Pos(len(p.Name)) }
func (i *Ident) Pos() token.Pos { return i.NamePos }
func (i *Ident) End() token.Pos { return i.NameEnd }

func (p *Path) Pos() token.Pos { return p.Idents[0].Pos() }
func (p *Path) End() token.Pos { return p.Idents[len(p.Idents)-1].End() }
func (p *Path) Pos() token.Pos { return firstPos(p.Idents) }
func (p *Path) End() token.Pos { return lastEnd(p.Idents) }

func (a *ArrayLiteral) Pos() token.Pos {
if !a.Array.Invalid() {
Expand Down
123 changes: 67 additions & 56 deletions ast/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,59 @@ package ast

import (
"github.com/cloudspannerecosystem/memefish/token"
"strings"
)

// ================================================================================
//
// Helper functions for SQL()
// These functions are intended for use within this file only.
//
// ================================================================================

// sqlOpt outputs:
//
// when node != nil: left + node.SQL() + right
// else : empty string
//
// This function corresponds to sqlOpt in ast.go
func sqlOpt[T interface {
Node
comparable
}](left string, node T, right string) string {
var zero T
if node == zero {
return ""
}
return left + node.SQL() + right
}

// strOpt outputs:
//
// when pred == true: s
// else : empty string
//
// This function corresponds to {{if pred}}s{{end}} in ast.go
func strOpt(pred bool, s string) string {
if pred {
return s
}
return ""
}

// sqlJoin outputs joined string of SQL() of all elems by sep.
// This function corresponds to sqlJoin in ast.go
func sqlJoin[T Node](elems []T, sep string) string {
var b strings.Builder
for i, r := range elems {
if i > 0 {
b.WriteString(sep)
}
b.WriteString(r.SQL())
}
return b.String()
}

type prec int

const (
Expand Down Expand Up @@ -116,36 +167,16 @@ func (c *CTE) SQL() string {
}

func (s *Select) SQL() string {
sql := "SELECT "
if s.Distinct {
sql += "DISTINCT "
}
if s.AsStruct {
sql += "AS STRUCT "
}
sql += s.Results[0].SQL()
for _, r := range s.Results[1:] {
sql += ", " + r.SQL()
}
if s.From != nil {
sql += " " + s.From.SQL()
}
if s.Where != nil {
sql += " " + s.Where.SQL()
}
if s.GroupBy != nil {
sql += " " + s.GroupBy.SQL()
}
if s.Having != nil {
sql += " " + s.Having.SQL()
}
if s.OrderBy != nil {
sql += " " + s.OrderBy.SQL()
}
if s.Limit != nil {
sql += " " + s.Limit.SQL()
}
return sql
return "SELECT " +
strOpt(s.Distinct, "DISTINCT ") +
strOpt(s.AsStruct, "AS STRUCT ") +
sqlJoin(s.Results, ", ") +
sqlOpt(" ", s.From, "") +
sqlOpt(" ", s.Where, "") +
sqlOpt(" ", s.GroupBy, "") +
sqlOpt(" ", s.Having, "") +
sqlOpt(" ", s.OrderBy, "") +
sqlOpt(" ", s.Limit, "")
}

func (c *CompoundQuery) SQL() string {
Expand Down Expand Up @@ -464,27 +495,11 @@ func (i *IndexExpr) SQL() string {
}

func (c *CallExpr) SQL() string {
sql := c.Func.SQL() + "("
if c.Distinct {
sql += "DISTINCT "
}
for i, a := range c.Args {
if i != 0 {
sql += ", "
}
sql += a.SQL()
}
if len(c.Args) > 0 && len(c.NamedArgs) > 0 {
sql += ", "
}
for i, v := range c.NamedArgs {
if i != 0 {
sql += ", "
}
sql += v.SQL()
}
sql += ")"
return sql
return c.Func.SQL() + "(" + strOpt(c.Distinct, "DISTINCT ") +
sqlJoin(c.Args, ", ") +
strOpt(len(c.Args) > 0 && len(c.NamedArgs) > 0, ", ") +
sqlJoin(c.NamedArgs, ", ") +
")"
}

func (n *NamedArg) SQL() string { return n.Name.SQL() + " => " + n.Value.SQL() }
Expand Down Expand Up @@ -595,11 +610,7 @@ func (i *Ident) SQL() string {
}

func (p *Path) SQL() string {
sql := p.Idents[0].SQL()
for _, id := range p.Idents[1:] {
sql += "." + id.SQL()
}
return sql
return sqlJoin(p.Idents, ".")
}

func (a *ArrayLiteral) SQL() string {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/cloudspannerecosystem/memefish

go 1.19
go 1.20

require (
github.com/MakeNowJust/heredoc/v2 v2.0.1
Expand Down
Loading