Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC of pretty printing #227

Closed
244 changes: 196 additions & 48 deletions ast/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,112 @@ import (
"github.com/cloudspannerecosystem/memefish/token"
)

// ================================================================================
//
// Experimental format with indentation
//
// ================================================================================

// formatOption is container of format configurations.
// It is conceptually immutable.
type formatOption struct {
newline bool
indent int
}

// FormatContext is container of format option and current indentation.
// You should initialize this struct using FormatContextCompact() or FormatContextPretty().
// If methods are called with nil receiver, they will work as FormatContextCompact() is receiver.
type FormatContext struct {
option formatOption
currentIndent int
}

// FormatContextCompact is format context without newline and indentation.
func FormatContextCompact() *FormatContext {
return &FormatContext{option: formatOption{}}
}

// FormatContextPretty is format context with newline and configured indentation.
func FormatContextPretty(indent int) *FormatContext {
return &FormatContext{option: formatOption{newline: true, indent: indent}}
}

// SQL is entry point of pretty printing.
// If node implements NodeFormat, it calls NodeFormat.sqlContext() instead of Node.SQL().
// If it is called with nil receiver, FormatContextCompact() is used instead.
func (fc *FormatContext) SQL(node Node) string {
if fc == nil {
fc = FormatContextCompact()
}

if nodeFormat, ok := node.(NodeFormat); ok {
return nodeFormat.sqlContext(fc)
} else {
return node.SQL()
}
}

// newlineOr returns newline with indentation if formatOptionPretty is used.
// Otherwise, it returns argument string.
func (fc *FormatContext) newlineOr(s string) string {
if fc == nil {
fc = FormatContextCompact()
}

return strIfElse(fc.option.newline, "\n", s) + strings.Repeat(" ", fc.currentIndent)
}

// indentScope executes function with FormatContext with deeper indentation.
func (fc *FormatContext) indentScope(f func(fc *FormatContext) string) string {
if fc == nil {
fc = FormatContextCompact()
}

newFc := *fc
newFc.currentIndent += fc.option.indent
return f(&newFc)
}

// sqlOptCtx is sqlOpt with FormatContext.
func sqlOptCtx[T interface {
Node
comparable
}](fc *FormatContext, left string, node T, right string) string {
if fc == nil {
fc = FormatContextCompact()
}

var zero T
if node == zero {
return ""
}
return left + fc.SQL(node) + right
}

// sqlJoinCtx is sqlJoin with FormatContext.
func sqlJoinCtx[T Node](fc *FormatContext, elems []T, sep string) string {
var b strings.Builder
for i, r := range elems {
if i > 0 {
b.WriteString(sep)
}
b.WriteString(fc.SQL(r))
}
return b.String()
}

// NodeFormat is Node with FormatContext support.
// If it is implemented, (*FormatContext).SQL calls sqlContext() instead of SQL()
type NodeFormat interface {
Node

// sqlContext is Node.SQL() with FormatContext conceptually.
// If it is called with nil FormatContext, FormatContextCompact() is used instead.
// Note: It would become to Node.SQL() finally.
sqlContext(fmtCtx *FormatContext) string
}

// ================================================================================
//
// Helper functions for SQL()
Expand All @@ -24,11 +130,7 @@ func sqlOpt[T interface {
Node
comparable
}](left string, node T, right string) string {
var zero T
if node == zero {
return ""
}
return left + node.SQL() + right
return sqlOptCtx(nil, left, node, right)
}

