diff --git a/lowtime/solver.py b/lowtime/solver.py index d246fd7..4f151da 100644 --- a/lowtime/solver.py +++ b/lowtime/solver.py @@ -435,11 +435,8 @@ def find_min_cut(self, dag: nx.DiGraph) -> tuple[set[int], set[int]]: profiling_min_cut_setup, ) - # We're done with constructing the DAG with only flow upper bounds. - # Find the maximum flow on this DAG. - try: - profiling_max_flow = time.time() - # ohjun: convert to Rust types + # Helper function for Rust interop + def format_rust_inputs(dag: nx.DiGraph) -> tuple[nx.NodeView, list[tuple[tuple[int, int], np.float64]]]: nodes = unbound_dag.nodes edges = [ ((u, v), cap) @@ -448,6 +445,29 @@ def find_min_cut(self, dag: nx.DiGraph) -> tuple[set[int], set[int]]: ).items() ] + # Helper function for Rust interop + # Note: this returns flows as float, not np.float64 + def reformat_rust_flow_to_dict(flow_vec: list[tuple[tuple[int, int], float]], dag: nx.DiGraph) -> dict[int, dict[int, float]]: + # ohjun: technicality. Rust's pathfinding::edmonds_karp doesn't + # return edges with 0 flow, but nx.max_flow does. So we fill in + # the 0s and empty nodes. + flow_dict = dict() + for u in dag.nodes: + flow_dict[u] = dict() + for v in dag.successors(u): + flow_dict[u][v] = 0.0 + + for (u, v), cap in flow_vec: + flow_dict[u][v] = cap + + return flow_dict + + # We're done with constructing the DAG with only flow upper bounds. + # Find the maximum flow on this DAG. + try: + profiling_max_flow = time.time() + nodes, edges = format_rust_inputs(unbound_dag) + profiling_data_transfer = time.time() rust_dag = lowtime_rust.PhillipsDessouky( nodes, s_prime_id, t_prime_id, edges @@ -459,24 +479,7 @@ def find_min_cut(self, dag: nx.DiGraph) -> tuple[set[int], set[int]]: ) rust_flow_vec = rust_dag.max_flow() - # _, flow_dict = nx.maximum_flow( - # unbound_dag, - # s_prime_id, - # t_prime_id, - # capacity="capacity", - # flow_func=edmonds_karp, - # ) - flow_dict = dict() - # ohjun: technicality. Rust's pathfinding::edmonds_karp doesn't - # return edges with 0 flow, but nx.max_flow does. So we fill in - # the 0s and empty nodes. - for u in unbound_dag.nodes: - flow_dict[u] = dict() - for v in unbound_dag.successors(u): - flow_dict[u][v] = 0.0 - - for (u, v), cap in rust_flow_vec: - flow_dict[u][v] = cap + flow_dict = reformat_rust_flow_to_dict(rust_flow_vec, unbound_dag) profiling_max_flow = time.time() - profiling_max_flow logger.info(