diff --git a/compiler/ast.go b/compiler/ast.go index 2ce761e..09bc5cc 100644 --- a/compiler/ast.go +++ b/compiler/ast.go @@ -13,8 +13,8 @@ type StmtBase struct { type SelectStmt struct { *StmtBase - From *From - ResultColumn ResultColumn + From *From + ResultColumns []ResultColumn } // ResultColumn is the column definitions in a select statement. diff --git a/compiler/parser.go b/compiler/parser.go index f27b19a..c695894 100644 --- a/compiler/parser.go +++ b/compiler/parser.go @@ -64,10 +64,11 @@ func (p *parser) parseSelect(sb *StmtBase) (*SelectStmt, error) { if p.tokens[p.end].value != kwSelect { return nil, fmt.Errorf(tokenErr, p.tokens[p.end].value) } - err := p.parseResultColumn(stmt) + resultColumn, err := p.parseResultColumn() if err != nil { return nil, err } + stmt.ResultColumns = append(stmt.ResultColumns, *resultColumn) f := p.nextNonSpace() if f.tokenType == tkEOF || f.value == ";" { return stmt, nil @@ -87,30 +88,28 @@ func (p *parser) parseSelect(sb *StmtBase) (*SelectStmt, error) { } // parseResultColumn parses a single result column -func (p *parser) parseResultColumn(stmt *SelectStmt) error { +func (p *parser) parseResultColumn() (*ResultColumn, error) { + resultColumn := &ResultColumn{} r := p.nextNonSpace() // Handle a result column for all or *. if r.value == "*" { - stmt.ResultColumn = ResultColumn{ - All: true, - } - return nil + resultColumn.All = true + return resultColumn, nil } else if r.value == kwCount { // Handle the result column for the COUNT(*) aggregate. TODO this will // probably be refactored to an expression. if v := p.nextNonSpace().value; v != "(" { - return fmt.Errorf(tokenErr, v) + return nil, fmt.Errorf(tokenErr, v) } if v := p.nextNonSpace().value; v != "*" { - return fmt.Errorf(tokenErr, v) + return nil, fmt.Errorf(tokenErr, v) } if v := p.nextNonSpace().value; v != ")" { - return fmt.Errorf(tokenErr, v) - } - stmt.ResultColumn = ResultColumn{ - Count: true, + return nil, fmt.Errorf(tokenErr, v) } - return p.parseAlias(stmt) + resultColumn.Count = true + err := p.parseAlias(resultColumn) + return resultColumn, err } else if r.tokenType == tkIdentifier { // Handle an identifier such as a table or column name if p.peekNextNonSpace().value == "." { @@ -120,42 +119,45 @@ func (p *parser) parseResultColumn(stmt *SelectStmt) error { v := p.nextNonSpace().value // A star after the dot is to select all the cols in a table. if v == "*" { - stmt.ResultColumn.AllTable = r.value + resultColumn.AllTable = r.value } else { // Otherwise after the dot has to be a specific column name. - stmt.ResultColumn.Expression = &ColumnRef{ + resultColumn.Expression = &ColumnRef{ Table: r.value, Column: v, } - return p.parseAlias(stmt) + err := p.parseAlias(resultColumn) + return resultColumn, err } - return nil + return resultColumn, nil } else if p.peekNextNonSpace().tokenType == tkWhitespace { // If the identifier is followed by whitespace the identifier is a // column name. There is no table name meaning the table will have // to be resolved in the planner. - stmt.ResultColumn.Expression = &ColumnRef{ + resultColumn.Expression = &ColumnRef{ Column: r.value, } - return p.parseAlias(stmt) + err := p.parseAlias(resultColumn) + return resultColumn, err } else { - return fmt.Errorf(tokenErr, r.value) + return nil, fmt.Errorf(tokenErr, r.value) } } else if r.tokenType == tkNumeric { // A numeric value may begin a complex expression. vi, err := strconv.Atoi(r.value) if err != nil { - return err + return nil, err } - stmt.ResultColumn.Expression = &IntLit{ + resultColumn.Expression = &IntLit{ Value: vi, } - return p.parseAlias(stmt) + err = p.parseAlias(resultColumn) + return resultColumn, err } - return fmt.Errorf(tokenErr, r.value) + return nil, fmt.Errorf(tokenErr, r.value) } -func (p *parser) parseAlias(stmt *SelectStmt) error { +func (p *parser) parseAlias(resultColumn *ResultColumn) error { a := p.peekNextNonSpace().value if a == kwAs { p.nextNonSpace() @@ -163,7 +165,7 @@ func (p *parser) parseAlias(stmt *SelectStmt) error { if alias.tokenType != tkIdentifier { return fmt.Errorf(identErr, alias.value) } - stmt.ResultColumn.Alias = alias.value + resultColumn.Alias = alias.value } return nil } diff --git a/compiler/parser_test.go b/compiler/parser_test.go index f3067a5..d54f8b3 100644 --- a/compiler/parser_test.go +++ b/compiler/parser_test.go @@ -33,8 +33,10 @@ func TestParseSelect(t *testing.T) { From: &From{ TableName: "foo", }, - ResultColumn: ResultColumn{ - All: true, + ResultColumns: []ResultColumn{ + { + All: true, + }, }, }, }, @@ -63,8 +65,10 @@ func TestParseSelect(t *testing.T) { From: &From{ TableName: "foo", }, - ResultColumn: ResultColumn{ - All: true, + ResultColumns: []ResultColumn{ + { + All: true, + }, }, }, }, @@ -89,9 +93,11 @@ func TestParseSelect(t *testing.T) { From: &From{ TableName: "foo", }, - ResultColumn: ResultColumn{ - Count: true, - All: false, + ResultColumns: []ResultColumn{ + { + Count: true, + All: false, + }, }, }, }, @@ -113,10 +119,12 @@ func TestParseSelect(t *testing.T) { From: &From{ TableName: "foo", }, - ResultColumn: ResultColumn{ - Expression: &ColumnRef{ - Table: "foo", - Column: "id", + ResultColumns: []ResultColumn{ + { + Expression: &ColumnRef{ + Table: "foo", + Column: "id", + }, }, }, }, @@ -139,8 +147,10 @@ func TestParseSelect(t *testing.T) { From: &From{ TableName: "foo", }, - ResultColumn: ResultColumn{ - AllTable: "foo", + ResultColumns: []ResultColumn{ + { + AllTable: "foo", + }, }, }, }, @@ -164,11 +174,13 @@ func TestParseSelect(t *testing.T) { From: &From{ TableName: "foo", }, - ResultColumn: ResultColumn{ - Expression: &IntLit{ - Value: 1, + ResultColumns: []ResultColumn{ + { + Expression: &IntLit{ + Value: 1, + }, + Alias: "bar", }, - Alias: "bar", }, }, }, diff --git a/planner/select.go b/planner/select.go index 3939acc..b2c9ae6 100644 --- a/planner/select.go +++ b/planner/select.go @@ -83,7 +83,8 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { return nil, err } var child logicalNode - if p.stmt.ResultColumn.All { + resultColumn := p.stmt.ResultColumns[0] + if resultColumn.All { scanColumns, err := p.getScanColumns() if err != nil { return nil, err @@ -137,7 +138,8 @@ func (p *selectQueryPlanner) getScanColumns() ([]scanColumn, error) { } func (p *selectQueryPlanner) getProjections() ([]projection, error) { - if p.stmt.ResultColumn.All { + resultColumn := p.stmt.ResultColumns[0] + if resultColumn.All { cols, err := p.catalog.GetColumns(p.stmt.From.TableName) if err != nil { return nil, err @@ -150,7 +152,7 @@ func (p *selectQueryPlanner) getProjections() ([]projection, error) { } return projections, nil } - if p.stmt.ResultColumn.Count { + if resultColumn.Count { return []projection{ { isCount: true, diff --git a/planner/select_test.go b/planner/select_test.go index 5263175..d3038e2 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -50,8 +50,10 @@ func TestGetPlan(t *testing.T) { From: &compiler.From{ TableName: "foo", }, - ResultColumn: compiler.ResultColumn{ - All: true, + ResultColumns: []compiler.ResultColumn{ + { + All: true, + }, }, } mockCatalog := &mockSelectCatalog{} @@ -85,8 +87,10 @@ func TestGetPlanPKMiddleOrdinal(t *testing.T) { From: &compiler.From{ TableName: "foo", }, - ResultColumn: compiler.ResultColumn{ - All: true, + ResultColumns: []compiler.ResultColumn{ + { + All: true, + }, }, } mockCatalog := &mockSelectCatalog{} @@ -116,8 +120,10 @@ func TestGetCountAggregate(t *testing.T) { From: &compiler.From{ TableName: "foo", }, - ResultColumn: compiler.ResultColumn{ - Count: true, + ResultColumns: []compiler.ResultColumn{ + { + Count: true, + }, }, } mockCatalog := &mockSelectCatalog{} @@ -149,8 +155,10 @@ func TestGetPlanNoPrimaryKey(t *testing.T) { From: &compiler.From{ TableName: "foo", }, - ResultColumn: compiler.ResultColumn{ - All: true, + ResultColumns: []compiler.ResultColumn{ + { + All: true, + }, }, } mockCatalog := &mockSelectCatalog{}