diff --git a/README.md b/README.md index dc04be2..08f3a62 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,14 @@ a subset of SQL described below. graph LR begin(( )) explain([EXPLAIN]) +queryPlan([QUERY PLAN]) select([SELECT]) all[*] from([FROM]) begin --> explain +explain --> queryPlan +queryPlan --> select begin --> select explain --> select select --> all @@ -33,6 +36,7 @@ Create supports the `PRIMARY KEY` column constraint for a single integer column. graph LR begin(( )) explain([EXPLAIN]) +queryPlan([QUERY PLAN]) create([CREATE]) table([TABLE]) colTypeInt([INTEGER]) @@ -45,6 +49,8 @@ colIdent["Column Identifier"] pkConstraint["PRIMARY KEY"] begin --> explain +explain --> queryPlan +queryPlan --> create begin --> create explain --> create create --> table @@ -69,6 +75,7 @@ colSep --> colIdent graph LR begin(( )) explain([EXPLAIN]) +queryPlan([QUERY PLAN]) insert([INSERT]) into([INTO]) tableIdent["Table Identifier"] @@ -83,6 +90,8 @@ literal["literal"] valSep[","] begin --> explain +explain --> queryPlan +queryPlan --> insert begin --> insert explain --> insert insert --> into diff --git a/compiler/ast.go b/compiler/ast.go index 6a608e3..716d6a5 100644 --- a/compiler/ast.go +++ b/compiler/ast.go @@ -7,7 +7,8 @@ package compiler type Stmt interface{} type StmtBase struct { - Explain bool + Explain bool + ExplainQueryPlan bool } type SelectStmt struct { @@ -43,3 +44,32 @@ type InsertStmt struct { ColNames []string ColValues []string } + +// type Expr interface { +// Type() string +// } + +// type BinaryExpr struct { +// Left Expr +// Operator string +// Right Expr +// } + +// type UnaryExpr struct { +// Operator string +// Operand Expr +// } + +// type ColumnRef struct { +// Table string +// Column string +// } + +// type IntLit struct { +// Value int +// } + +// type FunctionExpr struct { +// Name string +// Args []Expr +// } diff --git a/compiler/lexer.go b/compiler/lexer.go index 1b5ba26..2c3668e 100644 --- a/compiler/lexer.go +++ b/compiler/lexer.go @@ -41,6 +41,8 @@ const ( // Keywords where kw is keyword const ( kwExplain = "EXPLAIN" + kwQuery = "QUERY" + kwPlan = "PLAN" kwSelect = "SELECT" kwCount = "COUNT" kwFrom = "FROM" @@ -57,6 +59,8 @@ const ( var keywords = []string{ kwExplain, + kwQuery, + kwPlan, kwSelect, kwCount, kwFrom, diff --git a/compiler/lexer_test.go b/compiler/lexer_test.go index a902e8f..0864439 100644 --- a/compiler/lexer_test.go +++ b/compiler/lexer_test.go @@ -76,6 +76,20 @@ func TestLexSelect(t *testing.T) { {tkNumeric, "1"}, }, }, + { + sql: "EXPLAIN QUERY PLAN SELECT 1", + expected: []token{ + {tkKeyword, "EXPLAIN"}, + {tkWhitespace, " "}, + {tkKeyword, "QUERY"}, + {tkWhitespace, " "}, + {tkKeyword, "PLAN"}, + {tkWhitespace, " "}, + {tkKeyword, "SELECT"}, + {tkWhitespace, " "}, + {tkNumeric, "1"}, + }, + }, { sql: "SELECT 12", expected: []token{ @@ -95,10 +109,12 @@ func TestLexSelect(t *testing.T) { }, } for _, c := range cases { - ret := NewLexer(c.sql).Lex() - if !reflect.DeepEqual(ret, c.expected) { - t.Errorf("expected %#v got %#v", c.expected, ret) - } + t.Run(c.sql, func(t *testing.T) { + ret := NewLexer(c.sql).Lex() + if !reflect.DeepEqual(ret, c.expected) { + t.Errorf("expected %#v got %#v", c.expected, ret) + } + }) } } @@ -141,10 +157,12 @@ func TestLexCreate(t *testing.T) { }, } for _, c := range cases { - ret := NewLexer(c.sql).Lex() - if !reflect.DeepEqual(ret, c.expected) { - t.Errorf("expected %#v got %#v", c.expected, ret) - } + t.Run(c.sql, func(t *testing.T) { + ret := NewLexer(c.sql).Lex() + if !reflect.DeepEqual(ret, c.expected) { + t.Errorf("expected %#v got %#v", c.expected, ret) + } + }) } } @@ -195,9 +213,11 @@ func TestLexInsert(t *testing.T) { }, } for _, c := range cases { - ret := NewLexer(c.sql).Lex() - if !reflect.DeepEqual(ret, c.expected) { - t.Errorf("expected %#v got %#v", c.expected, ret) - } + t.Run(c.sql, func(t *testing.T) { + ret := NewLexer(c.sql).Lex() + if !reflect.DeepEqual(ret, c.expected) { + t.Errorf("expected %#v got %#v", c.expected, ret) + } + }) } } diff --git a/compiler/parser.go b/compiler/parser.go index 92e499c..fc8bc71 100644 --- a/compiler/parser.go +++ b/compiler/parser.go @@ -33,8 +33,19 @@ func (p *parser) parseStmt() (Stmt, error) { t := p.tokens[p.start] sb := &StmtBase{} if t.value == kwExplain { - sb.Explain = true t = p.nextNonSpace() + nv := p.peekNextNonSpace().value + if nv == kwQuery { + tp := p.nextNonSpace() + if tp.value == kwPlan { + sb.ExplainQueryPlan = true + t = p.nextNonSpace() + } else { + return nil, fmt.Errorf(tokenErr, p.tokens[p.end].value) + } + } else { + sb.Explain = true + } } switch t.value { case kwSelect: diff --git a/compiler/parser_test.go b/compiler/parser_test.go index e2afd9a..37f97a5 100644 --- a/compiler/parser_test.go +++ b/compiler/parser_test.go @@ -6,6 +6,7 @@ import ( ) type selectTestCase struct { + name string tokens []token expect Stmt } @@ -13,6 +14,7 @@ type selectTestCase struct { func TestParseSelect(t *testing.T) { cases := []selectTestCase{ { + name: "with explain", tokens: []token{ {tkKeyword, "EXPLAIN"}, {tkWhitespace, " "}, @@ -37,6 +39,37 @@ func TestParseSelect(t *testing.T) { }, }, { + name: "with explain query plan", + tokens: []token{ + {tkKeyword, "EXPLAIN"}, + {tkWhitespace, " "}, + {tkKeyword, "QUERY"}, + {tkWhitespace, " "}, + {tkKeyword, "PLAN"}, + {tkWhitespace, " "}, + {tkKeyword, "SELECT"}, + {tkWhitespace, " "}, + {tkPunctuator, "*"}, + {tkWhitespace, " "}, + {tkKeyword, "FROM"}, + {tkWhitespace, " "}, + {tkIdentifier, "foo"}, + }, + expect: &SelectStmt{ + StmtBase: &StmtBase{ + Explain: false, + ExplainQueryPlan: true, + }, + From: &From{ + TableName: "foo", + }, + ResultColumn: ResultColumn{ + All: true, + }, + }, + }, + { + name: "with count", tokens: []token{ {tkKeyword, "SELECT"}, {tkWhitespace, " "}, @@ -64,13 +97,15 @@ func TestParseSelect(t *testing.T) { }, } for _, c := range cases { - ret, err := NewParser(c.tokens).Parse() - if err != nil { - t.Errorf("want no err got err %s", err.Error()) - } - if !reflect.DeepEqual(ret, c.expect) { - t.Errorf("got %#v want %#v", ret, c.expect) - } + t.Run(c.name, func(t *testing.T) { + ret, err := NewParser(c.tokens).Parse() + if err != nil { + t.Errorf("want no err got err %s", err.Error()) + } + if !reflect.DeepEqual(ret, c.expect) { + t.Errorf("got %#v want %#v", ret, c.expect) + } + }) } } diff --git a/db/db.go b/db/db.go index d8cdc36..45f3a3a 100644 --- a/db/db.go +++ b/db/db.go @@ -5,7 +5,6 @@ package db import ( "errors" - "fmt" "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/kv" @@ -17,6 +16,11 @@ type executor interface { Execute(*vm.ExecutionPlan) *vm.ExecuteResult } +type statementPlanner interface { + ExecutionPlan() (*vm.ExecutionPlan, error) + QueryPlan() (*planner.QueryPlan, error) +} + type dbCatalog interface { GetColumns(tableOrIndexName string) ([]string, error) GetRootPageNumber(tableOrIndexName string) (int, error) @@ -49,9 +53,21 @@ func (db *DB) Execute(sql string) vm.ExecuteResult { if err != nil { return vm.ExecuteResult{Err: err} } + + planner := db.getPlannerFor(statement) + qp, err := planner.QueryPlan() + if err != nil { + return vm.ExecuteResult{Err: err} + } + if qp.ExplainQueryPlan { + return vm.ExecuteResult{ + Text: qp.ToString(), + } + } + var executeResult vm.ExecuteResult for { - executionPlan, err := db.getExecutionPlanFor(statement) + executionPlan, err := planner.ExecutionPlan() if err != nil { return vm.ExecuteResult{Err: err} } @@ -63,14 +79,14 @@ func (db *DB) Execute(sql string) vm.ExecuteResult { return executeResult } -func (db *DB) getExecutionPlanFor(statement compiler.Stmt) (*vm.ExecutionPlan, error) { +func (db *DB) getPlannerFor(statement compiler.Stmt) statementPlanner { switch s := statement.(type) { case *compiler.SelectStmt: - return planner.NewSelect(db.catalog).GetPlan(s) + return planner.NewSelect(db.catalog, s) case *compiler.CreateStmt: - return planner.NewCreate(db.catalog).GetPlan(s) + return planner.NewCreate(db.catalog, s) case *compiler.InsertStmt: - return planner.NewInsert(db.catalog).GetPlan(s) + return planner.NewInsert(db.catalog, s) } - return nil, fmt.Errorf("statement not supported") + panic("statement not supported") } diff --git a/planner/create.go b/planner/create.go index 97f961b..e2cface 100644 --- a/planner/create.go +++ b/planner/create.go @@ -22,66 +22,72 @@ type createCatalog interface { } type createPlanner struct { - catalog createCatalog + qp *createQueryPlanner + ep *createExecutionPlanner } -func NewCreate(catalog createCatalog) *createPlanner { +type createQueryPlanner struct { + catalog createCatalog + stmt *compiler.CreateStmt + queryPlan *createNode +} + +type createExecutionPlanner struct { + queryPlan *createNode + executionPlan *vm.ExecutionPlan +} + +func NewCreate(catalog createCatalog, stmt *compiler.CreateStmt) *createPlanner { return &createPlanner{ - catalog: catalog, + qp: &createQueryPlanner{ + catalog: catalog, + stmt: stmt, + }, + ep: &createExecutionPlanner{ + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), + }, } } -func (c *createPlanner) GetPlan(s *compiler.CreateStmt) (*vm.ExecutionPlan, error) { - executionPlan := vm.NewExecutionPlan(c.catalog.GetVersion()) - executionPlan.Explain = s.Explain - err := c.ensureTableDoesNotExist(s) +func (p *createPlanner) QueryPlan() (*QueryPlan, error) { + tableName, err := p.qp.ensureTableDoesNotExist() if err != nil { return nil, err } - jSchema, err := getSchemaString(s) + jSchema, err := p.qp.getSchemaString() if err != nil { return nil, err } - // objectType could be an index, trigger, or in this case a table. - objectType := "table" - // objectName is the name of the index, trigger, or table. - objectName := s.TableName - // tableName is name of the table this object is associated with. - tableName := s.TableName - commands := []vm.Command{} - commands = append(commands, &vm.InitCmd{P2: 1}) - commands = append(commands, &vm.TransactionCmd{P2: 1}) - commands = append(commands, &vm.CreateBTreeCmd{P2: 1}) - commands = append(commands, &vm.OpenWriteCmd{P1: 1, P2: 1}) - commands = append(commands, &vm.NewRowIdCmd{P1: 1, P2: 2}) - commands = append(commands, &vm.StringCmd{P1: 3, P4: objectType}) - commands = append(commands, &vm.StringCmd{P1: 4, P4: objectName}) - commands = append(commands, &vm.StringCmd{P1: 5, P4: tableName}) - commands = append(commands, &vm.CopyCmd{P1: 1, P2: 6}) - commands = append(commands, &vm.StringCmd{P1: 7, P4: string(jSchema)}) - commands = append(commands, &vm.MakeRecordCmd{P1: 3, P2: 5, P3: 8}) - commands = append(commands, &vm.InsertCmd{P1: 1, P2: 8, P3: 2}) - commands = append(commands, &vm.ParseSchemaCmd{}) - commands = append(commands, &vm.HaltCmd{}) - executionPlan.Commands = commands - return executionPlan, nil + createNode := &createNode{ + objectType: "table", + objectName: tableName, + tableName: tableName, + schema: jSchema, + } + qp := newQueryPlan(createNode, p.qp.stmt.ExplainQueryPlan) + p.ep.queryPlan = createNode + return qp, nil } -func (c *createPlanner) ensureTableDoesNotExist(s *compiler.CreateStmt) error { - if c.catalog.TableExists(s.TableName) { - return errTableExists +func (p *createQueryPlanner) ensureTableDoesNotExist() (string, error) { + tableName := p.stmt.TableName + if p.catalog.TableExists(tableName) { + return "", errTableExists } - return nil + return tableName, nil } -func getSchemaString(s *compiler.CreateStmt) (string, error) { - if err := ensurePrimaryKeyCount(s); err != nil { +func (p *createQueryPlanner) getSchemaString() (string, error) { + if err := p.ensurePrimaryKeyCount(); err != nil { return "", err } - if err := ensurePrimaryKeyInteger(s); err != nil { + if err := p.ensurePrimaryKeyInteger(); err != nil { return "", err } - jSchema, err := schemaFrom(s).ToJSON() + jSchema, err := p.schemaFrom().ToJSON() if err != nil { return "", err } @@ -91,14 +97,14 @@ func getSchemaString(s *compiler.CreateStmt) (string, error) { // The id column must be an integer. The index key is capable of being something // other than an integer, but is not worth implementing at the moment. Integer // primary keys are superior for auto incrementing and being unique. -func ensurePrimaryKeyInteger(s *compiler.CreateStmt) error { - hasPK := slices.ContainsFunc(s.ColDefs, func(cd compiler.ColDef) bool { +func (p *createQueryPlanner) ensurePrimaryKeyInteger() error { + hasPK := slices.ContainsFunc(p.stmt.ColDefs, func(cd compiler.ColDef) bool { return cd.PrimaryKey }) if !hasPK { return nil } - hasIntegerPK := slices.ContainsFunc(s.ColDefs, func(cd compiler.ColDef) bool { + hasIntegerPK := slices.ContainsFunc(p.stmt.ColDefs, func(cd compiler.ColDef) bool { return cd.PrimaryKey && cd.ColType == "INTEGER" }) if !hasIntegerPK { @@ -108,9 +114,9 @@ func ensurePrimaryKeyInteger(s *compiler.CreateStmt) error { } // Only one primary key is supported at this time. -func ensurePrimaryKeyCount(s *compiler.CreateStmt) error { +func (p *createQueryPlanner) ensurePrimaryKeyCount() error { count := 0 - for _, cd := range s.ColDefs { + for _, cd := range p.stmt.ColDefs { if cd.PrimaryKey { count += 1 } @@ -121,11 +127,11 @@ func ensurePrimaryKeyCount(s *compiler.CreateStmt) error { return nil } -func schemaFrom(s *compiler.CreateStmt) *kv.TableSchema { +func (p *createQueryPlanner) schemaFrom() *kv.TableSchema { schema := kv.TableSchema{ Columns: []kv.TableColumn{}, } - for _, cd := range s.ColDefs { + for _, cd := range p.stmt.ColDefs { schema.Columns = append(schema.Columns, kv.TableColumn{ Name: cd.ColName, ColType: cd.ColType, @@ -134,3 +140,28 @@ func schemaFrom(s *compiler.CreateStmt) *kv.TableSchema { } return &schema } + +func (cp *createPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { + if cp.qp.queryPlan == nil { + _, err := cp.QueryPlan() + if err != nil { + return nil, err + } + } + p := cp.ep + p.executionPlan.Append(&vm.InitCmd{P2: 1}) + p.executionPlan.Append(&vm.TransactionCmd{P2: 1}) + p.executionPlan.Append(&vm.CreateBTreeCmd{P2: 1}) + p.executionPlan.Append(&vm.OpenWriteCmd{P1: 1, P2: 1}) + p.executionPlan.Append(&vm.NewRowIdCmd{P1: 1, P2: 2}) + p.executionPlan.Append(&vm.StringCmd{P1: 3, P4: p.queryPlan.objectType}) + p.executionPlan.Append(&vm.StringCmd{P1: 4, P4: p.queryPlan.objectName}) + p.executionPlan.Append(&vm.StringCmd{P1: 5, P4: p.queryPlan.tableName}) + p.executionPlan.Append(&vm.CopyCmd{P1: 1, P2: 6}) + p.executionPlan.Append(&vm.StringCmd{P1: 7, P4: string(p.queryPlan.schema)}) + p.executionPlan.Append(&vm.MakeRecordCmd{P1: 3, P2: 5, P3: 8}) + p.executionPlan.Append(&vm.InsertCmd{P1: 1, P2: 8, P3: 2}) + p.executionPlan.Append(&vm.ParseSchemaCmd{}) + p.executionPlan.Append(&vm.HaltCmd{}) + return p.executionPlan, nil +} diff --git a/planner/create_test.go b/planner/create_test.go index e71f934..e67fc9e 100644 --- a/planner/create_test.go +++ b/planner/create_test.go @@ -70,7 +70,7 @@ func TestCreateWithNoIDColumn(t *testing.T) { &vm.ParseSchemaCmd{}, &vm.HaltCmd{}, } - plan, err := NewCreate(mc).GetPlan(stmt) + plan, err := NewCreate(mc, stmt).ExecutionPlan() if err != nil { t.Fatal(err.Error()) } @@ -129,7 +129,7 @@ func TestCreateWithAlternateNamedIDColumn(t *testing.T) { &vm.ParseSchemaCmd{}, &vm.HaltCmd{}, } - plan, err := NewCreate(mc).GetPlan(stmt) + plan, err := NewCreate(mc, stmt).ExecutionPlan() if err != nil { t.Fatal(err.Error()) } @@ -153,7 +153,7 @@ func TestCreatePrimaryKeyWithTextType(t *testing.T) { }, } mc := &mockCreateCatalog{} - _, err := NewCreate(mc).GetPlan(stmt) + _, err := NewCreate(mc, stmt).ExecutionPlan() if !errors.Is(err, errInvalidPKColumnType) { t.Fatalf("got error %s expected error %s", err, errInvalidPKColumnType) } @@ -171,7 +171,7 @@ func TestCreateWithExistingTable(t *testing.T) { }, } mc := &mockCreateCatalog{tableExistsRes: true} - _, err := NewCreate(mc).GetPlan(stmt) + _, err := NewCreate(mc, stmt).ExecutionPlan() if !errors.Is(err, errTableExists) { t.Fatalf("got error %s expected error %s", err, errTableExists) } @@ -195,7 +195,7 @@ func TestCreateWithMoreThanOnePrimaryKey(t *testing.T) { }, } mc := &mockCreateCatalog{} - _, err := NewCreate(mc).GetPlan(stmt) + _, err := NewCreate(mc, stmt).ExecutionPlan() if !errors.Is(err, errMoreThanOnePK) { t.Fatalf("got error %s expected error %s", err, errMoreThanOnePK) } diff --git a/planner/insert.go b/planner/insert.go index beec403..07c1f09 100644 --- a/planner/insert.go +++ b/planner/insert.go @@ -23,84 +23,121 @@ type insertCatalog interface { } type insertPlanner struct { - catalog insertCatalog + qp *insertQueryPlanner + ep *insertExecutionPlanner } -func NewInsert(catalog insertCatalog) *insertPlanner { +type insertQueryPlanner struct { + catalog insertCatalog + stmt *compiler.InsertStmt + queryPlan *insertNode +} + +type insertExecutionPlanner struct { + queryPlan *insertNode + executionPlan *vm.ExecutionPlan +} + +func NewInsert(catalog insertCatalog, stmt *compiler.InsertStmt) *insertPlanner { return &insertPlanner{ - catalog: catalog, + qp: &insertQueryPlanner{ + catalog: catalog, + stmt: stmt, + }, + ep: &insertExecutionPlanner{ + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), + }, } } -func (p *insertPlanner) GetPlan(s *compiler.InsertStmt) (*vm.ExecutionPlan, error) { - executionPlan := vm.NewExecutionPlan(p.catalog.GetVersion()) - executionPlan.Explain = s.Explain - rootPageNumber, err := p.catalog.GetRootPageNumber(s.TableName) +func (ip *insertPlanner) QueryPlan() (*QueryPlan, error) { + p := ip.qp + rootPage, err := p.catalog.GetRootPageNumber(p.stmt.TableName) if err != nil { return nil, errTableNotExist } - catalogColumnNames, err := p.catalog.GetColumns(s.TableName) + catalogColumnNames, err := p.catalog.GetColumns(p.stmt.TableName) if err != nil { return nil, err } - cursorId := 1 - commands := []vm.Command{} - commands = append(commands, &vm.InitCmd{P2: 1}) - commands = append(commands, &vm.TransactionCmd{P2: 1}) - commands = append(commands, &vm.OpenWriteCmd{P1: cursorId, P2: rootPageNumber}) - - if err := checkValuesMatchColumns(s); err != nil { + if err := checkValuesMatchColumns(p.stmt); err != nil { return nil, err } - - pkColumn, err := p.catalog.GetPrimaryKeyColumn(s.TableName) + pkColumn, err := p.catalog.GetPrimaryKeyColumn(p.stmt.TableName) if err != nil { return nil, err } - for valueIdx := range len(s.ColValues) / len(s.ColNames) { + insertNode := &insertNode{ + rootPage: rootPage, + catalogColumnNames: catalogColumnNames, + pkColumn: pkColumn, + colNames: p.stmt.ColNames, + colValues: p.stmt.ColValues, + } + p.queryPlan = insertNode + ip.ep.queryPlan = insertNode + return newQueryPlan(insertNode, p.stmt.ExplainQueryPlan), nil +} + +func (ip *insertPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { + if ip.qp.queryPlan == nil { + _, err := ip.QueryPlan() + if err != nil { + return nil, err + } + } + ep := ip.ep + cursorId := 1 + ep.executionPlan.Append(&vm.InitCmd{P2: 1}) + ep.executionPlan.Append(&vm.TransactionCmd{P2: 1}) + ep.executionPlan.Append(&vm.OpenWriteCmd{P1: cursorId, P2: ep.queryPlan.rootPage}) + + for valueIdx := range len(ep.queryPlan.colValues) / len(ep.queryPlan.colNames) { keyRegister := 1 statementIDIdx := -1 - if pkColumn != "" { - statementIDIdx = slices.IndexFunc(s.ColNames, func(s string) bool { - return s == pkColumn + if ep.queryPlan.pkColumn != "" { + statementIDIdx = slices.IndexFunc(ep.queryPlan.colNames, func(s string) bool { + return s == ep.queryPlan.pkColumn }) } if statementIDIdx == -1 { - commands = append(commands, &vm.NewRowIdCmd{P1: rootPageNumber, P2: keyRegister}) + ep.executionPlan.Append(&vm.NewRowIdCmd{P1: ep.queryPlan.rootPage, P2: keyRegister}) } else { - rowId, err := strconv.Atoi(s.ColValues[statementIDIdx+valueIdx*len(s.ColNames)]) + rowId, err := strconv.Atoi(ep.queryPlan.colValues[statementIDIdx+valueIdx*len(ep.queryPlan.colNames)]) if err != nil { return nil, err } - integerCmdIdx := len(commands) + 2 - commands = append(commands, &vm.NotExistsCmd{P1: rootPageNumber, P2: integerCmdIdx, P3: rowId}) - commands = append(commands, &vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}) - commands = append(commands, &vm.IntegerCmd{P1: rowId, P2: keyRegister}) + integerCmdIdx := len(ep.executionPlan.Commands) + 2 + ep.executionPlan.Append(&vm.NotExistsCmd{P1: ep.queryPlan.rootPage, P2: integerCmdIdx, P3: rowId}) + ep.executionPlan.Append(&vm.HaltCmd{P1: 1, P4: "pk unique constraint violated"}) + ep.executionPlan.Append(&vm.IntegerCmd{P1: rowId, P2: keyRegister}) } registerIdx := keyRegister - for _, catalogColumnName := range catalogColumnNames { - if catalogColumnName != "" && catalogColumnName == pkColumn { + for _, catalogColumnName := range ep.queryPlan.catalogColumnNames { + if catalogColumnName != "" && catalogColumnName == ep.queryPlan.pkColumn { continue } registerIdx += 1 vIdx := -1 - for i, statementColumnName := range s.ColNames { + for i, statementColumnName := range ep.queryPlan.colNames { if statementColumnName == catalogColumnName { - vIdx = i + (valueIdx * len(s.ColNames)) + vIdx = i + (valueIdx * len(ep.queryPlan.colNames)) } } if vIdx == -1 { return nil, fmt.Errorf("%w %s", errMissingColumnName, catalogColumnName) } - commands = append(commands, &vm.StringCmd{P1: registerIdx, P4: s.ColValues[vIdx]}) + ep.executionPlan.Append(&vm.StringCmd{P1: registerIdx, P4: ep.queryPlan.colValues[vIdx]}) } - commands = append(commands, &vm.MakeRecordCmd{P1: 2, P2: registerIdx - 1, P3: registerIdx + 1}) - commands = append(commands, &vm.InsertCmd{P1: rootPageNumber, P2: registerIdx + 1, P3: keyRegister}) + ep.executionPlan.Append(&vm.MakeRecordCmd{P1: 2, P2: registerIdx - 1, P3: registerIdx + 1}) + ep.executionPlan.Append(&vm.InsertCmd{P1: ep.queryPlan.rootPage, P2: registerIdx + 1, P3: keyRegister}) } - commands = append(commands, &vm.HaltCmd{}) - executionPlan.Commands = commands - return executionPlan, nil + ep.executionPlan.Append(&vm.HaltCmd{}) + return ep.executionPlan, nil } func checkValuesMatchColumns(s *compiler.InsertStmt) error { diff --git a/planner/insert_test.go b/planner/insert_test.go index c1f1dec..4a8f073 100644 --- a/planner/insert_test.go +++ b/planner/insert_test.go @@ -77,7 +77,7 @@ func TestInsertWithoutPrimaryKey(t *testing.T) { } mockCatalog := &mockInsertCatalog{} mockCatalog.columnsReturn = []string{"first", "last"} - plan, err := NewInsert(mockCatalog).GetPlan(ast) + plan, err := NewInsert(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } @@ -117,7 +117,7 @@ func TestInsertWithPrimaryKey(t *testing.T) { columnsReturn: []string{"id", "first"}, pkColumnName: "id", } - plan, err := NewInsert(mockCatalog).GetPlan(ast) + plan, err := NewInsert(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } @@ -157,7 +157,7 @@ func TestInsertWithPrimaryKeyMiddleOrder(t *testing.T) { columnsReturn: []string{"id", "first"}, pkColumnName: "id", } - plan, err := NewInsert(mockCatalog).GetPlan(ast) + plan, err := NewInsert(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } @@ -176,7 +176,7 @@ func TestInsertIntoNonExistingTable(t *testing.T) { ColValues: []string{}, } mockCatalog := &mockInsertCatalog{} - _, err := NewInsert(mockCatalog).GetPlan(ast) + _, err := NewInsert(mockCatalog, ast).ExecutionPlan() if !errors.Is(err, errTableNotExist) { t.Fatalf("expected err %s got err %s", errTableNotExist, err) } @@ -197,7 +197,7 @@ func TestInsertValuesNotMatchingColumns(t *testing.T) { }, } mockCatalog := &mockInsertCatalog{} - _, err := NewInsert(mockCatalog).GetPlan(ast) + _, err := NewInsert(mockCatalog, ast).ExecutionPlan() if !errors.Is(err, errValuesNotMatch) { t.Fatalf("expected err %s got err %s", errValuesNotMatch, err) } @@ -215,7 +215,7 @@ func TestInsertIntoNonExistingColumn(t *testing.T) { }, } mockCatalog := &mockInsertCatalog{} - _, err := NewInsert(mockCatalog).GetPlan(ast) + _, err := NewInsert(mockCatalog, ast).ExecutionPlan() if !errors.Is(err, errMissingColumnName) { t.Fatalf("expected err %s got err %s", errMissingColumnName, err) } diff --git a/planner/node.go b/planner/node.go new file mode 100644 index 0000000..48004d6 --- /dev/null +++ b/planner/node.go @@ -0,0 +1,69 @@ +package planner + +// logicalNode defines the interface for a node in the query plan tree. +type logicalNode interface { + children() []logicalNode + print() string +} + +// projectNode defines what columns should be projected. +type projectNode struct { + projections []projection + child logicalNode +} + +type projection struct { + isCount bool + colName string +} + +// scanNode represents a full scan on a table +type scanNode struct { + // tableName is the name of the table to be scanned + tableName string + // rootPage is the valid page number corresponding to the table + rootPage int + // scanColumns contains information about how the scan will project columns + scanColumns []scanColumn +} + +type scanColumn struct { + // isPrimaryKey means the column will be a key instead of a nth value. + isPrimaryKey bool + // colIdx is the nth column for non primary key values. + colIdx int +} + +// countNode represents a special optimization when a table needs a full count +// with no filtering or other projections. +type countNode struct { + // tableName is the name of the table to be scanned + tableName string + // rootPage is the valid page number corresponding to the table + rootPage int +} + +type joinNode struct { + left logicalNode + right logicalNode + operation string +} + +type createNode struct { + // objectName is the name of the index, trigger, or table. + objectName string + // objectType could be an index, trigger, or in this case a table. + objectType string + // tableName is name of the table this object is associated with. + tableName string + // schema is the json serialized schema definition for the object. + schema string +} + +type insertNode struct { + rootPage int + catalogColumnNames []string + pkColumn string + colNames []string + colValues []string +} diff --git a/planner/plan.go b/planner/plan.go new file mode 100644 index 0000000..becb5f9 --- /dev/null +++ b/planner/plan.go @@ -0,0 +1,153 @@ +package planner + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +// QueryPlan contains the query plan tree. It is capable of converting the tree +// to a string representation for a query prefixed with `EXPLAIN QUERY PLAN`. +type QueryPlan struct { + plan string + root logicalNode + ExplainQueryPlan bool +} + +func newQueryPlan(root logicalNode, explainQueryPlan bool) *QueryPlan { + return &QueryPlan{ + root: root, + ExplainQueryPlan: explainQueryPlan, + } +} + +func (p *QueryPlan) ToString() string { + qp := &QueryPlan{} + qp.walk(p.root, 0) + qp.trimLeft() + return qp.connectSiblings() +} + +func (p *QueryPlan) walk(root logicalNode, depth int) { + p.visit(root, depth+1) + for _, c := range root.children() { + p.walk(c, depth+1) + } +} + +func (p *QueryPlan) visit(ln logicalNode, depth int) { + padding := "" + for i := 0; i < depth; i += 1 { + padding += " " + } + if depth == 1 { + padding += " ── " + } else { + padding += " └─ " + } + p.plan += fmt.Sprintf("%s%s\n", padding, ln.print()) +} + +func (p *QueryPlan) trimLeft() { + trimBy := 4 + newPlan := []string{} + for _, row := range strings.Split(p.plan, "\n") { + newRow := row + if len(row) >= trimBy { + newRow = row[trimBy:] + } + newPlan = append(newPlan, newRow) + } + p.plan = strings.Join(newPlan, "\n") +} + +func (p *QueryPlan) connectSiblings() string { + planMatrix := strings.Split(p.plan, "\n") + for rowIdx := len(planMatrix) - 1; 0 < rowIdx; rowIdx -= 1 { + row := planMatrix[rowIdx] + for charIdx, char := range row { + if char == '└' { + for backwardsRowIdx := rowIdx - 1; 0 < backwardsRowIdx; backwardsRowIdx -= 1 { + if len(planMatrix[backwardsRowIdx]) < charIdx { + continue + } + char, _ := utf8.DecodeRuneInString(planMatrix[backwardsRowIdx][charIdx:]) + if char == ' ' { + out := []rune(planMatrix[backwardsRowIdx]) + out[charIdx] = '|' + planMatrix[backwardsRowIdx] = string(out) + } + if char == '└' { + out := []rune(planMatrix[backwardsRowIdx]) + out[charIdx] = '├' + planMatrix[backwardsRowIdx] = string(out) + } + } + } + } + } + return strings.Join(planMatrix, "\n") +} + +func (p *projectNode) print() string { + list := "(" + for i, proj := range p.projections { + list += proj.print() + if i+1 < len(p.projections) { + list += ", " + } + } + list += ")" + return "project" + list +} + +func (p *projection) print() string { + if p.isCount { + return "count(*)" + } + return p.colName +} + +func (s *scanNode) print() string { + return fmt.Sprintf("scan table %s", s.tableName) +} + +func (c *countNode) print() string { + return fmt.Sprintf("count table %s", c.tableName) +} + +func (j *joinNode) print() string { + return fmt.Sprint(j.operation) +} + +func (c *createNode) print() string { + return fmt.Sprintf("create table %s", c.tableName) +} + +func (i *insertNode) print() string { + return "insert" +} + +func (p *projectNode) children() []logicalNode { + return []logicalNode{p.child} +} + +func (s *scanNode) children() []logicalNode { + return []logicalNode{} +} + +func (c *countNode) children() []logicalNode { + return []logicalNode{} +} + +func (j *joinNode) children() []logicalNode { + return []logicalNode{j.left, j.right} +} + +func (c *createNode) children() []logicalNode { + return []logicalNode{} +} + +func (i *insertNode) children() []logicalNode { + return []logicalNode{} +} diff --git a/planner/plan_test.go b/planner/plan_test.go new file mode 100644 index 0000000..2fcbaf4 --- /dev/null +++ b/planner/plan_test.go @@ -0,0 +1,48 @@ +package planner + +import "testing" + +func TestExplainQueryPlan(t *testing.T) { + root := &projectNode{ + projections: []projection{ + {colName: "id"}, + {colName: "first_name"}, + {colName: "last_name"}, + }, + child: &joinNode{ + operation: "join", + left: &joinNode{ + operation: "join", + left: &scanNode{ + tableName: "foo", + }, + right: &joinNode{ + operation: "join", + left: &scanNode{ + tableName: "baz", + }, + right: &scanNode{ + tableName: "buzz", + }, + }, + }, + right: &scanNode{ + tableName: "bar", + }, + }, + } + qp := newQueryPlan(root, true) + formattedResult := qp.ToString() + expectedResult := "" + + " ── project(id, first_name, last_name)\n" + + " └─ join\n" + + " ├─ join\n" + + " | ├─ scan table foo\n" + + " | └─ join\n" + + " | ├─ scan table baz\n" + + " | └─ scan table buzz\n" + + " └─ scan table bar\n" + if formattedResult != expectedResult { + t.Fatalf("got\n%s\nwant\n%s", formattedResult, expectedResult) + } +} diff --git a/planner/select.go b/planner/select.go index 456261f..3b36eb9 100644 --- a/planner/select.go +++ b/planner/select.go @@ -1,18 +1,8 @@ package planner -// TODO -// This planner will eventually consist of smaller parts. -// 1. Something like a binder may be necessary which would validate the values -// in the statement make sense given the current schema. -// 2. A logical query planner which would transform the ast into a relational -// algebra like structure. This structure would allow for optimizations like -// predicate push down. -// 3. Perhaps a physical planner which would maybe take into account statistics -// and indexes. -// Somewhere in these structures would be the ability to print a query plan that -// is higher level than the bytecode operations. A typical explain tree. - import ( + "errors" + "github.com/chirst/cdb/compiler" "github.com/chirst/cdb/vm" ) @@ -25,68 +15,207 @@ type selectCatalog interface { GetPrimaryKeyColumn(tableName string) (string, error) } +// selectPlanner is capable of generating a logical query plan and a physical +// execution plan for a select statement. The planners within are separated by +// their responsibility. Notice a statement or catalog is not shared with with +// the execution planner. This is by design since the logical query planner also +// performs binding. type selectPlanner struct { + qp *selectQueryPlanner + ep *selectExecutionPlanner +} + +// selectQueryPlanner converts an AST to a logical query plan. Along the way it +// also validates the AST makes sense with the catalog (a process known as +// binding). +type selectQueryPlanner struct { + // catalog contains the schema catalog selectCatalog + // stmt contains the AST + stmt *compiler.SelectStmt + // queryPlan contains the logical plan. The root node must be a projection. + queryPlan *projectNode } -func NewSelect(catalog selectCatalog) *selectPlanner { +// selectExecutionPlanner converts logical nodes in a query plan tree to +// bytecode that can be run by the vm. +type selectExecutionPlanner struct { + // queryPlan contains the logical plan. This node is populated by calling + // the QueryPlan method. + queryPlan *projectNode + // executionPlan contains the execution plan for the vm. This is built by + // calling ExecutionPlan. + executionPlan *vm.ExecutionPlan +} + +// NewSelect returns an instance of a select planner for the given AST. +func NewSelect(catalog selectCatalog, stmt *compiler.SelectStmt) *selectPlanner { return &selectPlanner{ - catalog: catalog, + qp: &selectQueryPlanner{ + catalog: catalog, + stmt: stmt, + }, + ep: &selectExecutionPlanner{ + executionPlan: vm.NewExecutionPlan( + catalog.GetVersion(), + stmt.Explain, + ), + }, } } -func (p *selectPlanner) GetPlan(s *compiler.SelectStmt) (*vm.ExecutionPlan, error) { - executionPlan := vm.NewExecutionPlan(p.catalog.GetVersion()) - executionPlan.Explain = s.Explain - resultHeader := []string{} - cols, err := p.catalog.GetColumns(s.From.TableName) +// QueryPlan generates the query plan tree for the planner. +func (p *selectPlanner) QueryPlan() (*QueryPlan, error) { + tableName := p.qp.stmt.From.TableName + rootPageNumber, err := p.qp.catalog.GetRootPageNumber(tableName) if err != nil { return nil, err } - if s.ResultColumn.All { - resultHeader = append(resultHeader, cols...) - } else if s.ResultColumn.Count { - resultHeader = append(resultHeader, "") + var child logicalNode + if p.qp.stmt.ResultColumn.All { + scanColumns, err := p.qp.getScanColumns() + if err != nil { + return nil, err + } + child = &scanNode{ + tableName: tableName, + rootPage: rootPageNumber, + scanColumns: scanColumns, + } + } else { + child = &countNode{ + tableName: tableName, + rootPage: rootPageNumber, + } } - rootPage, err := p.catalog.GetRootPageNumber(s.From.TableName) + projections, err := p.qp.getProjections() if err != nil { return nil, err } - cursorId := 1 - commands := []vm.Command{} - commands = append(commands, &vm.InitCmd{P2: 1}) - commands = append(commands, &vm.TransactionCmd{P2: 0}) - commands = append(commands, &vm.OpenReadCmd{P1: cursorId, P2: rootPage}) - if s.ResultColumn.All { - rwc := &vm.RewindCmd{P1: cursorId} - commands = append(commands, rwc) - pkColName, err := p.catalog.GetPrimaryKeyColumn(s.From.TableName) + p.qp.queryPlan = &projectNode{ + projections: projections, + child: child, + } + p.ep.queryPlan = p.qp.queryPlan + return newQueryPlan(p.qp.queryPlan, p.qp.stmt.ExplainQueryPlan), nil +} + +func (p *selectQueryPlanner) getScanColumns() ([]scanColumn, error) { + pkColName, err := p.catalog.GetPrimaryKeyColumn(p.stmt.From.TableName) + if err != nil { + return nil, err + } + cols, err := p.catalog.GetColumns(p.stmt.From.TableName) + if err != nil { + return nil, err + } + scanColumns := []scanColumn{} + idx := 0 + for _, c := range cols { + if c == pkColName { + scanColumns = append(scanColumns, scanColumn{ + isPrimaryKey: c == pkColName, + }) + } else { + scanColumns = append(scanColumns, scanColumn{ + colIdx: idx, + }) + idx += 1 + } + } + return scanColumns, nil +} + +func (p *selectQueryPlanner) getProjections() ([]projection, error) { + if p.stmt.ResultColumn.All { + cols, err := p.catalog.GetColumns(p.stmt.From.TableName) if err != nil { return nil, err } - registerIdx := 1 - gap := 0 - colIdx := 0 + projections := []projection{} for _, c := range cols { - if c == pkColName { - commands = append(commands, &vm.RowIdCmd{P1: cursorId, P2: registerIdx}) - } else { - commands = append(commands, &vm.ColumnCmd{P1: cursorId, P2: colIdx, P3: registerIdx}) - colIdx += 1 - } - registerIdx += 1 - gap += 1 + projections = append(projections, projection{ + colName: c, + }) } - commands = append(commands, &vm.ResultRowCmd{P1: 1, P2: gap}) - commands = append(commands, &vm.NextCmd{P1: cursorId, P2: 4}) - commands = append(commands, &vm.HaltCmd{}) - rwc.P2 = len(commands) - 1 - } else { - commands = append(commands, &vm.CountCmd{P1: cursorId, P2: 1}) - commands = append(commands, &vm.ResultRowCmd{P1: 1, P2: 1}) - commands = append(commands, &vm.HaltCmd{}) + return projections, nil + } else if p.stmt.ResultColumn.Count { + return []projection{ + { + isCount: true, + }, + }, nil } - executionPlan.Commands = commands - executionPlan.ResultHeader = resultHeader - return executionPlan, nil + return nil, errors.New("unhandled projection") +} + +// ExecutionPlan returns the bytecode execution plan for the planner. Calling +// QueryPlan is not a prerequisite to this method as it will be called by +// ExecutionPlan if needed. +func (sp *selectPlanner) ExecutionPlan() (*vm.ExecutionPlan, error) { + if sp.qp.queryPlan == nil { + _, err := sp.QueryPlan() + if err != nil { + return nil, err + } + } + p := sp.ep + p.resultHeader() + p.buildInit() + + switch c := p.queryPlan.child.(type) { + case *scanNode: + if err := p.buildScan(c); err != nil { + return nil, err + } + case *countNode: + p.buildOptimizedCountScan(c) + default: + panic("unhandled node") + } + p.executionPlan.Append(&vm.HaltCmd{}) + return p.executionPlan, nil +} + +func (p *selectExecutionPlanner) resultHeader() { + resultHeader := []string{} + for _, p := range p.queryPlan.projections { + resultHeader = append(resultHeader, p.colName) + } + p.executionPlan.ResultHeader = resultHeader +} + +func (p *selectExecutionPlanner) buildInit() { + p.executionPlan.Append(&vm.InitCmd{P2: 1}) + p.executionPlan.Append(&vm.TransactionCmd{P2: 0}) +} + +func (p *selectExecutionPlanner) buildScan(n *scanNode) error { + const cursorId = 1 + p.executionPlan.Append(&vm.OpenReadCmd{P1: cursorId, P2: n.rootPage}) + + rwc := &vm.RewindCmd{P1: cursorId} + p.executionPlan.Append(rwc) + + for i, c := range n.scanColumns { + register := i + 1 + if c.isPrimaryKey { + p.executionPlan.Append(&vm.RowIdCmd{P1: cursorId, P2: register}) + } else { + p.executionPlan.Append(&vm.ColumnCmd{P1: cursorId, P2: c.colIdx, P3: register}) + } + } + p.executionPlan.Append(&vm.ResultRowCmd{P1: 1, P2: len(n.scanColumns)}) + + p.executionPlan.Append(&vm.NextCmd{P1: cursorId, P2: 4}) + + rwc.P2 = len(p.executionPlan.Commands) + return nil +} + +func (p *selectExecutionPlanner) buildOptimizedCountScan(n *countNode) { + const cursorId = 1 + p.executionPlan.Append(&vm.OpenReadCmd{P1: cursorId, P2: n.rootPage}) + p.executionPlan.Append(&vm.CountCmd{P1: cursorId, P2: 1}) + p.executionPlan.Append(&vm.ResultRowCmd{P1: 1, P2: 1}) } diff --git a/planner/select_test.go b/planner/select_test.go index 808c12d..5263175 100644 --- a/planner/select_test.go +++ b/planner/select_test.go @@ -57,7 +57,7 @@ func TestGetPlan(t *testing.T) { mockCatalog := &mockSelectCatalog{} mockCatalog.primaryKeyColumnName = "id" mockCatalog.columns = []string{"name", "id", "age"} - plan, err := NewSelect(mockCatalog).GetPlan(ast) + plan, err := NewSelect(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } @@ -91,7 +91,7 @@ func TestGetPlanPKMiddleOrdinal(t *testing.T) { } mockCatalog := &mockSelectCatalog{} mockCatalog.primaryKeyColumnName = "id" - plan, err := NewSelect(mockCatalog).GetPlan(ast) + plan, err := NewSelect(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } @@ -121,7 +121,7 @@ func TestGetCountAggregate(t *testing.T) { }, } mockCatalog := &mockSelectCatalog{} - plan, err := NewSelect(mockCatalog).GetPlan(ast) + plan, err := NewSelect(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } @@ -154,7 +154,7 @@ func TestGetPlanNoPrimaryKey(t *testing.T) { }, } mockCatalog := &mockSelectCatalog{} - plan, err := NewSelect(mockCatalog).GetPlan(ast) + plan, err := NewSelect(mockCatalog, ast).ExecutionPlan() if err != nil { t.Errorf("expected no err got err %s", err) } diff --git a/vm/vm.go b/vm/vm.go index 7198d3a..45448f1 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -74,12 +74,17 @@ type ExecutionPlan struct { Version string } -func NewExecutionPlan(version string) *ExecutionPlan { +func NewExecutionPlan(version string, explain bool) *ExecutionPlan { return &ExecutionPlan{ Version: version, + Explain: explain, } } +func (e *ExecutionPlan) Append(command Command) { + e.Commands = append(e.Commands, command) +} + // Execute performs the execution plan provided. If the execution plan is an // explain Execute does not execute the plan. If the plan is out of date with // the system catalog Execute will return ErrVersionChanged in the ExecuteResult diff --git a/vm/vm_test.go b/vm/vm_test.go index 99dfb28..57a860b 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -14,7 +14,7 @@ func TestExec(t *testing.T) { log.Fatal(err.Error()) } vm := New(kv) - ep := NewExecutionPlan(kv.GetCatalog().GetVersion()) + ep := NewExecutionPlan(kv.GetCatalog().GetVersion(), false) ep.Commands = []Command{ &InitCmd{P2: 1}, &TransactionCmd{}, @@ -40,7 +40,7 @@ func TestExecReturnsVersionErr(t *testing.T) { vm := New(kv) t.Run("for read", func(t *testing.T) { - ep := NewExecutionPlan("FakeVersion") + ep := NewExecutionPlan("FakeVersion", false) ep.Commands = []Command{ &InitCmd{P2: 1}, &TransactionCmd{P2: 0}, @@ -54,7 +54,7 @@ func TestExecReturnsVersionErr(t *testing.T) { }) t.Run("for write", func(t *testing.T) { - ep := NewExecutionPlan("FakeVersion") + ep := NewExecutionPlan("FakeVersion", false) ep.Commands = []Command{ &InitCmd{P2: 1}, &TransactionCmd{P2: 1},