From 261bd7b7604389f2e74a3c87b740344dd5237cf5 Mon Sep 17 00:00:00 2001 From: gusiri Date: Mon, 17 Feb 2025 16:17:43 +0900 Subject: [PATCH 1/7] initial commit --- prover/protocol/distributed/distributed.go | 15 -- .../protocol/distributed/module_discoverer.go | 195 ++++++++++++++++++ 2 files changed, 195 insertions(+), 15 deletions(-) create mode 100644 prover/protocol/distributed/module_discoverer.go diff --git a/prover/protocol/distributed/distributed.go b/prover/protocol/distributed/distributed.go index 65798fea9..e76f554a3 100644 --- a/prover/protocol/distributed/distributed.go +++ b/prover/protocol/distributed/distributed.go @@ -4,7 +4,6 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/symbolic" ) type ModuleName = string @@ -25,20 +24,6 @@ type DistributedModule struct { GlobalLocal *wizard.CompiledIOP } -// ModuleDiscoverer a set of methods responsible for the horizontal splittings (i.e., splitting to modules) -type ModuleDiscoverer interface { - // Analyze is responsible for letting the module discoverer compute how to - // group best the columns into modules. - Analyze(comp *wizard.CompiledIOP) - NbModules() int - ModuleList() []ModuleName - FindModule(col ifaces.Column) ModuleName - // given a query and a module name it checks if the query is inside the module - ExpressionIsInModule(*symbolic.Expression, ModuleName) bool - QueryIsInModule(ifaces.Query, ModuleName) bool - ColumnIsInModule(col ifaces.Column, name ModuleName) bool -} - // This transforms the initial wizard. So it is not really the initial // wizard anymore. That means the caller can forget about "initialWizard" // after calling the function. diff --git a/prover/protocol/distributed/module_discoverer.go b/prover/protocol/distributed/module_discoverer.go new file mode 100644 index 000000000..c9ec54047 --- /dev/null +++ b/prover/protocol/distributed/module_discoverer.go @@ -0,0 +1,195 @@ +package distributed + +import ( + "sync" + + "github.com/consensys/linea-monorepo/prover/protocol/coin" + "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" + "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/symbolic" + "github.com/consensys/linea-monorepo/prover/utils" +) + +// ModuleDiscoverer implements the ModuleDiscovererInterface +type ModuleDiscoverer struct { + moduleMapping map[ifaces.Column]ModuleName + moduleNames []ModuleName + mutex sync.Mutex +} + +// ModuleDiscovererInterface defines methods for horizontal splitting (i.e., splitting into modules). +type ModuleDiscovererInterface interface { + // Analyze computes how to group columns into modules. + Analyze(comp *wizard.CompiledIOP) + NbModules() int + ModuleList() []ModuleName + FindModule(col ifaces.Column) ModuleName + ExpressionIsInModule(*symbolic.Expression, ModuleName) bool + QueryIsInModule(ifaces.Query, ModuleName) bool + ColumnIsInModule(col ifaces.Column, name ModuleName) bool +} + +// NewModuleDiscoverer initializes and returns a new ModuleDiscoverer instance. +func NewModuleDiscoverer() *ModuleDiscoverer { + return &ModuleDiscoverer{ + moduleMapping: make(map[ifaces.Column]ModuleName), + moduleNames: []ModuleName{}, + } +} + +// Analyze clusters columns into modules by iterating through global constraints (QueriesNoParams). +// Columns sharing the same global constraints are grouped into the same module. +func (md *ModuleDiscoverer) Analyze(comp *wizard.CompiledIOP) { + md.mutex.Lock() + defer md.mutex.Unlock() + + moduleIndex := 0 + columnClusters := make(map[ModuleName]map[ifaces.Column]bool) + + for _, qName := range comp.QueriesNoParams.AllUnignoredKeys() { + query := comp.QueriesNoParams.Data(qName) + + // Determine the columns connected by this global constraint + connectedColumns := getColumnsFromQuery(query) + + // Check if these columns belong to an existing module + moduleFound := false + for moduleName, cluster := range columnClusters { + if sharesColumns(cluster, connectedColumns) { + mergeClusters(cluster, connectedColumns) + moduleFound = true + md.assignModule(moduleName, connectedColumns) + break + } + } + + // If no module matches, create a new one + if !moduleFound { + moduleName := ModuleName("Module_" + string(moduleIndex)) + moduleIndex++ + columnClusters[moduleName] = connectedColumns + md.moduleNames = append(md.moduleNames, moduleName) + md.assignModule(moduleName, connectedColumns) + } + } +} + +// NbModules returns the total number of discovered modules. +func (md *ModuleDiscoverer) NbModules() int { + md.mutex.Lock() + defer md.mutex.Unlock() + return len(md.moduleNames) +} + +// ModuleList returns the list of all module names. +func (md *ModuleDiscoverer) ModuleList() []ModuleName { + md.mutex.Lock() + defer md.mutex.Unlock() + return append([]ModuleName(nil), md.moduleNames...) +} + +// FindModule returns the module name for the given column. +func (md *ModuleDiscoverer) FindModule(col ifaces.Column) ModuleName { + md.mutex.Lock() + defer md.mutex.Unlock() + return md.moduleMapping[col] +} + +// ExpressionIsInModule checks that all the columns (except verifiercol) in the expression are from the given module. +func (md *ModuleDiscoverer) ExpressionIsInModule(expr *symbolic.Expression, name ModuleName) bool { + board := expr.Board() + metadata := board.ListVariableMetadata() + + // by contradiction, if there is no metadata it belongs to the module. + if len(metadata) == 0 { + return true + } + + md.mutex.Lock() + defer md.mutex.Unlock() + + b := true + nCols := 0 + + for _, m := range metadata { + switch v := m.(type) { + case ifaces.Column: + if _, ok := v.(verifiercol.VerifierCol); !ok { + if !md.ColumnIsInModule(v, name) { + b = false + } + nCols++ + } + // The expression can involve random coins + case coin.Info, variables.X, variables.PeriodicSample, ifaces.Accessor: + // Do nothing + default: + utils.Panic("unknown type %T", metadata) + } + } + + if nCols == 0 { + panic("could not find any column in the expression") + } + return b +} + +// QueryIsInModule checks if the given query is inside the given module +func (md *ModuleDiscoverer) QueryIsInModule(query ifaces.Query, name ModuleName) bool { + md.mutex.Lock() + defer md.mutex.Unlock() + for _, col := range query.Columns() { + if md.FindModule(col) != name { + return false + } + } + return true +} + +// ColumnIsInModule checks that the given column is inside the given module. +func (md *ModuleDiscoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { + md.mutex.Lock() + defer md.mutex.Unlock() + return md.moduleMapping[col] == name +} + +// CoinIsInModule (placeholder): Extend logic to handle coins if needed. +func (md *ModuleDiscoverer) CoinIsInModule(coin ifaces.Coin, name ModuleName) bool { + // Logic for associating coins with modules goes here + return false +} + +// Utility: Extracts columns involved in a query. +func getColumnsFromQuery(query ifaces.Query) map[ifaces.Column]bool { + columns := make(map[ifaces.Column]bool) + for _, col := range query.Columns() { + columns[col] = true + } + return columns +} + +// Utility: Checks if two column clusters share any columns. +func sharesColumns(cluster map[ifaces.Column]bool, columns map[ifaces.Column]bool) bool { + for col := range columns { + if cluster[col] { + return true + } + } + return false +} + +// Utility: Merges columns into an existing cluster. +func mergeClusters(cluster map[ifaces.Column]bool, columns map[ifaces.Column]bool) { + for col := range columns { + cluster[col] = true + } +} + +// Utility: Assigns module information to columns in the mapping. +func (md *ModuleDiscoverer) assignModule(moduleName ModuleName, columns map[ifaces.Column]bool) { + for col := range columns { + md.moduleMapping[col] = moduleName + } +} From f3cb6596ac3e1f0f30d5eb6fae529af6428f9e14 Mon Sep 17 00:00:00 2001 From: gusiri Date: Tue, 18 Feb 2025 00:55:44 +0900 Subject: [PATCH 2/7] add module discoverer package --- .../{ => modulediscoverer}/module_discoverer.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) rename prover/protocol/distributed/{ => modulediscoverer}/module_discoverer.go (94%) diff --git a/prover/protocol/distributed/module_discoverer.go b/prover/protocol/distributed/modulediscoverer/module_discoverer.go similarity index 94% rename from prover/protocol/distributed/module_discoverer.go rename to prover/protocol/distributed/modulediscoverer/module_discoverer.go index c9ec54047..d110363b1 100644 --- a/prover/protocol/distributed/module_discoverer.go +++ b/prover/protocol/distributed/modulediscoverer/module_discoverer.go @@ -1,10 +1,11 @@ -package distributed +package modulediscoverer import ( "sync" "github.com/consensys/linea-monorepo/prover/protocol/coin" "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" + "github.com/consensys/linea-monorepo/prover/protocol/distributed" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/variables" "github.com/consensys/linea-monorepo/prover/protocol/wizard" @@ -12,6 +13,9 @@ import ( "github.com/consensys/linea-monorepo/prover/utils" ) +// Type alias for distributed.ModuleName +type ModuleName = distributed.ModuleName + // ModuleDiscoverer implements the ModuleDiscovererInterface type ModuleDiscoverer struct { moduleMapping map[ifaces.Column]ModuleName @@ -61,13 +65,13 @@ func (md *ModuleDiscoverer) Analyze(comp *wizard.CompiledIOP) { mergeClusters(cluster, connectedColumns) moduleFound = true md.assignModule(moduleName, connectedColumns) - break + break // possible scenario: two clusters could be merged together } } // If no module matches, create a new one if !moduleFound { - moduleName := ModuleName("Module_" + string(moduleIndex)) + moduleName := ModuleName("Module_" + string(moduleIndex)) // don't name it moduleIndex++ columnClusters[moduleName] = connectedColumns md.moduleNames = append(md.moduleNames, moduleName) @@ -99,6 +103,9 @@ func (md *ModuleDiscoverer) FindModule(col ifaces.Column) ModuleName { // ExpressionIsInModule checks that all the columns (except verifiercol) in the expression are from the given module. func (md *ModuleDiscoverer) ExpressionIsInModule(expr *symbolic.Expression, name ModuleName) bool { + + // get columns from an expression + board := expr.Board() metadata := board.ListVariableMetadata() From 68decefbf815493eeba7477d350cc0c4989b49fb Mon Sep 17 00:00:00 2001 From: gusiri Date: Sat, 1 Mar 2025 07:17:55 +0900 Subject: [PATCH 3/7] add module discoverer --- .../modulediscoverer/module_discoverer.go | 428 ++++++++++++------ .../module_discoverer_test.go | 162 +++++++ 2 files changed, 456 insertions(+), 134 deletions(-) create mode 100644 prover/protocol/distributed/modulediscoverer/module_discoverer_test.go diff --git a/prover/protocol/distributed/modulediscoverer/module_discoverer.go b/prover/protocol/distributed/modulediscoverer/module_discoverer.go index d110363b1..87bc7dc71 100644 --- a/prover/protocol/distributed/modulediscoverer/module_discoverer.go +++ b/prover/protocol/distributed/modulediscoverer/module_discoverer.go @@ -1,202 +1,362 @@ package modulediscoverer import ( + "fmt" "sync" - "github.com/consensys/linea-monorepo/prover/protocol/coin" - "github.com/consensys/linea-monorepo/prover/protocol/column/verifiercol" - "github.com/consensys/linea-monorepo/prover/protocol/distributed" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" - "github.com/consensys/linea-monorepo/prover/protocol/variables" + "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" "github.com/consensys/linea-monorepo/prover/symbolic" - "github.com/consensys/linea-monorepo/prover/utils" ) -// Type alias for distributed.ModuleName -type ModuleName = distributed.ModuleName +type ModuleName string -// ModuleDiscoverer implements the ModuleDiscovererInterface -type ModuleDiscoverer struct { - moduleMapping map[ifaces.Column]ModuleName - moduleNames []ModuleName - mutex sync.Mutex -} - -// ModuleDiscovererInterface defines methods for horizontal splitting (i.e., splitting into modules). -type ModuleDiscovererInterface interface { +// ModuleDiscoverer defines methods responsible for horizontal splitting (i.e., splitting into modules). +type ModuleDiscoverer interface { // Analyze computes how to group columns into modules. Analyze(comp *wizard.CompiledIOP) - NbModules() int + // ModuleList returns the list of module names. ModuleList() []ModuleName + // FindModule returns the module corresponding to a column. FindModule(col ifaces.Column) ModuleName + // NewSizeOf returns the split-size of a column in the module. + NewSizeOf(col ifaces.Column) int + ExpressionIsInModule(*symbolic.Expression, ModuleName) bool QueryIsInModule(ifaces.Query, ModuleName) bool ColumnIsInModule(col ifaces.Column, name ModuleName) bool } -// NewModuleDiscoverer initializes and returns a new ModuleDiscoverer instance. -func NewModuleDiscoverer() *ModuleDiscoverer { - return &ModuleDiscoverer{ - moduleMapping: make(map[ifaces.Column]ModuleName), - moduleNames: []ModuleName{}, +// DisjointSet represents a union-find data structure, which efficiently groups elements (columns) +// into disjoint sets (modules). It supports fast union and find operations with path compression. +type DisjointSet struct { + parent map[ifaces.Column]ifaces.Column // Maps a column to its representative parent. + rank map[ifaces.Column]int // Stores the rank (tree depth) for optimization. +} + +// NewDisjointSet initializes a new DisjointSet with empty mappings. +func NewDisjointSet() *DisjointSet { + return &DisjointSet{ + parent: make(map[ifaces.Column]ifaces.Column), + rank: make(map[ifaces.Column]int), } } -// Analyze clusters columns into modules by iterating through global constraints (QueriesNoParams). -// Columns sharing the same global constraints are grouped into the same module. -func (md *ModuleDiscoverer) Analyze(comp *wizard.CompiledIOP) { - md.mutex.Lock() - defer md.mutex.Unlock() +// Find returns the representative (root) of a column using path compression for optimization. +// Path compression ensures that the structure remains nearly flat, reducing the time complexity to O(α(n)), +// where α(n) is the inverse Ackermann function, which is nearly constant in practice. +// +// Example: +// Suppose we have the following sets: +// +// A -> B -> C (C is the root) +// D -> E -> F (F is the root) +// +// Calling Find(A) will compress the path so that: +// +// A -> C +// B -> C +// C remains the root +// +// Similarly, calling Find(D) will compress the path so that: +// +// D -> F +// E -> F +// F remains the root +func (ds *DisjointSet) Find(col ifaces.Column) ifaces.Column { + if _, exists := ds.parent[col]; !exists { + ds.parent[col] = col + ds.rank[col] = 0 + } + if ds.parent[col] != col { + ds.parent[col] = ds.Find(ds.parent[col]) + } + return ds.parent[col] +} - moduleIndex := 0 - columnClusters := make(map[ModuleName]map[ifaces.Column]bool) +// Union merges two sets by linking the root of one to the root of another, optimizing with rank. +// The smaller tree is always attached to the larger tree to keep the depth minimal. +// +// Time Complexity: O(α(n)) (nearly constant due to path compression and union by rank). +// +// Example: +// Suppose we have: +// +// Set 1: A -> B (B is the root) +// Set 2: C -> D (D is the root) +// +// Calling Union(A, C) will merge the sets: +// +// If B has a higher rank than D: +// D -> B +// C -> D -> B +// If D has a higher rank than B: +// B -> D +// A -> B -> D +// If B and D have equal rank: +// D -> B (or B -> D) +// Rank of the new root increases by 1 +func (ds *DisjointSet) Union(col1, col2 ifaces.Column) { + root1 := ds.Find(col1) + root2 := ds.Find(col2) - for _, qName := range comp.QueriesNoParams.AllUnignoredKeys() { - query := comp.QueriesNoParams.Data(qName) - - // Determine the columns connected by this global constraint - connectedColumns := getColumnsFromQuery(query) - - // Check if these columns belong to an existing module - moduleFound := false - for moduleName, cluster := range columnClusters { - if sharesColumns(cluster, connectedColumns) { - mergeClusters(cluster, connectedColumns) - moduleFound = true - md.assignModule(moduleName, connectedColumns) - break // possible scenario: two clusters could be merged together - } + if root1 != root2 { + if ds.rank[root1] > ds.rank[root2] { + ds.parent[root2] = root1 + } else if ds.rank[root1] < ds.rank[root2] { + ds.parent[root1] = root2 + } else { + ds.parent[root2] = root1 + ds.rank[root1]++ } + } +} + +// Module represents a set of columns grouped by constraints. +type Module struct { + moduleName ModuleName + ds *DisjointSet // Uses a disjoint set to track relationships among columns. + size int + numColumns int +} + +// Discoverer tracks modules using DisjointSet. +type Discoverer struct { + mutex sync.Mutex + modules []*Module + moduleNames []ModuleName + columnsToModule map[ifaces.Column]ModuleName +} - // If no module matches, create a new one - if !moduleFound { - moduleName := ModuleName("Module_" + string(moduleIndex)) // don't name it - moduleIndex++ - columnClusters[moduleName] = connectedColumns - md.moduleNames = append(md.moduleNames, moduleName) - md.assignModule(moduleName, connectedColumns) +// NewDiscoverer initializes a new Discoverer. +func NewDiscoverer() *Discoverer { + return &Discoverer{ + modules: []*Module{}, + moduleNames: []ModuleName{}, + columnsToModule: make(map[ifaces.Column]ModuleName), + } +} + +// CreateModule initializes a new module with a disjoint set and populates it with columns. +func (disc *Discoverer) CreateModule(columns []ifaces.Column) *Module { + module := &Module{ + moduleName: ModuleName(fmt.Sprintf("Module_%d", len(disc.modules))), + ds: NewDisjointSet(), + } + for _, col := range columns { + module.ds.parent[col] = col + module.ds.rank[col] = 0 + fmt.Println("Assigned parent for column:", col) + } + for i := 0; i < len(columns); i++ { + for j := i + 1; j < len(columns); j++ { + module.ds.Union(columns[i], columns[j]) } } + fmt.Println("Final parent map for module:", module.moduleName, module.ds.parent) + disc.moduleNames = append(disc.moduleNames, module.moduleName) + disc.modules = append(disc.modules, module) + return module } -// NbModules returns the total number of discovered modules. -func (md *ModuleDiscoverer) NbModules() int { - md.mutex.Lock() - defer md.mutex.Unlock() - return len(md.moduleNames) +// MergeModules merges a list of overlapping modules into a single module. +func (disc *Discoverer) MergeModules(modules []*Module, moduleCandidates *[]*Module) *Module { + if len(modules) == 0 { + return nil + } + + // Select the first module as the base + mergedModule := modules[0] + + // Merge all remaining modules into the base + for _, module := range modules[1:] { + for col := range module.ds.parent { + mergedModule.ds.Union(mergedModule.ds.Find(col), col) + } + + // Remove merged module from moduleCandidates + *moduleCandidates = removeModule(*moduleCandidates, module) + } + + return mergedModule } -// ModuleList returns the list of all module names. -func (md *ModuleDiscoverer) ModuleList() []ModuleName { - md.mutex.Lock() - defer md.mutex.Unlock() - return append([]ModuleName(nil), md.moduleNames...) +// AddColumnsToModule adds columns to an existing module. +func (disc *Discoverer) AddColumnsToModule(module *Module, columns []ifaces.Column) { + for _, col := range columns { + module.ds.parent[col] = col + module.ds.rank[col] = 0 + module.ds.Union(module.ds.Find(columns[0]), col) // Union with the first column + } } -// FindModule returns the module name for the given column. -func (md *ModuleDiscoverer) FindModule(col ifaces.Column) ModuleName { - md.mutex.Lock() - defer md.mutex.Unlock() - return md.moduleMapping[col] +// Helper function to remove a module from the slice +func removeModule(modules []*Module, target *Module) []*Module { + var updatedModules []*Module + for _, mod := range modules { + if mod != target { + updatedModules = append(updatedModules, mod) + } + } + return updatedModules } -// ExpressionIsInModule checks that all the columns (except verifiercol) in the expression are from the given module. -func (md *ModuleDiscoverer) ExpressionIsInModule(expr *symbolic.Expression, name ModuleName) bool { +// Analyze processes columns and assigns them to modules. - // get columns from an expression +// {1,2,3,4,5} +// {100} +// {6,7,8} +// {9,10} +// {3,6,20} +// {2,99} - board := expr.Board() - metadata := board.ListVariableMetadata() +// Processing: +// First Iteration - {1,2,3,4,5} +// No existing module. +// Create Module_0 → {1,2,3,4,5} +// Assign columns {1,2,3,4,5} to Module_0. - // by contradiction, if there is no metadata it belongs to the module. - if len(metadata) == 0 { - return true - } +// Second Iteration - {100} +// No overlap with existing modules. +// Create Module_1 → {100} +// Assign {100} to Module_1. + +// Third Iteration - {6,7,8} +// No overlap. +// Create Module_2 → {6,7,8} +// Assign {6,7,8} to Module_2. + +// Fourth Iteration - {9,10} +// No overlap. +// Create Module_3 → {9,10} +// Assign {9,10} to Module_3. - md.mutex.Lock() - defer md.mutex.Unlock() +// Fifth Iteration - {3,6,20} +// {3} is in Module_0, {6} is in Module_2 → Overlap detected. +// Merge Module_0 and Module_2 into Module_0. +// Module_0 now contains {1,2,3,4,5,6,7,8,20}. +// Remove Module_2 from moduleCandidates. +// Assign {3,6,20} to Module_0. - b := true - nCols := 0 +// Sixth Iteration - {2,99} +// {2} is in Module_0 → Overlap detected. +// Add {99} to Module_0. +// Module_0 now contains {1,2,3,4,5,6,7,8,20,99}. +// Assign {2,99} to Module_0. - for _, m := range metadata { - switch v := m.(type) { - case ifaces.Column: - if _, ok := v.(verifiercol.VerifierCol); !ok { - if !md.ColumnIsInModule(v, name) { - b = false - } - nCols++ +// Final Modules: +// Module_0 → {1,2,3,4,5,6,7,8,20,99} +// Module_1 → {100} +// Module_3 → {9,10} + +func (disc *Discoverer) Analyze(comp *wizard.CompiledIOP) { + disc.mutex.Lock() + defer disc.mutex.Unlock() + + moduleCandidates := []*Module{} + + for _, qName := range comp.QueriesNoParams.AllUnignoredKeys() { + cs, ok := comp.QueriesNoParams.Data(qName).(query.GlobalConstraint) + if !ok { + continue // Skip non-global constraints + } + + columns := getColumnsFromQuery(cs) + overlappingModules := []*Module{} + + // Find overlapping modules + for _, module := range moduleCandidates { + if HasOverlap(module, columns) { + overlappingModules = append(overlappingModules, module) } - // The expression can involve random coins - case coin.Info, variables.X, variables.PeriodicSample, ifaces.Accessor: - // Do nothing - default: - utils.Panic("unknown type %T", metadata) + } + + var assignedModule *Module + + // Merge if necessary + if len(overlappingModules) > 0 { + assignedModule = disc.MergeModules(overlappingModules, &moduleCandidates) + disc.AddColumnsToModule(assignedModule, columns) + } else { + // Create a new module + assignedModule = disc.CreateModule(columns) + moduleCandidates = append(moduleCandidates, assignedModule) } } - if nCols == 0 { - panic("could not find any column in the expression") + // Assign final module names after all processing + for _, module := range moduleCandidates { + for col := range module.ds.parent { + disc.columnsToModule[col] = module.moduleName + } } - return b } -// QueryIsInModule checks if the given query is inside the given module -func (md *ModuleDiscoverer) QueryIsInModule(query ifaces.Query, name ModuleName) bool { - md.mutex.Lock() - defer md.mutex.Unlock() - for _, col := range query.Columns() { - if md.FindModule(col) != name { - return false +// getColumnsFromQuery extracts columns from a global constraint query. +func getColumnsFromQuery(q ifaces.Query) []ifaces.Column { + gc, ok := q.(query.GlobalConstraint) + if !ok { + return nil // Not a global constraint, return nil + } + + // Extract columns from the constraint expression + var columns []ifaces.Column + board := gc.Expression.Board() + for _, metadata := range board.ListVariableMetadata() { + if col, ok := metadata.(ifaces.Column); ok { + columns = append(columns, col) } } - return true + + return columns +} + +// assignModule assigns a module name to a set of columns. +func (disc *Discoverer) assignModule(moduleName ModuleName, columns []ifaces.Column) { + for _, col := range columns { + disc.columnsToModule[col] = moduleName + } } -// ColumnIsInModule checks that the given column is inside the given module. -func (md *ModuleDiscoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { - md.mutex.Lock() - defer md.mutex.Unlock() - return md.moduleMapping[col] == name +// NewSizeOf returns the size (length) of a column. +func (disc *Discoverer) NewSizeOf(col ifaces.Column) int { + return col.Size() } -// CoinIsInModule (placeholder): Extend logic to handle coins if needed. -func (md *ModuleDiscoverer) CoinIsInModule(coin ifaces.Coin, name ModuleName) bool { - // Logic for associating coins with modules goes here - return false +// ModuleList returns a list of all module names. +func (disc *Discoverer) ModuleList() []ModuleName { + disc.mutex.Lock() + defer disc.mutex.Unlock() + return disc.moduleNames } -// Utility: Extracts columns involved in a query. -func getColumnsFromQuery(query ifaces.Query) map[ifaces.Column]bool { - columns := make(map[ifaces.Column]bool) - for _, col := range query.Columns() { - columns[col] = true +// ModuleOf returns the module name for a given column. +func (disc *Discoverer) ModuleOf(col ifaces.Column) ModuleName { + disc.mutex.Lock() + defer disc.mutex.Unlock() + + if moduleName, exists := disc.columnsToModule[col]; exists { + return moduleName } - return columns + return "" } -// Utility: Checks if two column clusters share any columns. -func sharesColumns(cluster map[ifaces.Column]bool, columns map[ifaces.Column]bool) bool { - for col := range columns { - if cluster[col] { +// HasOverlap checks if a module shares at least one column with a set of columns. +func HasOverlap(module *Module, columns []ifaces.Column) bool { + for _, col := range columns { + fmt.Println("Checking column:", col, "against module:", module.moduleName) + if _, exists := module.ds.parent[col]; exists { + fmt.Println("Overlap found between:", col, "and module:", module.moduleName) return true } } return false } -// Utility: Merges columns into an existing cluster. -func mergeClusters(cluster map[ifaces.Column]bool, columns map[ifaces.Column]bool) { - for col := range columns { - cluster[col] = true - } -} - -// Utility: Assigns module information to columns in the mapping. -func (md *ModuleDiscoverer) assignModule(moduleName ModuleName, columns map[ifaces.Column]bool) { - for col := range columns { - md.moduleMapping[col] = moduleName - } +// NbModules returns the total number of discovered modules. +func (disc *Discoverer) NbModules() int { + disc.mutex.Lock() + defer disc.mutex.Unlock() + return len(disc.moduleNames) } diff --git a/prover/protocol/distributed/modulediscoverer/module_discoverer_test.go b/prover/protocol/distributed/modulediscoverer/module_discoverer_test.go new file mode 100644 index 000000000..0f53ec2aa --- /dev/null +++ b/prover/protocol/distributed/modulediscoverer/module_discoverer_test.go @@ -0,0 +1,162 @@ +package modulediscoverer_test + +import ( + "fmt" + "sync" + "testing" + + "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" + "github.com/stretchr/testify/assert" +) + +// This should be re-written for ifaces.Column instead of smartvectors.Smartvector +type ModuleName string + +type DisjointSetTest struct { + parent map[smartvectors.SmartVector]smartvectors.SmartVector + rank map[smartvectors.SmartVector]int +} + +func NewDisjointSetTest() *DisjointSetTest { + return &DisjointSetTest{ + parent: make(map[smartvectors.SmartVector]smartvectors.SmartVector), + rank: make(map[smartvectors.SmartVector]int), + } +} + +func (ds *DisjointSetTest) Find(vec smartvectors.SmartVector) smartvectors.SmartVector { + if _, exists := ds.parent[vec]; !exists { + ds.parent[vec] = vec + ds.rank[vec] = 0 + } + if ds.parent[vec] != vec { + ds.parent[vec] = ds.Find(ds.parent[vec]) + } + return ds.parent[vec] +} + +func (ds *DisjointSetTest) Union(vec1, vec2 smartvectors.SmartVector) { + root1 := ds.Find(vec1) + root2 := ds.Find(vec2) + + if root1 != root2 { + if ds.rank[root1] > ds.rank[root2] { + ds.parent[root2] = root1 + } else if ds.rank[root1] < ds.rank[root2] { + ds.parent[root1] = root2 + } else { + ds.parent[root2] = root1 + ds.rank[root1]++ + } + } +} + +type Module struct { + moduleName ModuleName + ds *DisjointSetTest +} + +type Discoverer struct { + mutex sync.Mutex + modules []*Module + moduleNames []ModuleName + columnsToModule map[smartvectors.SmartVector]ModuleName +} + +func NewDiscovererTest() *Discoverer { + return &Discoverer{ + modules: []*Module{}, + moduleNames: []ModuleName{}, + columnsToModule: make(map[smartvectors.SmartVector]ModuleName), + } +} + +func (disc *Discoverer) ModuleList() []ModuleName { + disc.mutex.Lock() + defer disc.mutex.Unlock() + return disc.moduleNames +} + +func (disc *Discoverer) assignModule(moduleName ModuleName, vectors []smartvectors.SmartVector) { + for _, vec := range vectors { + disc.columnsToModule[vec] = moduleName + } +} + +func (disc *Discoverer) CreateModule(vectors []smartvectors.SmartVector) *Module { + module := &Module{ + moduleName: ModuleName(fmt.Sprintf("Module_%d", len(disc.modules))), + ds: NewDisjointSetTest(), + } + + for _, vec := range vectors { + module.ds.parent[vec] = vec + module.ds.rank[vec] = 0 + } + + // Union all vectors together in the module + for i := 0; i < len(vectors); i++ { + for j := i + 1; j < len(vectors); j++ { + module.ds.Union(vectors[i], vectors[j]) + } + } + + fmt.Println("Final parent map for module:", module.moduleName, module.ds.parent) + disc.modules = append(disc.modules, module) + return module +} + +func HasOverlap(module *Module, vectors []smartvectors.SmartVector) bool { + for _, vec := range vectors { + fmt.Println("Checking vector:", vec, "against module:", module.moduleName) + if _, exists := module.ds.parent[vec]; exists { + fmt.Println("Overlap found between:", vec, "and module:", module.moduleName) + return true + } else { + fmt.Println("Vector:", vec, "NOT found in module:", module.moduleName) + } + } + return false +} + +func TestUnion(t *testing.T) { + ds := NewDisjointSetTest() + vec1 := smartvectors.ForTest(1) + vec2 := smartvectors.ForTest(2) + vec3 := smartvectors.ForTest(3) + ds.Union(vec1, vec2) + assert.Equal(t, ds.Find(vec1), ds.Find(vec2), "vec1 and vec2 should have the same root") + ds.Union(vec2, vec3) + assert.Equal(t, ds.Find(vec1), ds.Find(vec3), "vec1, vec2, and vec3 should have the same root") +} + +func TestCreateModule(t *testing.T) { + disc := NewDiscovererTest() + vectors := []smartvectors.SmartVector{smartvectors.ForTest(1), smartvectors.ForTest(2), smartvectors.ForTest(3)} + module := disc.CreateModule(vectors) + + assert.NotNil(t, module, "Module should not be nil") + assert.Equal(t, 1, len(disc.modules), "Discoverer should have one module") + for _, vec := range vectors { + assert.Equal(t, module.ds.Find(vec), module.ds.Find(vectors[0]), "All vectors should belong to the same set") + } +} +func TestHasOverlap(t *testing.T) { + disc := NewDiscovererTest() + + vec1 := smartvectors.ForTest(1) + vec2 := smartvectors.ForTest(2) + vec3 := smartvectors.ForTest(3) + vec4 := smartvectors.ForTest(4) + vec5 := smartvectors.ForTest(5) + vec6 := smartvectors.ForTest(6) + vec7 := smartvectors.ForTest(7) + + module1 := disc.CreateModule([]smartvectors.SmartVector{vec1, vec2}) + module2 := disc.CreateModule([]smartvectors.SmartVector{vec3, vec4}) + candidates := []smartvectors.SmartVector{vec2, vec5} + + assert.False(t, HasOverlap(module1, []smartvectors.SmartVector{vec6, vec7}), "module1 should NOT overlap with unrelated vectors") + assert.True(t, HasOverlap(module1, candidates), "module1 should overlap with candidates") + assert.False(t, HasOverlap(module2, candidates), "module2 should NOT overlap with candidates") +} From 9e1bdf16089b8953210717d651bec229fdfd55de Mon Sep 17 00:00:00 2001 From: gusiri Date: Sat, 1 Mar 2025 07:35:15 +0900 Subject: [PATCH 4/7] add missing implementation --- .../modulediscoverer/module_discoverer.go | 58 ++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/prover/protocol/distributed/modulediscoverer/module_discoverer.go b/prover/protocol/distributed/modulediscoverer/module_discoverer.go index 87bc7dc71..35969af5e 100644 --- a/prover/protocol/distributed/modulediscoverer/module_discoverer.go +++ b/prover/protocol/distributed/modulediscoverer/module_discoverer.go @@ -7,7 +7,6 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/ifaces" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" - "github.com/consensys/linea-monorepo/prover/symbolic" ) type ModuleName string @@ -23,9 +22,17 @@ type ModuleDiscoverer interface { // NewSizeOf returns the split-size of a column in the module. NewSizeOf(col ifaces.Column) int - ExpressionIsInModule(*symbolic.Expression, ModuleName) bool QueryIsInModule(ifaces.Query, ModuleName) bool + // it return true if it can find any column from the given slice in the module + SliceIsInModule([]ifaces.Column, ModuleName) bool + // it checks if the given column is in the given module ColumnIsInModule(col ifaces.Column, name ModuleName) bool + // it adds all the unassigned columns in the slice to the given module. + UpdateDiscoverer([]ifaces.Column, ModuleName) + // it return the module associated with the column, if it is already captured + HasModule(col ifaces.Column) (ModuleName, bool) + // return the columns from the module + ListColumns(ModuleName) []ifaces.Column } // DisjointSet represents a union-find data structure, which efficiently groups elements (columns) @@ -354,6 +361,53 @@ func HasOverlap(module *Module, columns []ifaces.Column) bool { return false } +func (disc *Discoverer) QueryIsInModule(query ifaces.Query, name ModuleName) bool { + // Extract columns from the query + columns := getColumnsFromQuery(query) + + // Check if any of the columns belong to the module + return disc.SliceIsInModule(columns, name) +} + +func (disc *Discoverer) SliceIsInModule(columns []ifaces.Column, name ModuleName) bool { + for _, col := range columns { + if disc.ColumnIsInModule(col, name) { + return true + } + } + return false +} + +func (disc *Discoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) bool { + if moduleName, exists := disc.columnsToModule[col]; exists { + return moduleName == name + } + return false +} + +func (disc *Discoverer) UpdateDiscoverer(columns []ifaces.Column, name ModuleName) { + for _, col := range columns { + if _, exists := disc.columnsToModule[col]; !exists { + disc.columnsToModule[col] = name + } + } +} + +func (disc *Discoverer) HasModule(col ifaces.Column) (ModuleName, bool) { + moduleName, exists := disc.columnsToModule[col] + return moduleName, exists +} + +func (disc *Discoverer) ListColumns(name ModuleName) []ifaces.Column { + var columns []ifaces.Column + for col, moduleName := range disc.columnsToModule { + if moduleName == name { + columns = append(columns, col) + } + } + return columns +} + // NbModules returns the total number of discovered modules. func (disc *Discoverer) NbModules() int { disc.mutex.Lock() From b3fb3590aec25f07f7cf8e89ac1c6451504b6e45 Mon Sep 17 00:00:00 2001 From: gusiri Date: Sat, 1 Mar 2025 08:01:05 +0900 Subject: [PATCH 5/7] fix UpdateDiscoverer --- .../modulediscoverer/module_discoverer.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/prover/protocol/distributed/modulediscoverer/module_discoverer.go b/prover/protocol/distributed/modulediscoverer/module_discoverer.go index 35969af5e..60fc68283 100644 --- a/prover/protocol/distributed/modulediscoverer/module_discoverer.go +++ b/prover/protocol/distributed/modulediscoverer/module_discoverer.go @@ -386,11 +386,20 @@ func (disc *Discoverer) ColumnIsInModule(col ifaces.Column, name ModuleName) boo } func (disc *Discoverer) UpdateDiscoverer(columns []ifaces.Column, name ModuleName) { - for _, col := range columns { - if _, exists := disc.columnsToModule[col]; !exists { - disc.columnsToModule[col] = name + // Find the module corresponding to the given name + var targetModule *Module + for _, module := range disc.modules { + if module.moduleName == name { + targetModule = module + break } } + + // If the module is found, add the columns + if targetModule != nil { + disc.AddColumnsToModule(targetModule, columns) + disc.assignModule(name, columns) + } } func (disc *Discoverer) HasModule(col ifaces.Column) (ModuleName, bool) { From b143aa05a6373523772def8b052134703a6dc8e3 Mon Sep 17 00:00:00 2001 From: gusiri Date: Sat, 1 Mar 2025 08:03:20 +0900 Subject: [PATCH 6/7] clean up logs --- .../distributed/modulediscoverer/module_discoverer.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/prover/protocol/distributed/modulediscoverer/module_discoverer.go b/prover/protocol/distributed/modulediscoverer/module_discoverer.go index 60fc68283..5bdf87627 100644 --- a/prover/protocol/distributed/modulediscoverer/module_discoverer.go +++ b/prover/protocol/distributed/modulediscoverer/module_discoverer.go @@ -154,14 +154,13 @@ func (disc *Discoverer) CreateModule(columns []ifaces.Column) *Module { for _, col := range columns { module.ds.parent[col] = col module.ds.rank[col] = 0 - fmt.Println("Assigned parent for column:", col) } for i := 0; i < len(columns); i++ { for j := i + 1; j < len(columns); j++ { module.ds.Union(columns[i], columns[j]) } } - fmt.Println("Final parent map for module:", module.moduleName, module.ds.parent) + disc.moduleNames = append(disc.moduleNames, module.moduleName) disc.modules = append(disc.modules, module) return module @@ -352,9 +351,7 @@ func (disc *Discoverer) ModuleOf(col ifaces.Column) ModuleName { // HasOverlap checks if a module shares at least one column with a set of columns. func HasOverlap(module *Module, columns []ifaces.Column) bool { for _, col := range columns { - fmt.Println("Checking column:", col, "against module:", module.moduleName) if _, exists := module.ds.parent[col]; exists { - fmt.Println("Overlap found between:", col, "and module:", module.moduleName) return true } } From 49f3634c691a3af7f723533f9e3708ed0cf718f7 Mon Sep 17 00:00:00 2001 From: gusiri Date: Sat, 1 Mar 2025 08:09:42 +0900 Subject: [PATCH 7/7] clean up code --- .../modulediscoverer/module_discoverer.go | 330 +++++++++--------- 1 file changed, 165 insertions(+), 165 deletions(-) diff --git a/prover/protocol/distributed/modulediscoverer/module_discoverer.go b/prover/protocol/distributed/modulediscoverer/module_discoverer.go index 5bdf87627..f16b6a64d 100644 --- a/prover/protocol/distributed/modulediscoverer/module_discoverer.go +++ b/prover/protocol/distributed/modulediscoverer/module_discoverer.go @@ -35,91 +35,6 @@ type ModuleDiscoverer interface { ListColumns(ModuleName) []ifaces.Column } -// DisjointSet represents a union-find data structure, which efficiently groups elements (columns) -// into disjoint sets (modules). It supports fast union and find operations with path compression. -type DisjointSet struct { - parent map[ifaces.Column]ifaces.Column // Maps a column to its representative parent. - rank map[ifaces.Column]int // Stores the rank (tree depth) for optimization. -} - -// NewDisjointSet initializes a new DisjointSet with empty mappings. -func NewDisjointSet() *DisjointSet { - return &DisjointSet{ - parent: make(map[ifaces.Column]ifaces.Column), - rank: make(map[ifaces.Column]int), - } -} - -// Find returns the representative (root) of a column using path compression for optimization. -// Path compression ensures that the structure remains nearly flat, reducing the time complexity to O(α(n)), -// where α(n) is the inverse Ackermann function, which is nearly constant in practice. -// -// Example: -// Suppose we have the following sets: -// -// A -> B -> C (C is the root) -// D -> E -> F (F is the root) -// -// Calling Find(A) will compress the path so that: -// -// A -> C -// B -> C -// C remains the root -// -// Similarly, calling Find(D) will compress the path so that: -// -// D -> F -// E -> F -// F remains the root -func (ds *DisjointSet) Find(col ifaces.Column) ifaces.Column { - if _, exists := ds.parent[col]; !exists { - ds.parent[col] = col - ds.rank[col] = 0 - } - if ds.parent[col] != col { - ds.parent[col] = ds.Find(ds.parent[col]) - } - return ds.parent[col] -} - -// Union merges two sets by linking the root of one to the root of another, optimizing with rank. -// The smaller tree is always attached to the larger tree to keep the depth minimal. -// -// Time Complexity: O(α(n)) (nearly constant due to path compression and union by rank). -// -// Example: -// Suppose we have: -// -// Set 1: A -> B (B is the root) -// Set 2: C -> D (D is the root) -// -// Calling Union(A, C) will merge the sets: -// -// If B has a higher rank than D: -// D -> B -// C -> D -> B -// If D has a higher rank than B: -// B -> D -// A -> B -> D -// If B and D have equal rank: -// D -> B (or B -> D) -// Rank of the new root increases by 1 -func (ds *DisjointSet) Union(col1, col2 ifaces.Column) { - root1 := ds.Find(col1) - root2 := ds.Find(col2) - - if root1 != root2 { - if ds.rank[root1] > ds.rank[root2] { - ds.parent[root2] = root1 - } else if ds.rank[root1] < ds.rank[root2] { - ds.parent[root1] = root2 - } else { - ds.parent[root2] = root1 - ds.rank[root1]++ - } - } -} - // Module represents a set of columns grouped by constraints. type Module struct { moduleName ModuleName @@ -145,69 +60,6 @@ func NewDiscoverer() *Discoverer { } } -// CreateModule initializes a new module with a disjoint set and populates it with columns. -func (disc *Discoverer) CreateModule(columns []ifaces.Column) *Module { - module := &Module{ - moduleName: ModuleName(fmt.Sprintf("Module_%d", len(disc.modules))), - ds: NewDisjointSet(), - } - for _, col := range columns { - module.ds.parent[col] = col - module.ds.rank[col] = 0 - } - for i := 0; i < len(columns); i++ { - for j := i + 1; j < len(columns); j++ { - module.ds.Union(columns[i], columns[j]) - } - } - - disc.moduleNames = append(disc.moduleNames, module.moduleName) - disc.modules = append(disc.modules, module) - return module -} - -// MergeModules merges a list of overlapping modules into a single module. -func (disc *Discoverer) MergeModules(modules []*Module, moduleCandidates *[]*Module) *Module { - if len(modules) == 0 { - return nil - } - - // Select the first module as the base - mergedModule := modules[0] - - // Merge all remaining modules into the base - for _, module := range modules[1:] { - for col := range module.ds.parent { - mergedModule.ds.Union(mergedModule.ds.Find(col), col) - } - - // Remove merged module from moduleCandidates - *moduleCandidates = removeModule(*moduleCandidates, module) - } - - return mergedModule -} - -// AddColumnsToModule adds columns to an existing module. -func (disc *Discoverer) AddColumnsToModule(module *Module, columns []ifaces.Column) { - for _, col := range columns { - module.ds.parent[col] = col - module.ds.rank[col] = 0 - module.ds.Union(module.ds.Find(columns[0]), col) // Union with the first column - } -} - -// Helper function to remove a module from the slice -func removeModule(modules []*Module, target *Module) []*Module { - var updatedModules []*Module - for _, mod := range modules { - if mod != target { - updatedModules = append(updatedModules, mod) - } - } - return updatedModules -} - // Analyze processes columns and assigns them to modules. // {1,2,3,4,5} @@ -299,6 +151,86 @@ func (disc *Discoverer) Analyze(comp *wizard.CompiledIOP) { } } +// CreateModule initializes a new module with a disjoint set and populates it with columns. +func (disc *Discoverer) CreateModule(columns []ifaces.Column) *Module { + module := &Module{ + moduleName: ModuleName(fmt.Sprintf("Module_%d", len(disc.modules))), + ds: NewDisjointSet(), + } + for _, col := range columns { + module.ds.parent[col] = col + module.ds.rank[col] = 0 + } + for i := 0; i < len(columns); i++ { + for j := i + 1; j < len(columns); j++ { + module.ds.Union(columns[i], columns[j]) + } + } + + disc.moduleNames = append(disc.moduleNames, module.moduleName) + disc.modules = append(disc.modules, module) + return module +} + +// MergeModules merges a list of overlapping modules into a single module. +func (disc *Discoverer) MergeModules(modules []*Module, moduleCandidates *[]*Module) *Module { + if len(modules) == 0 { + return nil + } + + // Select the first module as the base + mergedModule := modules[0] + + // Merge all remaining modules into the base + for _, module := range modules[1:] { + for col := range module.ds.parent { + mergedModule.ds.Union(mergedModule.ds.Find(col), col) + } + + // Remove merged module from moduleCandidates + *moduleCandidates = removeModule(*moduleCandidates, module) + } + + return mergedModule +} + +// AddColumnsToModule adds columns to an existing module. +func (disc *Discoverer) AddColumnsToModule(module *Module, columns []ifaces.Column) { + for _, col := range columns { + module.ds.parent[col] = col + module.ds.rank[col] = 0 + module.ds.Union(module.ds.Find(columns[0]), col) // Union with the first column + } +} + +// Helper function to remove a module from the slice +func removeModule(modules []*Module, target *Module) []*Module { + var updatedModules []*Module + for _, mod := range modules { + if mod != target { + updatedModules = append(updatedModules, mod) + } + } + return updatedModules +} + +// assignModule assigns a module name to a set of columns. +func (disc *Discoverer) assignModule(moduleName ModuleName, columns []ifaces.Column) { + for _, col := range columns { + disc.columnsToModule[col] = moduleName + } +} + +// HasOverlap checks if a module shares at least one column with a set of columns. +func HasOverlap(module *Module, columns []ifaces.Column) bool { + for _, col := range columns { + if _, exists := module.ds.parent[col]; exists { + return true + } + } + return false +} + // getColumnsFromQuery extracts columns from a global constraint query. func getColumnsFromQuery(q ifaces.Query) []ifaces.Column { gc, ok := q.(query.GlobalConstraint) @@ -318,13 +250,6 @@ func getColumnsFromQuery(q ifaces.Query) []ifaces.Column { return columns } -// assignModule assigns a module name to a set of columns. -func (disc *Discoverer) assignModule(moduleName ModuleName, columns []ifaces.Column) { - for _, col := range columns { - disc.columnsToModule[col] = moduleName - } -} - // NewSizeOf returns the size (length) of a column. func (disc *Discoverer) NewSizeOf(col ifaces.Column) int { return col.Size() @@ -348,16 +273,6 @@ func (disc *Discoverer) ModuleOf(col ifaces.Column) ModuleName { return "" } -// HasOverlap checks if a module shares at least one column with a set of columns. -func HasOverlap(module *Module, columns []ifaces.Column) bool { - for _, col := range columns { - if _, exists := module.ds.parent[col]; exists { - return true - } - } - return false -} - func (disc *Discoverer) QueryIsInModule(query ifaces.Query, name ModuleName) bool { // Extract columns from the query columns := getColumnsFromQuery(query) @@ -420,3 +335,88 @@ func (disc *Discoverer) NbModules() int { defer disc.mutex.Unlock() return len(disc.moduleNames) } + +// DisjointSet represents a union-find data structure, which efficiently groups elements (columns) +// into disjoint sets (modules). It supports fast union and find operations with path compression. +type DisjointSet struct { + parent map[ifaces.Column]ifaces.Column // Maps a column to its representative parent. + rank map[ifaces.Column]int // Stores the rank (tree depth) for optimization. +} + +// NewDisjointSet initializes a new DisjointSet with empty mappings. +func NewDisjointSet() *DisjointSet { + return &DisjointSet{ + parent: make(map[ifaces.Column]ifaces.Column), + rank: make(map[ifaces.Column]int), + } +} + +// Find returns the representative (root) of a column using path compression for optimization. +// Path compression ensures that the structure remains nearly flat, reducing the time complexity to O(α(n)), +// where α(n) is the inverse Ackermann function, which is nearly constant in practice. +// +// Example: +// Suppose we have the following sets: +// +// A -> B -> C (C is the root) +// D -> E -> F (F is the root) +// +// Calling Find(A) will compress the path so that: +// +// A -> C +// B -> C +// C remains the root +// +// Similarly, calling Find(D) will compress the path so that: +// +// D -> F +// E -> F +// F remains the root +func (ds *DisjointSet) Find(col ifaces.Column) ifaces.Column { + if _, exists := ds.parent[col]; !exists { + ds.parent[col] = col + ds.rank[col] = 0 + } + if ds.parent[col] != col { + ds.parent[col] = ds.Find(ds.parent[col]) + } + return ds.parent[col] +} + +// Union merges two sets by linking the root of one to the root of another, optimizing with rank. +// The smaller tree is always attached to the larger tree to keep the depth minimal. +// +// Time Complexity: O(α(n)) (nearly constant due to path compression and union by rank). +// +// Example: +// Suppose we have: +// +// Set 1: A -> B (B is the root) +// Set 2: C -> D (D is the root) +// +// Calling Union(A, C) will merge the sets: +// +// If B has a higher rank than D: +// D -> B +// C -> D -> B +// If D has a higher rank than B: +// B -> D +// A -> B -> D +// If B and D have equal rank: +// D -> B (or B -> D) +// Rank of the new root increases by 1 +func (ds *DisjointSet) Union(col1, col2 ifaces.Column) { + root1 := ds.Find(col1) + root2 := ds.Find(col2) + + if root1 != root2 { + if ds.rank[root1] > ds.rank[root2] { + ds.parent[root2] = root1 + } else if ds.rank[root1] < ds.rank[root2] { + ds.parent[root1] = root2 + } else { + ds.parent[root2] = root1 + ds.rank[root1]++ + } + } +}