// strOpt outputs:
Expand Down Expand Up @@ -60,14 +162,7 @@ func strIfElse(pred bool, ifStr string, elseStr string) string {
// sqlJoin outputs joined string of SQL() of all elems by sep.
// This function corresponds to sqlJoin in ast.go
func sqlJoin[T Node](elems []T, sep string) string {
var b strings.Builder
for i, r := range elems {
if i > 0 {
b.WriteString(sep)
}
b.WriteString(r.SQL())
}
return b.String()
return sqlJoinCtx(nil, elems, sep)
}

// formatBoolUpper formats bool value as uppercase.
Expand Down Expand Up @@ -152,17 +247,26 @@ func paren(p prec, e Expr) string {
//
// ================================================================================

func (q *QueryStatement) sqlContext(fc *FormatContext) string {
return sqlOptCtx(fc, "", q.Hint, fc.newlineOr(" ")) +
fc.SQL(q.Query)
}

func (q *QueryStatement) SQL() string {
return sqlOpt("", q.Hint, " ") + q.Query.SQL()
return q.sqlContext(nil)
}

func (q *Query) sqlContext(fc *FormatContext) string {
return sqlOptCtx(fc, "", q.With, fc.newlineOr(" ")) +
fc.SQL(q.Query) +
sqlOptCtx(fc, fc.newlineOr(" "), q.OrderBy, "") +
sqlOptCtx(fc, fc.newlineOr(" "), q.Limit, "") +
strOpt(len(q.PipeOperators) > 0, fc.newlineOr(" ")) +
sqlJoinCtx(fc, q.PipeOperators, fc.newlineOr(" "))
}

func (q *Query) SQL() string {
return sqlOpt("", q.With, " ") +
q.Query.SQL() +
sqlOpt(" ", q.OrderBy, "") +
sqlOpt(" ", q.Limit, "") +
strOpt(len(q.PipeOperators) > 0, " ") +
sqlJoin(q.PipeOperators, " ")
return q.sqlContext(nil)
}

func (h *Hint) SQL() string {
Expand All @@ -173,23 +277,41 @@ func (h *HintRecord) SQL() string {
return h.Key.SQL() + "=" + h.Value.SQL()
}

func (w *With) sqlContext(fc *FormatContext) string {
return "WITH " + sqlJoinCtx(fc, w.CTEs, ", ")
}

func (w *With) SQL() string {
return "WITH " + sqlJoin(w.CTEs, ", ")
return w.sqlContext(nil)
}
func (c *CTE) sqlContext(fc *FormatContext) string {
return c.Name.SQL() + " AS (" +
fc.indentScope(func(fc *FormatContext) string {
return fc.newlineOr("") + fc.SQL(c.QueryExpr)
}) +
fc.newlineOr("") + ")"
}

func (c *CTE) SQL() string {
return c.Name.SQL() + " AS (" + c.QueryExpr.SQL() + ")"
return c.sqlContext(nil)
}

func (s *Select) sqlContext(fc *FormatContext) string {
return "SELECT" +
strOpt(s.AllOrDistinct != "", " "+string(s.AllOrDistinct)) +
sqlOptCtx(fc, " ", s.As, "") +
fc.indentScope(func(fc *FormatContext) string {
return strIfElse(len(s.Results) > 1, fc.newlineOr(" "), " ") +
sqlJoinCtx(fc, s.Results, ","+fc.newlineOr(" "))
}) +
sqlOptCtx(fc, fc.newlineOr(" "), s.From, "") +
sqlOptCtx(fc, fc.newlineOr(" "), s.Where, "") +
sqlOptCtx(fc, fc.newlineOr(" "), s.GroupBy, "") +
sqlOptCtx(fc, fc.newlineOr(" "), s.Having, "")
}

func (s *Select) SQL() string {
return "SELECT " +
strOpt(s.AllOrDistinct != "", string(s.AllOrDistinct)+" ") +
sqlOpt("", s.As, " ") +
sqlJoin(s.Results, ", ") +
sqlOpt(" ", s.From, "") +
sqlOpt(" ", s.Where, "") +
sqlOpt(" ", s.GroupBy, "") +
sqlOpt(" ", s.Having, "")
return s.sqlContext(nil)
}

func (a *AsStruct) SQL() string { return "AS STRUCT" }
Expand Down Expand Up @@ -232,8 +354,11 @@ func (e *ExprSelectItem) SQL() string {
return e.Expr.SQL()
}

func (f *From) sqlContext(fc *FormatContext) string {
return "FROM " + fc.SQL(f.Source)
}
func (f *From) SQL() string {
return "FROM " + f.Source.SQL()
return f.sqlContext(nil)
}

func (w *Where) SQL() string {
Expand Down Expand Up @@ -318,26 +443,37 @@ func (e *PathTableExpr) SQL() string {
sqlOpt(" ", e.Sample, "")
}

func (s *SubQueryTableExpr) sqlContext(fc *FormatContext) string {
return "(" +
fc.indentScope(func(fc *FormatContext) string {
return fc.newlineOr("") + fc.SQL(s.Query)
}) +
fc.newlineOr("") + ")" +
sqlOptCtx(fc, " ", s.As, "") +
sqlOptCtx(fc, " ", s.Sample, "")
}

func (s *SubQueryTableExpr) SQL() string {
return "(" + s.Query.SQL() + ")" +
sqlOpt(" ", s.As, "") +
sqlOpt(" ", s.Sample, "")
return s.sqlContext(nil)
}

func (p *ParenTableExpr) SQL() string {
return "(" + p.Source.SQL() + ")" +
sqlOpt(" ", p.Sample, "")
}

func (j *Join) SQL() string {
return j.Left.SQL() +
strOpt(j.Op != CommaJoin, " ") +
func (j *Join) sqlContext(fc *FormatContext) string {
return fc.SQL(j.Left) +
strOpt(j.Op != CommaJoin, fc.newlineOr(" ")) +
string(j.Op) + " " +
sqlOpt("", j.Hint, " ") +
j.Right.SQL() +
sqlOpt(" ", j.Cond, "")
sqlOptCtx(fc, "", j.Hint, " ") +
fc.SQL(j.Right) +
sqlOptCtx(fc, " ", j.Cond, "")
}

func (j *Join) SQL() string {
return j.sqlContext(nil)
}
func (o *On) SQL() string {
return "ON " + o.Expr.SQL()
}
Expand Down Expand Up @@ -745,16 +881,28 @@ func (a *AlterProtoBundleDelete) SQL() string { return "DELETE " + a.Types.SQL()

func (d *DropProtoBundle) SQL() string { return "DROP PROTO BUNDLE" }

func (c *CreateTable) SQL() string {
func (c *CreateTable) sqlContext(fc *FormatContext) string {
return "CREATE TABLE " +
strOpt(c.IfNotExists, "IF NOT EXISTS ") +
c.Name.SQL() + " (" +
sqlJoin(c.Columns, ", ") + strOpt(len(c.Columns) > 0 && (len(c.TableConstraints) > 0 || len(c.Synonyms) > 0), ", ") +
sqlJoin(c.TableConstraints, ", ") + strOpt(len(c.TableConstraints) > 0 && len(c.Synonyms) > 0, ", ") +
sqlJoin(c.Synonyms, ", ") +
") PRIMARY KEY (" + sqlJoin(c.PrimaryKeys, ", ") + ")" +
sqlOpt("", c.Cluster, "") +
sqlOpt("", c.RowDeletionPolicy, "")
fc.SQL(c.Name) +
" (" +
fc.indentScope(func(fc *FormatContext) string {
return fc.newlineOr("") + sqlJoinCtx(fc, c.Columns, ","+fc.newlineOr(" ")) +
strOpt(len(c.Columns) > 0 && (len(c.TableConstraints) > 0 || len(c.Synonyms) > 0), ","+fc.newlineOr(" ")) +
sqlJoinCtx(fc, c.TableConstraints, ","+fc.newlineOr(" ")) +
strOpt(len(c.TableConstraints) > 0 && len(c.Synonyms) > 0, ","+fc.newlineOr(" ")) +
sqlJoinCtx(fc, c.Synonyms, ","+fc.newlineOr(" "))
}) +
fc.newlineOr("") +
") PRIMARY KEY (" +
sqlJoinCtx(fc, c.PrimaryKeys, ", ") +
")" +
sqlOptCtx(fc, "", c.Cluster, "") +
sqlOptCtx(fc, "", c.RowDeletionPolicy, "")
}

func (c *CreateTable) SQL() string {
return c.sqlContext(nil)
}

func (s *Synonym) SQL() string { return "SYNONYM (" + s.Name.SQL() + ")" }
Expand Down
6 changes: 5 additions & 1 deletion tools/parse/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
"unicode"

"github.com/MakeNowJust/heredoc/v2"
"github.com/k0kubun/pp"

"github.com/cloudspannerecosystem/memefish"
"github.com/cloudspannerecosystem/memefish/ast"
"github.com/cloudspannerecosystem/memefish/token"
"github.com/cloudspannerecosystem/memefish/tools/util/poslang"
"github.com/k0kubun/pp"
)

var usage = heredoc.Doc(`
Expand Down Expand Up @@ -119,6 +120,9 @@ func main() {
fmt.Println("--- SQL")
fmt.Println(node.SQL())

fmt.Println("--- SQL with indentation")
fmt.Println(ast.FormatContextPretty(2).SQL(node))

if *pos != "" {
fmt.Println("--- POS")

Expand Down
Loading