diff --git a/pkg/sql/explain_bundle_test.go b/pkg/sql/explain_bundle_test.go index 0a75d1101231..829d5292d21b 100644 --- a/pkg/sql/explain_bundle_test.go +++ b/pkg/sql/explain_bundle_test.go @@ -276,11 +276,18 @@ CREATE TABLE users(id UUID DEFAULT gen_random_uuid() PRIMARY KEY, promo_id INT R }) t.Run("foreign keys", func(t *testing.T) { + // All tables should be included in the stmt bundle, regardless of which + // one we query because all of them are considered "related" (even + // though we don't specify ON DELETE and ON UPDATE actions). + tableNames := []string{"parent", "child1", "child2", "grandchild1", "grandchild2"} r.Exec(t, "CREATE TABLE parent (pk INT PRIMARY KEY, v INT);") - r.Exec(t, "CREATE TABLE child (pk INT PRIMARY KEY, fk INT REFERENCES parent(pk));") + r.Exec(t, "CREATE TABLE child1 (pk INT PRIMARY KEY, fk INT REFERENCES parent(pk));") + r.Exec(t, "CREATE TABLE child2 (pk INT PRIMARY KEY, fk INT REFERENCES parent(pk));") + r.Exec(t, "CREATE TABLE grandchild1 (pk INT PRIMARY KEY, fk INT REFERENCES child1(pk));") + r.Exec(t, "CREATE TABLE grandchild2 (pk INT PRIMARY KEY, fk INT REFERENCES child2(pk));") contentCheck := func(name, contents string) error { if name == "schema.sql" { - for _, tableName := range []string{"parent", "child"} { + for _, tableName := range tableNames { if regexp.MustCompile("CREATE TABLE defaultdb.public."+tableName).FindString(contents) == "" { return errors.Newf( "could not find 'CREATE TABLE defaultdb.public.%s' in schema.sql:\n%s", tableName, contents) @@ -289,12 +296,12 @@ CREATE TABLE users(id UUID DEFAULT gen_random_uuid() PRIMARY KEY, promo_id INT R } return nil } - for _, tableName := range []string{"parent", "child"} { + for _, tableName := range tableNames { rows := r.QueryStr(t, "EXPLAIN ANALYZE (DEBUG) SELECT * FROM "+tableName) checkBundle( t, fmt.Sprint(rows), "child", contentCheck, false, /* expectErrors */ - base, plans, "stats-defaultdb.public.parent.sql", "stats-defaultdb.public.child.sql", - "distsql.html vec.txt vec-v.txt", + base, plans, "stats-defaultdb.public.parent.sql", "stats-defaultdb.public.child1.sql", "stats-defaultdb.public.child2.sql", + "stats-defaultdb.public.grandchild1.sql", "stats-defaultdb.public.grandchild2.sql", "distsql.html vec.txt vec-v.txt", ) } }) diff --git a/pkg/sql/opt/metadata.go b/pkg/sql/opt/metadata.go index 537eb1df9302..60cf0278e593 100644 --- a/pkg/sql/opt/metadata.go +++ b/pkg/sql/opt/metadata.go @@ -993,49 +993,43 @@ func (md *Metadata) getAllReferenceTables( var tableSet intsets.Fast var tableList []cat.DataSource var addForeignKeyReferencedTables func(tab cat.Table) + var addForeignKeyReferencingTables func(tab cat.Table) + // handleRelatedTables is a helper function that processes the given table + // if it hasn't been handled yet by adding all referenced and referencing + // table of the given one, including via transient (recursive) FK + // relationships. + handleRelatedTables := func(tabID cat.StableID) { + if !tableSet.Contains(int(tabID)) { + tableSet.Add(int(tabID)) + ds, _, err := catalog.ResolveDataSourceByID(ctx, cat.Flags{}, tabID) + if err != nil { + // This is a best-effort attempt to get all the tables, so don't + // error. + return + } + refTab, ok := ds.(cat.Table) + if !ok { + // This is a best-effort attempt to get all the tables, so don't + // error. + return + } + // We want to include all tables that we reference before adding + // ourselves, followed by all tables that reference us. + addForeignKeyReferencedTables(refTab) + tableList = append(tableList, ds) + addForeignKeyReferencingTables(refTab) + } + } addForeignKeyReferencedTables = func(tab cat.Table) { for i := 0; i < tab.OutboundForeignKeyCount(); i++ { tabID := tab.OutboundForeignKey(i).ReferencedTableID() - if !tableSet.Contains(int(tabID)) { - tableSet.Add(int(tabID)) - ds, _, err := catalog.ResolveDataSourceByID(ctx, cat.Flags{}, tabID) - if err != nil { - // This is a best-effort attempt to get all the tables, so don't error. - continue - } - refTab, ok := ds.(cat.Table) - if !ok { - // This is a best-effort attempt to get all the tables, so don't error. - continue - } - // We want to include all tables that we reference before adding - // ourselves. - addForeignKeyReferencedTables(refTab) - tableList = append(tableList, ds) - } + handleRelatedTables(tabID) } } - var addForeignKeyReferencingTables func(tab cat.Table) addForeignKeyReferencingTables = func(tab cat.Table) { for i := 0; i < tab.InboundForeignKeyCount(); i++ { tabID := tab.InboundForeignKey(i).OriginTableID() - if !tableSet.Contains(int(tabID)) { - tableSet.Add(int(tabID)) - ds, _, err := catalog.ResolveDataSourceByID(ctx, cat.Flags{}, tabID) - if err != nil { - // This is a best-effort attempt to get all the tables, so don't error. - continue - } - refTab, ok := ds.(cat.Table) - if !ok { - // This is a best-effort attempt to get all the tables, so don't error. - continue - } - // We want to include ourselves before all tables that reference - // us. - tableList = append(tableList, ds) - addForeignKeyReferencingTables(refTab) - } + handleRelatedTables(tabID) } } for i := range md.tables {