From e16b9e33ab68bb431bd7385f23855fbffb45d62a Mon Sep 17 00:00:00 2001 From: Colton Hirst Date: Mon, 25 Nov 2024 21:25:34 -0700 Subject: [PATCH] Result columns (#7) Begin supporting expressions beyond just `*` in select list --- .gitignore | 1 + .vscode/settings.json | 7 ++ compiler/ast.go | 101 ++++++++++----- compiler/lexer.go | 65 ++++++++-- compiler/lexer_test.go | 92 +++++++++++++- compiler/parser.go | 181 ++++++++++++++++++++++---- compiler/parser_test.go | 273 +++++++++++++++++++++++++++++++++++----- db/db.go | 8 +- planner/select.go | 151 +++++++++++++++++----- planner/select_test.go | 106 ++++++++++++++-- 10 files changed, 834 insertions(+), 151 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 98e6ef6..ed94bd7 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *.db +*.sqlite diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..e1a5959 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "files.trimTrailingWhitespace": true, + "files.insertFinalNewline": true, + "cSpell.words": [ + "chirst" + ] +} diff --git a/compiler/ast.go b/compiler/ast.go index bd07d6e..fabb2f6 100644 --- a/compiler/ast.go +++ b/compiler/ast.go @@ -13,13 +13,20 @@ type StmtBase struct { type SelectStmt struct { *StmtBase - From *From - ResultColumn ResultColumn + From *From + ResultColumns []ResultColumn } +// ResultColumn is the column definitions in a select statement. type ResultColumn struct { - All bool - Count bool + // All is * in a select statement for example SELECT * FROM foo + All bool + // AllTable is all for a table for example SELECT foo.* FROM foo + AllTable string + // Expression contains the more complicated result column rules + Expression Expr + // Alias is the alias for an expression for example SELECT 1 AS "bar" + Alias string } type From struct { @@ -45,31 +52,61 @@ type InsertStmt struct { ColValues [][]string } -// type Expr interface { -// Type() string -// } - -// type BinaryExpr struct { -// Left Expr -// Operator string -// Right Expr -// } - -// type UnaryExpr struct { -// Operator string -// Operand Expr -// } - -// type ColumnRef struct { -// Table string -// Column string -// } - -// type IntLit struct { -// Value int -// } - -// type FunctionExpr struct { -// Name string -// Args []Expr -// } +// Expr defines the interface of an expression. +type Expr interface { + Type() string // TODO this pattern may not be the best +} + +// BinaryExpr is for an expression with two operands. +type BinaryExpr struct { + Left Expr + Operator string + Right Expr +} + +func (*BinaryExpr) Type() string { return "BinaryExpr" } + +// UnaryExpr is an expression with one operand. +type UnaryExpr struct { + Operator string + Operand Expr +} + +func (*UnaryExpr) Type() string { return "UnaryExpr" } + +// ColumnRef is an expression with no operands. It references a column on a +// table. +type ColumnRef struct { + Table string + Column string +} + +func (*ColumnRef) Type() string { return "ColumnRef" } + +// IntLit is an expression that is a literal integer such as "1". +type IntLit struct { + Value int +} + +func (*IntLit) Type() string { return "IntLit" } + +// StringLit is an expression that is a literal string such as "'asdf'". +type StringLit struct { + Value string +} + +func (*StringLit) Type() string { return "StringLit" } + +// FunctionExpr is an expression that represents a function. +type FunctionExpr struct { + // FnType corresponds to the type of function. For example fnCount is for + // COUNT(*) + FnType string + Args []Expr +} + +const ( + FnCount = "COUNT" +) + +func (*FunctionExpr) Type() string { return "FunctionExpr" } diff --git a/compiler/lexer.go b/compiler/lexer.go index 2c3668e..423ed12 100644 --- a/compiler/lexer.go +++ b/compiler/lexer.go @@ -30,12 +30,12 @@ const ( tkSeparator // tkOperator is a symbol that operates on arguments. tkOperator - // tkPunctuator is punctuation that is neither a separator or operator. - tkPunctuator // tkLiteral is a quoted text value like 'foo'. tkLiteral // tkNumeric is a numeric value like 1, 1.2, or -3. tkNumeric + // tkPunctuator is punctuation that is neither a separator or operator. + // tkPunctuator ) // Keywords where kw is keyword @@ -55,8 +55,10 @@ const ( kwText = "TEXT" kwPrimary = "PRIMARY" kwKey = "KEY" + kwAs = "AS" ) +// keywords is a list of all keywords. var keywords = []string{ kwExplain, kwQuery, @@ -73,11 +75,35 @@ var keywords = []string{ kwText, kwPrimary, kwKey, + kwAs, } -func (*lexer) isKeyword(w string) bool { - uw := strings.ToUpper(w) - return slices.Contains(keywords, uw) +// Operators where op is operator. +const ( + opSub = "-" + opAdd = "+" + opDiv = "/" + opMul = "*" + opExp = "^" +) + +// operators is a list of all operators. +var operators = []string{ + opSub, + opAdd, + opDiv, + opMul, + opExp, +} + +// opPrecedence defines operator precedence. The higher the number the higher +// the precedence. +var opPrecedence = map[string]int{ + opSub: 1, + opAdd: 1, + opDiv: 2, + opMul: 2, + opExp: 3, } type lexer struct { @@ -112,12 +138,12 @@ func (l *lexer) getToken() token { return l.scanWord() case l.isDigit(r): return l.scanDigit() - case l.isAsterisk(r): - return l.scanAsterisk() case l.isSeparator(r): return l.scanSeparator() case l.isSingleQuote(r): return l.scanLiteral() + case l.isOperator(r): + return l.scanOperator() } return token{tkEOF, ""} } @@ -164,11 +190,6 @@ func (l *lexer) scanDigit() token { return token{tokenType: tkNumeric, value: l.src[l.start:l.end]} } -func (l *lexer) scanAsterisk() token { - l.next() - return token{tokenType: tkPunctuator, value: l.src[l.start:l.end]} -} - func (l *lexer) scanSeparator() token { l.next() return token{tokenType: tkSeparator, value: l.src[l.start:l.end]} @@ -183,6 +204,11 @@ func (l *lexer) scanLiteral() token { return token{tokenType: tkLiteral, value: l.src[l.start:l.end]} } +func (l *lexer) scanOperator() token { + l.next() + return token{tokenType: tkOperator, value: l.src[l.start:l.end]} +} + func (*lexer) isWhiteSpace(r rune) bool { return r == ' ' || r == '\t' || r == '\n' } @@ -204,9 +230,22 @@ func (*lexer) isDigit(r rune) bool { } func (*lexer) isSeparator(r rune) bool { - return r == ',' || r == '(' || r == ')' || r == ';' + return r == ',' || r == '(' || r == ')' || r == ';' || r == '.' } func (*lexer) isSingleQuote(r rune) bool { return r == '\'' } + +func (*lexer) isKeyword(w string) bool { + uw := strings.ToUpper(w) + return slices.Contains(keywords, uw) +} + +func (*lexer) isOperator(o rune) bool { + ros := []rune{} + for _, op := range operators { + ros = append(ros, rune(op[0])) + } + return slices.Contains(ros, o) +} diff --git a/compiler/lexer_test.go b/compiler/lexer_test.go index 0864439..cac44f0 100644 --- a/compiler/lexer_test.go +++ b/compiler/lexer_test.go @@ -17,7 +17,7 @@ func TestLexSelect(t *testing.T) { expected: []token{ {tkKeyword, "SELECT"}, {tkWhitespace, " "}, - {tkPunctuator, "*"}, + {tkOperator, "*"}, {tkWhitespace, " "}, {tkKeyword, "FROM"}, {tkWhitespace, " "}, @@ -31,7 +31,7 @@ func TestLexSelect(t *testing.T) { {tkWhitespace, " "}, {tkKeyword, "COUNT"}, {tkSeparator, "("}, - {tkPunctuator, "*"}, + {tkOperator, "*"}, {tkSeparator, ")"}, {tkWhitespace, " "}, {tkKeyword, "FROM"}, @@ -44,7 +44,7 @@ func TestLexSelect(t *testing.T) { expected: []token{ {tkKeyword, "SELECT"}, {tkWhitespace, " "}, - {tkPunctuator, "*"}, + {tkOperator, "*"}, {tkWhitespace, " "}, {tkKeyword, "FROM"}, {tkWhitespace, " "}, @@ -59,7 +59,7 @@ func TestLexSelect(t *testing.T) { expected: []token{ {tkKeyword, "SELECT"}, {tkWhitespace, " "}, - {tkPunctuator, "*"}, + {tkOperator, "*"}, {tkWhitespace, " "}, {tkKeyword, "FROM"}, {tkWhitespace, " "}, @@ -107,6 +107,90 @@ func TestLexSelect(t *testing.T) { {tkSeparator, ";"}, }, }, + { + sql: "SELECT foo.id FROM foo", + expected: []token{ + {tkKeyword, "SELECT"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + {tkSeparator, "."}, + {tkIdentifier, "id"}, + {tkWhitespace, " "}, + {tkKeyword, "FROM"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + }, + }, + { + sql: "SELECT foo.* FROM foo", + expected: []token{ + {tkKeyword, "SELECT"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + {tkSeparator, "."}, + {tkOperator, "*"}, + {tkWhitespace, " "}, + {tkKeyword, "FROM"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + }, + }, + { + sql: "SELECT 1 AS bar FROM foo", + expected: []token{ + {tkKeyword, "SELECT"}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + {tkWhitespace, " "}, + {tkKeyword, "AS"}, + {tkWhitespace, " "}, + {tkIdentifier, "bar"}, + {tkWhitespace, " "}, + {tkKeyword, "FROM"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + }, + }, + { + sql: "SELECT 1 + 2 - 3 * 4 + 5 / 6 ^ 7 - 8 * 9", + expected: []token{ + {tkKeyword, "SELECT"}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + {tkWhitespace, " "}, + {tkOperator, "+"}, + {tkWhitespace, " "}, + {tkNumeric, "2"}, + {tkWhitespace, " "}, + {tkOperator, "-"}, + {tkWhitespace, " "}, + {tkNumeric, "3"}, + {tkWhitespace, " "}, + {tkOperator, "*"}, + {tkWhitespace, " "}, + {tkNumeric, "4"}, + {tkWhitespace, " "}, + {tkOperator, "+"}, + {tkWhitespace, " "}, + {tkNumeric, "5"}, + {tkWhitespace, " "}, + {tkOperator, "/"}, + {tkWhitespace, " "}, + {tkNumeric, "6"}, + {tkWhitespace, " "}, + {tkOperator, "^"}, + {tkWhitespace, " "}, + {tkNumeric, "7"}, + {tkWhitespace, " "}, + {tkOperator, "-"}, + {tkWhitespace, " "}, + {tkNumeric, "8"}, + {tkWhitespace, " "}, + {tkOperator, "*"}, + {tkWhitespace, " "}, + {tkNumeric, "9"}, + }, + }, } for _, c := range cases { t.Run(c.sql, func(t *testing.T) { diff --git a/compiler/parser.go b/compiler/parser.go index 53064c7..465c8a9 100644 --- a/compiler/parser.go +++ b/compiler/parser.go @@ -5,7 +5,9 @@ package compiler // Machine). import ( + "errors" "fmt" + "strconv" ) const ( @@ -33,9 +35,8 @@ func (p *parser) parseStmt() (Stmt, error) { t := p.tokens[p.start] sb := &StmtBase{} if t.value == kwExplain { - t = p.nextNonSpace() - nv := p.peekNextNonSpace().value - if nv == kwQuery { + nv := p.nextNonSpace() + if nv.value == kwQuery { tp := p.nextNonSpace() if tp.value == kwPlan { sb.ExplainQueryPlan = true @@ -45,6 +46,7 @@ func (p *parser) parseStmt() (Stmt, error) { } } else { sb.Explain = true + t = nv } } switch t.value { @@ -63,28 +65,18 @@ func (p *parser) parseSelect(sb *StmtBase) (*SelectStmt, error) { if p.tokens[p.end].value != kwSelect { return nil, fmt.Errorf(tokenErr, p.tokens[p.end].value) } - r := p.nextNonSpace() - if r.value == "*" { - stmt.ResultColumn = ResultColumn{ - All: true, - } - } else if r.value == kwCount { - if v := p.nextNonSpace().value; v != "(" { - return nil, fmt.Errorf(tokenErr, v) - } - if v := p.nextNonSpace().value; v != "*" { - return nil, fmt.Errorf(tokenErr, v) - } - if v := p.nextNonSpace().value; v != ")" { - return nil, fmt.Errorf(tokenErr, v) + for { + resultColumn, err := p.parseResultColumn() + if err != nil { + return nil, err } - stmt.ResultColumn = ResultColumn{ - Count: true, + stmt.ResultColumns = append(stmt.ResultColumns, *resultColumn) + n := p.peekNextNonSpace() + if n.value != "," { + break } - } else { - return nil, fmt.Errorf(tokenErr, r.value) + p.nextNonSpace() } - f := p.nextNonSpace() if f.tokenType == tkEOF || f.value == ";" { return stmt, nil @@ -103,6 +95,139 @@ func (p *parser) parseSelect(sb *StmtBase) (*SelectStmt, error) { return stmt, nil } +// parseResultColumn parses a single result column +func (p *parser) parseResultColumn() (*ResultColumn, error) { + resultColumn := &ResultColumn{} + r := p.nextNonSpace() + // There are three cases to handle here. + // 1. * + // 2. tableName.* + // 3. expression AS alias + // We simply try and identify the first two then fall into expression + // parsing if the first two cases are not present. This is a smart way to do + // things since expressions are not limited to result columns. + if r.value == "*" { + resultColumn.All = true + return resultColumn, nil + } else if r.tokenType == tkIdentifier { + if p.peekNextNonSpace().value == "." { + if p.peekNonSpaceBy(2).value == "*" { + p.nextNonSpace() // move to . + p.nextNonSpace() // move to * + resultColumn.AllTable = r.value + return resultColumn, nil + } + } + } + p.rewind() + expr, err := p.parseExpression(0) + if err != nil { + return nil, err + } + resultColumn.Expression = expr + err = p.parseAlias(resultColumn) + return resultColumn, err +} + +// Vaughan Pratt's top down operator precedence parsing algorithm. +// Definitions: +// - Left binding power (LBP) an integer representing operator precedence level. +// - Null denotation (NUD) nothing to it's left (prefix). +// - Left denotation (LED) something to it's left (infix). +// - Right binding power (RBP) parse prefix operator then iteratively parse +// infix expressions. +// +// Begin with rbp 0 +func (p *parser) parseExpression(rbp int) (Expr, error) { + left, err := p.getOperand() + if err != nil { + return nil, err + } + for { + nextToken := p.peekNextNonSpace() + if nextToken.tokenType != tkOperator { + return left, nil + } + lbp := opPrecedence[nextToken.value] + if lbp <= rbp { + return left, nil + } + p.nextNonSpace() + right, err := p.parseExpression(lbp) + if err != nil { + return nil, err + } + left = &BinaryExpr{ + Left: left, + Operator: nextToken.value, + Right: right, + } + } +} + +// getOperand is a parseExpression helper who parses token groups into atomic +// expressions serving as operands in the expression tree. A good example of +// this would be in the statement `SELECT foo.bar + 1;`. `foo.bar` is processed +// as three tokens, but needs to be "squashed" into the expression `ColumnRef`. +func (p *parser) getOperand() (Expr, error) { + first := p.nextNonSpace() + if first.tokenType == tkLiteral { + return &StringLit{Value: first.value}, nil + } + if first.tokenType == tkNumeric { + intValue, err := strconv.Atoi(first.value) + if err != nil { + return nil, errors.New("failed to parse numeric token") + } + return &IntLit{Value: intValue}, nil + } + if first.tokenType == tkIdentifier { + next := p.peekNextNonSpace() + if next.value == "." { + p.nextNonSpace() + prop := p.peekNextNonSpace() + if prop.tokenType == tkIdentifier { + p.nextNonSpace() + return &ColumnRef{ + Table: first.value, + Column: prop.value, + }, nil + } + } + return &ColumnRef{ + Column: first.value, + }, nil + } + if first.tokenType == tkKeyword && first.value == kwCount { + if v := p.nextNonSpace().value; v != "(" { + return nil, fmt.Errorf(tokenErr, v) + } + if v := p.nextNonSpace().value; v != "*" { + return nil, fmt.Errorf(tokenErr, v) + } + if v := p.nextNonSpace().value; v != ")" { + return nil, fmt.Errorf(tokenErr, v) + } + return &FunctionExpr{FnType: FnCount}, nil + } + // TODO support unary prefix expression + // TODO support parens + return nil, errors.New("failed to parse null denotation") +} + +func (p *parser) parseAlias(resultColumn *ResultColumn) error { + a := p.peekNextNonSpace().value + if a == kwAs { + p.nextNonSpace() + alias := p.nextNonSpace() + if alias.tokenType != tkIdentifier { + return fmt.Errorf(identErr, alias.value) + } + resultColumn.Alias = alias.value + } + return nil +} + func (p *parser) parseCreate(sb *StmtBase) (*CreateStmt, error) { stmt := &CreateStmt{StmtBase: sb} if p.tokens[p.end].value != kwCreate { @@ -237,7 +362,12 @@ func (p *parser) nextNonSpace() token { } func (p *parser) peekNextNonSpace() token { - tmpEnd := p.end + return p.peekNonSpaceBy(1) +} + +// peekNonSpaceBy will peek more than one space ahead. +func (p *parser) peekNonSpaceBy(next int) token { + tmpEnd := p.end + next if tmpEnd > len(p.tokens)-1 { return token{tkEOF, ""} } @@ -249,3 +379,8 @@ func (p *parser) peekNextNonSpace() token { } return p.tokens[tmpEnd] } + +func (p *parser) rewind() token { + p.end = p.end - 1 + return p.tokens[p.end] +} diff --git a/compiler/parser_test.go b/compiler/parser_test.go index 2c56509..d48a046 100644 --- a/compiler/parser_test.go +++ b/compiler/parser_test.go @@ -2,6 +2,7 @@ package compiler import ( "reflect" + "slices" "testing" ) @@ -20,7 +21,7 @@ func TestParseSelect(t *testing.T) { {tkWhitespace, " "}, {tkKeyword, "SELECT"}, {tkWhitespace, " "}, - {tkPunctuator, "*"}, + {tkOperator, "*"}, {tkWhitespace, " "}, {tkKeyword, "FROM"}, {tkWhitespace, " "}, @@ -33,8 +34,10 @@ func TestParseSelect(t *testing.T) { From: &From{ TableName: "foo", }, - ResultColumn: ResultColumn{ - All: true, + ResultColumns: []ResultColumn{ + { + All: true, + }, }, }, }, @@ -49,7 +52,7 @@ func TestParseSelect(t *testing.T) { {tkWhitespace, " "}, {tkKeyword, "SELECT"}, {tkWhitespace, " "}, - {tkPunctuator, "*"}, + {tkOperator, "*"}, {tkWhitespace, " "}, {tkKeyword, "FROM"}, {tkWhitespace, " "}, @@ -63,35 +66,10 @@ func TestParseSelect(t *testing.T) { From: &From{ TableName: "foo", }, - ResultColumn: ResultColumn{ - All: true, - }, - }, - }, - { - name: "with count", - tokens: []token{ - {tkKeyword, "SELECT"}, - {tkWhitespace, " "}, - {tkKeyword, "COUNT"}, - {tkSeparator, "("}, - {tkPunctuator, "*"}, - {tkSeparator, ")"}, - {tkWhitespace, " "}, - {tkKeyword, "FROM"}, - {tkWhitespace, " "}, - {tkIdentifier, "foo"}, - }, - expect: &SelectStmt{ - StmtBase: &StmtBase{ - Explain: false, - }, - From: &From{ - TableName: "foo", - }, - ResultColumn: ResultColumn{ - Count: true, - All: false, + ResultColumns: []ResultColumn{ + { + All: true, + }, }, }, }, @@ -277,3 +255,232 @@ func TestParseInsert(t *testing.T) { } } } + +type resultColumnTestCase struct { + name string + tokens []token + expect []ResultColumn +} + +func TestParseResultColumn(t *testing.T) { + template := []token{ + {tkKeyword, "SELECT"}, + {tkWhitespace, " "}, + {tkWhitespace, " "}, + {tkKeyword, "FROM"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + } + cases := []resultColumnTestCase{ + { + name: "*", + tokens: []token{ + {tkOperator, "*"}, + }, + expect: []ResultColumn{ + { + All: true, + }, + }, + }, + { + name: "foo.*", + tokens: []token{ + {tkIdentifier, "foo"}, + {tkOperator, "."}, + {tkOperator, "*"}, + }, + expect: []ResultColumn{ + { + AllTable: "foo", + }, + }, + }, + { + name: "COUNT(*)", + tokens: []token{ + {tkKeyword, "COUNT"}, + {tkSeparator, "("}, + {tkOperator, "*"}, + {tkSeparator, ")"}, + }, + expect: []ResultColumn{ + { + Expression: &FunctionExpr{FnType: FnCount}, + }, + }, + }, + { + name: "COUNT(*) + 1", + tokens: []token{ + {tkKeyword, "COUNT"}, + {tkSeparator, "("}, + {tkOperator, "*"}, + {tkSeparator, ")"}, + {tkWhitespace, " "}, + {tkOperator, "+"}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + }, + expect: []ResultColumn{ + { + Expression: &BinaryExpr{ + Left: &FunctionExpr{FnType: FnCount}, + Operator: "+", + Right: &IntLit{Value: 1}, + }, + }, + }, + }, + { + name: "(1 + 2 - (3 * 4) + (5 / (6 ^ 7)) - (8 * 9))", + tokens: []token{ + {tkNumeric, "1"}, + {tkWhitespace, " "}, + {tkOperator, "+"}, + {tkWhitespace, " "}, + {tkNumeric, "2"}, + {tkWhitespace, " "}, + {tkOperator, "-"}, + {tkWhitespace, " "}, + {tkNumeric, "3"}, + {tkWhitespace, " "}, + {tkOperator, "*"}, + {tkWhitespace, " "}, + {tkNumeric, "4"}, + {tkWhitespace, " "}, + {tkOperator, "+"}, + {tkWhitespace, " "}, + {tkNumeric, "5"}, + {tkWhitespace, " "}, + {tkOperator, "/"}, + {tkWhitespace, " "}, + {tkNumeric, "6"}, + {tkWhitespace, " "}, + {tkOperator, "^"}, + {tkWhitespace, " "}, + {tkNumeric, "7"}, + {tkWhitespace, " "}, + {tkOperator, "-"}, + {tkWhitespace, " "}, + {tkNumeric, "8"}, + {tkWhitespace, " "}, + {tkOperator, "*"}, + {tkWhitespace, " "}, + {tkNumeric, "9"}, + }, + expect: []ResultColumn{ + { + Expression: &BinaryExpr{ + Left: &BinaryExpr{ + Left: &BinaryExpr{ + Left: &BinaryExpr{ + Left: &IntLit{Value: 1}, + Operator: opAdd, + Right: &IntLit{Value: 2}, + }, + Operator: opSub, + Right: &BinaryExpr{ + Left: &IntLit{Value: 3}, + Operator: opMul, + Right: &IntLit{Value: 4}, + }, + }, + Operator: opAdd, + Right: &BinaryExpr{ + Left: &IntLit{Value: 5}, + Operator: opDiv, + Right: &BinaryExpr{ + Left: &IntLit{Value: 6}, + Operator: opExp, + Right: &IntLit{Value: 7}, + }, + }, + }, + Operator: opSub, + Right: &BinaryExpr{ + Left: &IntLit{Value: 8}, + Operator: opMul, + Right: &IntLit{Value: 9}, + }, + }, + }, + }, + }, + { + name: "foo.id AS bar", + tokens: []token{ + {tkIdentifier, "foo"}, + {tkSeparator, "."}, + {tkIdentifier, "id"}, + {tkWhitespace, " "}, + {tkKeyword, "AS"}, + {tkWhitespace, " "}, + {tkIdentifier, "bar"}, + }, + expect: []ResultColumn{ + { + Expression: &ColumnRef{ + Table: "foo", + Column: "id", + }, + Alias: "bar", + }, + }, + }, + { + name: "1 + 2 AS foo, id, id2 AS id1", + tokens: []token{ + {tkNumeric, "1"}, + {tkWhitespace, " "}, + {tkOperator, "+"}, + {tkWhitespace, " "}, + {tkNumeric, "2"}, + {tkWhitespace, " "}, + {tkKeyword, "AS"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkIdentifier, "id"}, + {tkSeparator, ","}, + {tkWhitespace, " "}, + {tkIdentifier, "id2"}, + {tkWhitespace, " "}, + {tkKeyword, "AS"}, + {tkWhitespace, " "}, + {tkIdentifier, "id1"}, + }, + expect: []ResultColumn{ + { + Expression: &BinaryExpr{ + Left: &IntLit{Value: 1}, + Operator: "+", + Right: &IntLit{Value: 2}, + }, + Alias: "foo", + }, + { + Expression: &ColumnRef{Column: "id"}, + }, + { + Expression: &ColumnRef{Column: "id2"}, + Alias: "id1", + }, + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + tks := slices.Insert(template, 2, c.tokens...) + ret, err := NewParser(tks).Parse() + if err != nil { + t.Errorf("want no err got err %s", err) + } + retSelect, _ := ret.(*SelectStmt) + if !reflect.DeepEqual(retSelect.ResultColumns, c.expect) { + t.Errorf("got %#v want %#v", retSelect.ResultColumns, c.expect) + } + }) + } +} diff --git a/db/db.go b/db/db.go index fad69c5..e01eca5 100644 --- a/db/db.go +++ b/db/db.go @@ -23,11 +23,11 @@ type statementPlanner interface { } type dbCatalog interface { - GetColumns(tableOrIndexName string) ([]string, error) - GetRootPageNumber(tableOrIndexName string) (int, error) - TableExists(tableName string) bool + GetColumns(string) ([]string, error) + GetRootPageNumber(string) (int, error) + TableExists(string) bool GetVersion() string - GetPrimaryKeyColumn(tableName string) (string, error) + GetPrimaryKeyColumn(string) (string, error) } type DB struct { diff --git a/planner/select.go b/planner/select.go index 3939acc..24b7c6f 100644 --- a/planner/select.go +++ b/planner/select.go @@ -1,6 +1,10 @@ package planner import ( + "errors" + "fmt" + "slices" + "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) @@ -83,20 +87,87 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { return nil, err } var child logicalNode - if p.stmt.ResultColumn.All { - scanColumns, err := p.getScanColumns() - if err != nil { - return nil, err - } - child = &scanNode{ - tableName: tableName, - rootPage: rootPageNumber, - scanColumns: scanColumns, - } - } else { - child = &countNode{ - tableName: tableName, - rootPage: rootPageNumber, + for _, resultColumn := range p.stmt.ResultColumns { + if resultColumn.All { + scanColumns, err := p.getScanColumns() + if err != nil { + return nil, err + } + switch c := child.(type) { + case *scanNode: + c.scanColumns = append(c.scanColumns, scanColumns...) + case nil: + child = &scanNode{ + tableName: tableName, + rootPage: rootPageNumber, + scanColumns: scanColumns, + } + default: + return nil, errors.New("expected scanNode") + } + } else if resultColumn.Expression != nil { + switch e := resultColumn.Expression.(type) { + case *compiler.ColumnRef: + if e.Table == "" { + e.Table = p.stmt.From.TableName + } + cols, err := p.catalog.GetColumns(e.Table) + if err != nil { + return nil, err + } + colIdx := slices.Index(cols, e.Column) + pkCol, err := p.catalog.GetPrimaryKeyColumn(e.Table) + if err != nil { + return nil, err + } + pkColIdx := slices.Index(cols, pkCol) + if pkColIdx < colIdx { + colIdx -= 1 + } + switch c := child.(type) { + case *scanNode: + c.scanColumns = append(c.scanColumns, scanColumn{ + isPrimaryKey: pkCol == e.Column, + colIdx: colIdx, + }) + case nil: + child = &scanNode{ + tableName: e.Table, + rootPage: rootPageNumber, + scanColumns: []scanColumn{ + { + isPrimaryKey: pkCol == e.Column, + colIdx: colIdx, + }, + }, + } + default: + return nil, fmt.Errorf("expected scan node but got %#v", c) + } + case *compiler.FunctionExpr: + if e.FnType != compiler.FnCount { + return nil, fmt.Errorf( + "only %s function is supported", e.FnType, + ) + } + switch child.(type) { + case nil: + child = &countNode{ + tableName: tableName, + rootPage: rootPageNumber, + } + default: + return nil, errors.New( + "count with other result columns not supported", + ) + } + default: + return nil, fmt.Errorf( + "unhandled expression for result column %#v", resultColumn, + ) + } + } else { + return nil, fmt.Errorf("unhandled result column %#v", resultColumn) } } projections, err := p.getProjections() @@ -137,27 +208,39 @@ func (p *selectQueryPlanner) getScanColumns() ([]scanColumn, error) { } func (p *selectQueryPlanner) getProjections() ([]projection, error) { - if p.stmt.ResultColumn.All { - cols, err := p.catalog.GetColumns(p.stmt.From.TableName) - if err != nil { - return nil, err + var projections []projection + for _, resultColumn := range p.stmt.ResultColumns { + if resultColumn.All { + cols, err := p.catalog.GetColumns(p.stmt.From.TableName) + if err != nil { + return nil, err + } + for _, c := range cols { + projections = append(projections, projection{ + colName: c, + }) + } + } else if resultColumn.Expression != nil { + switch e := resultColumn.Expression.(type) { + case *compiler.ColumnRef: + colName := e.Column + if resultColumn.Alias != "" { + colName = resultColumn.Alias + } + projections = append(projections, projection{ + colName: colName, + }) + case *compiler.FunctionExpr: + projections = append(projections, projection{ + isCount: true, + colName: resultColumn.Alias, + }) + default: + return nil, fmt.Errorf("unhandled result column expression %#v", e) + } } - projections := []projection{} - for _, c := range cols { - projections = append(projections, projection{ - colName: c, - }) - } - return projections, nil - } - if p.stmt.ResultColumn.Count { - return []projection{ - { - isCount: true, - }, - }, nil } - panic("unhandled projection") + return projections, nil } // ExecutionPlan returns the bytecode execution plan for the planner. Calling @@ -184,7 +267,7 @@ func (p *selectExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { case *countNode: p.buildOptimizedCountScan(c) default: - panic("unhandled node") + return nil, fmt.Errorf("unhandled node %#v", c) } p.executionPlan.Append(&vm.HaltCmd{}) return p.executionPlan, nil diff --git a/planner/select_test.go b/planner/select_test.go index 5263175..8d0f8e7 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -50,8 +50,92 @@ func TestGetPlan(t *testing.T) { From: &compiler.From{ TableName: "foo", }, - ResultColumn: compiler.ResultColumn{ - All: true, + ResultColumns: []compiler.ResultColumn{ + { + All: true, + }, + }, + } + mockCatalog := &mockSelectCatalog{} + mockCatalog.primaryKeyColumnName = "id" + mockCatalog.columns = []string{"name", "id", "age"} + plan, err := NewSelect(mockCatalog, ast).ExecutionPlan() + if err != nil { + t.Errorf("expected no err got err %s", err) + } + for i, c := range expectedCommands { + if !reflect.DeepEqual(c, plan.Commands[i]) { + t.Errorf("got %#v want %#v", plan.Commands[i], c) + } + } +} + +func TestGetPlanSelectColumn(t *testing.T) { + expectedCommands := []vm.Command{ + &vm.InitCmd{P2: 1}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 7}, + &vm.RowIdCmd{P1: 1, P2: 1}, + &vm.ResultRowCmd{P1: 1, P2: 1}, + &vm.NextCmd{P1: 1, P2: 4}, + &vm.HaltCmd{}, + } + ast := &compiler.SelectStmt{ + StmtBase: &compiler.StmtBase{}, + From: &compiler.From{ + TableName: "foo", + }, + ResultColumns: []compiler.ResultColumn{ + { + Expression: &compiler.ColumnRef{ + Column: "id", + }, + }, + }, + } + mockCatalog := &mockSelectCatalog{} + mockCatalog.primaryKeyColumnName = "id" + mockCatalog.columns = []string{"name", "id", "age"} + plan, err := NewSelect(mockCatalog, ast).ExecutionPlan() + if err != nil { + t.Errorf("expected no err got err %s", err) + } + for i, c := range expectedCommands { + if !reflect.DeepEqual(c, plan.Commands[i]) { + t.Errorf("got %#v want %#v", plan.Commands[i], c) + } + } +} + +func TestGetPlanSelectMultiColumn(t *testing.T) { + expectedCommands := []vm.Command{ + &vm.InitCmd{P2: 1}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 8}, + &vm.RowIdCmd{P1: 1, P2: 1}, + &vm.ColumnCmd{P1: 1, P2: 1, P3: 2}, + &vm.ResultRowCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 4}, + &vm.HaltCmd{}, + } + ast := &compiler.SelectStmt{ + StmtBase: &compiler.StmtBase{}, + From: &compiler.From{ + TableName: "foo", + }, + ResultColumns: []compiler.ResultColumn{ + { + Expression: &compiler.ColumnRef{ + Column: "id", + }, + }, + { + Expression: &compiler.ColumnRef{ + Column: "age", + }, + }, }, } mockCatalog := &mockSelectCatalog{} @@ -85,8 +169,10 @@ func TestGetPlanPKMiddleOrdinal(t *testing.T) { From: &compiler.From{ TableName: "foo", }, - ResultColumn: compiler.ResultColumn{ - All: true, + ResultColumns: []compiler.ResultColumn{ + { + All: true, + }, }, } mockCatalog := &mockSelectCatalog{} @@ -116,8 +202,10 @@ func TestGetCountAggregate(t *testing.T) { From: &compiler.From{ TableName: "foo", }, - ResultColumn: compiler.ResultColumn{ - Count: true, + ResultColumns: []compiler.ResultColumn{ + { + Expression: &compiler.FunctionExpr{FnType: compiler.FnCount}, + }, }, } mockCatalog := &mockSelectCatalog{} @@ -149,8 +237,10 @@ func TestGetPlanNoPrimaryKey(t *testing.T) { From: &compiler.From{ TableName: "foo", }, - ResultColumn: compiler.ResultColumn{ - All: true, + ResultColumns: []compiler.ResultColumn{ + { + All: true, + }, }, } mockCatalog := &mockSelectCatalog{}