From 135ef36342f488aee5594ea3da2f8590d7478e65 Mon Sep 17 00:00:00 2001 From: chirst Date: Mon, 25 Nov 2024 20:51:14 -0700 Subject: [PATCH] planner handles more than one result column --- .gitignore | 1 + .vscode/settings.json | 5 ++ planner/select.go | 170 ++++++++++++++++++++++++----------------- planner/select_test.go | 44 +++++++++++ 4 files changed, 151 insertions(+), 69 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 98e6ef6..11424f8 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *.db +*.sqlite \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..0f35012 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "cSpell.words": [ + "chirst" + ] +} \ No newline at end of file diff --git a/planner/select.go b/planner/select.go index 9dd904a..529633d 100644 --- a/planner/select.go +++ b/planner/select.go @@ -1,6 +1,8 @@ package planner import ( + "errors" + "fmt" "slices" "github.com/chirst/cdb/compiler" @@ -85,53 +87,84 @@ func (p *selectQueryPlanner) getQueryPlan() (*QueryPlan, error) { return nil, err } var child logicalNode - resultColumn := p.stmt.ResultColumns[0] - if resultColumn.Count { - child = &countNode{ - tableName: tableName, - rootPage: rootPageNumber, - } - } else if resultColumn.All { - scanColumns, err := p.getScanColumns() - if err != nil { - return nil, err - } - child = &scanNode{ - tableName: tableName, - rootPage: rootPageNumber, - scanColumns: scanColumns, - } - } else if resultColumn.Expression != nil { - switch e := resultColumn.Expression.(type) { - case *compiler.ColumnRef: - if e.Table == "" { - // TODO should do better at checking no table - e.Table = p.stmt.From.TableName + for _, resultColumn := range p.stmt.ResultColumns { + if resultColumn.Count { + switch child.(type) { + case nil: + child = &countNode{ + tableName: tableName, + rootPage: rootPageNumber, + } + default: + return nil, errors.New( + "count with other result columns not supported", + ) } - cols, err := p.catalog.GetColumns(e.Table) + } else if resultColumn.All { + scanColumns, err := p.getScanColumns() if err != nil { return nil, err } - colIdx := slices.Index(cols, e.Column) - pkCol, err := p.catalog.GetPrimaryKeyColumn(e.Table) - if err != nil { - return nil, err + switch c := child.(type) { + case *scanNode: + c.scanColumns = append(c.scanColumns, scanColumns...) + case nil: + child = &scanNode{ + tableName: tableName, + rootPage: rootPageNumber, + scanColumns: scanColumns, + } + default: + return nil, errors.New("expected scanNode") } - child = &scanNode{ - tableName: e.Table, - rootPage: rootPageNumber, - scanColumns: []scanColumn{ - { + } else if resultColumn.Expression != nil { + switch e := resultColumn.Expression.(type) { + case *compiler.ColumnRef: + if e.Table == "" { + // TODO should do better at checking no table + e.Table = p.stmt.From.TableName + } + cols, err := p.catalog.GetColumns(e.Table) + if err != nil { + return nil, err + } + colIdx := slices.Index(cols, e.Column) + pkCol, err := p.catalog.GetPrimaryKeyColumn(e.Table) + if err != nil { + return nil, err + } + pkColIdx := slices.Index(cols, pkCol) + if pkColIdx < colIdx { + colIdx -= 1 + } + switch c := child.(type) { + case *scanNode: + c.scanColumns = append(c.scanColumns, scanColumn{ isPrimaryKey: pkCol == e.Column, colIdx: colIdx, - }, - }, + }) + case nil: + child = &scanNode{ + tableName: e.Table, + rootPage: rootPageNumber, + scanColumns: []scanColumn{ + { + isPrimaryKey: pkCol == e.Column, + colIdx: colIdx, + }, + }, + } + default: + return nil, fmt.Errorf("expected scan node but got %#v", c) + } + default: + return nil, fmt.Errorf( + "unhandled expression for result column %#v", resultColumn, + ) } - default: - panic("unhandled expression") + } else { + return nil, fmt.Errorf("unhandled result column %#v", resultColumn) } - } else { - panic("unhandled result column") } projections, err := p.getProjections() if err != nil { @@ -171,40 +204,39 @@ func (p *selectQueryPlanner) getScanColumns() ([]scanColumn, error) { } func (p *selectQueryPlanner) getProjections() ([]projection, error) { - resultColumn := p.stmt.ResultColumns[0] - if resultColumn.Count { - return []projection{ - { - isCount: true, - }, - }, nil - } - if resultColumn.All { - cols, err := p.catalog.GetColumns(p.stmt.From.TableName) - if err != nil { - return nil, err - } - projections := []projection{} - for _, c := range cols { + var projections []projection + for _, resultColumn := range p.stmt.ResultColumns { + if resultColumn.Count { projections = append(projections, projection{ - colName: c, + isCount: true, + colName: resultColumn.Alias, }) - } - return projections, nil - } - if resultColumn.Expression != nil { - switch e := resultColumn.Expression.(type) { - case *compiler.ColumnRef: - return []projection{ - { - colName: e.Column, - }, - }, nil - default: - panic("unhandled expression") + } else if resultColumn.All { + cols, err := p.catalog.GetColumns(p.stmt.From.TableName) + if err != nil { + return nil, err + } + for _, c := range cols { + projections = append(projections, projection{ + colName: c, + }) + } + } else if resultColumn.Expression != nil { + switch e := resultColumn.Expression.(type) { + case *compiler.ColumnRef: + colName := e.Column + if resultColumn.Alias != "" { + colName = resultColumn.Alias + } + projections = append(projections, projection{ + colName: colName, + }) + default: + return nil, fmt.Errorf("unhandled result column expression %#v", e) + } } } - panic("unhandled projection") + return projections, nil } // ExecutionPlan returns the bytecode execution plan for the planner. Calling @@ -231,7 +263,7 @@ func (p *selectExecutionPlanner) getExecutionPlan() (*vm.ExecutionPlan, error) { case *countNode: p.buildOptimizedCountScan(c) default: - panic("unhandled node") + return nil, fmt.Errorf("unhandled node %#v", c) } p.executionPlan.Append(&vm.HaltCmd{}) return p.executionPlan, nil diff --git a/planner/select_test.go b/planner/select_test.go index 5c27f03..7618778 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -108,6 +108,50 @@ func TestGetPlanSelectColumn(t *testing.T) { } } +func TestGetPlanSelectMultiColumn(t *testing.T) { + expectedCommands := []vm.Command{ + &vm.InitCmd{P2: 1}, + &vm.TransactionCmd{P1: 0}, + &vm.OpenReadCmd{P1: 1, P2: 2}, + &vm.RewindCmd{P1: 1, P2: 8}, + &vm.RowIdCmd{P1: 1, P2: 1}, + &vm.ColumnCmd{P1: 1, P2: 1, P3: 2}, + &vm.ResultRowCmd{P1: 1, P2: 2}, + &vm.NextCmd{P1: 1, P2: 4}, + &vm.HaltCmd{}, + } + ast := &compiler.SelectStmt{ + StmtBase: &compiler.StmtBase{}, + From: &compiler.From{ + TableName: "foo", + }, + ResultColumns: []compiler.ResultColumn{ + { + Expression: &compiler.ColumnRef{ + Column: "id", + }, + }, + { + Expression: &compiler.ColumnRef{ + Column: "age", + }, + }, + }, + } + mockCatalog := &mockSelectCatalog{} + mockCatalog.primaryKeyColumnName = "id" + mockCatalog.columns = []string{"name", "id", "age"} + plan, err := NewSelect(mockCatalog, ast).ExecutionPlan() + if err != nil { + t.Errorf("expected no err got err %s", err) + } + for i, c := range expectedCommands { + if !reflect.DeepEqual(c, plan.Commands[i]) { + t.Errorf("got %#v want %#v", plan.Commands[i], c) + } + } +} + func TestGetPlanPKMiddleOrdinal(t *testing.T) { expectedCommands := []vm.Command{ &vm.InitCmd{P2: 1},