Skip to content

Commit

Permalink
Merge pull request #35 from mastodon-sc/compute-node-mapping-review
Browse files Browse the repository at this point in the history
Compute node mapping review
  • Loading branch information
stefanhahmann authored Sep 5, 2023
2 parents f45fd7f + e0c2923 commit b4cb6f9
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 104 deletions.
42 changes: 0 additions & 42 deletions src/main/java/org/mastodon/mamut/treesimilarity/FlowNetwork.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import org.apache.commons.lang3.tuple.Pair;
import org.mastodon.mamut.treesimilarity.tree.Tree;
import org.mastodon.mamut.treesimilarity.tree.TreeUtils;
import org.mastodon.mamut.treesimilarity.util.FlowNetwork;
import org.mastodon.mamut.treesimilarity.util.NodeMapping;
import org.mastodon.mamut.treesimilarity.util.NodeMappings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -130,6 +133,23 @@ else if ( tree2 == null )
return zhang.compute( tree1, tree2 );
}

/**
* Calculates a mapping between nodes in the given two trees ({@code tree1} and {@code tree2}) that links the nodes from the two trees, which have the minimum tree edit distance to each other.<p>
* The required minimum tree edit distance is calculated using the Zhang unordered edit distance.
* @param tree1 The first tree.
* @param tree2 The second tree.
* @param costFunction The cost function.
* @return The mapping between nodes.
*/
public static < T > Map< Tree< T >, Tree< T > > nodeMapping( Tree< T > tree1, Tree< T > tree2, BiFunction< T, T, Double > costFunction )
{
if ( tree1 == null || tree2 == null )
return Collections.emptyMap();

NodeMapping< T > mapping = new ZhangUnorderedTreeEditDistance<>( tree1, tree2, costFunction ).treeMapping( tree1, tree2 );
return mapping.asMap();
}

private static < T > double distanceTreeToNull( Tree< T > tree2, BiFunction< T, T, Double > costFunction )
{
double distance = 0;
Expand Down Expand Up @@ -162,15 +182,6 @@ private ZhangUnorderedTreeEditDistance( final Tree< T > tree1, final Tree< T > t
forestDistances = new HashMap<>();
}

public static < T > Map< Tree< T >, Tree< T > > nodeMapping( Tree< T > tree1, Tree< T > tree2, BiFunction< T, T, Double > costFunction )
{
if ( tree1 == null || tree2 == null )
return Collections.emptyMap();

NodeMapping< T > matching = new ZhangUnorderedTreeEditDistance<>( tree1, tree2, costFunction ).treeMapping( tree1, tree2 );
return matching.asMap();
}

/**
* Calculate the Zhang edit distance between two (labeled) unordered trees.
*
Expand Down Expand Up @@ -312,9 +323,9 @@ private NodeMapping< T > insertOperationMapping( Tree< T > tree1, Tree< T > tree
double insertCostTree2 = insertCosts.get( tree2 ).treeCost;
return findBestMapping( tree2.getChildren(), child ->
{
NodeMapping< T > insertCosts = NodeMappings.empty( insertCostTree2 - this.insertCosts.get( child ).treeCost );
NodeMapping< T > insertMapping = NodeMappings.empty( insertCostTree2 - this.insertCosts.get( child ).treeCost );
NodeMapping< T > childMapping = treeMapping( tree1, child );
return NodeMappings.compose( insertCosts, childMapping );
return NodeMappings.compose( insertMapping, childMapping );
} );
}

Expand All @@ -329,9 +340,9 @@ private NodeMapping< T > deleteOperationMapping( Tree< T > tree1, Tree< T > tree
double deleteCostTree1 = deleteCosts.get( tree1 ).treeCost;
return findBestMapping( tree1.getChildren(), child ->
{
NodeMapping< T > deleteCosts = NodeMappings.empty( deleteCostTree1 - this.deleteCosts.get( child ).treeCost );
NodeMapping< T > deleteMapping = NodeMappings.empty( deleteCostTree1 - this.deleteCosts.get( child ).treeCost );
NodeMapping< T > childMapping = treeMapping( child, tree2 );
return NodeMappings.compose( deleteCosts, childMapping );
return NodeMappings.compose( deleteMapping, childMapping );
} );
}

Expand All @@ -345,9 +356,9 @@ private NodeMapping< T > forestInsertMapping( Tree< T > forest1, Tree< T > fores
double insertCostForest2 = insertCosts.get( forest2 ).forestCost;
return findBestMapping( forest2.getChildren(), child ->
{
NodeMapping< T > insertCosts = NodeMappings.empty( insertCostForest2 - this.insertCosts.get( child ).forestCost );
NodeMapping< T > insertMapping = NodeMappings.empty( insertCostForest2 - this.insertCosts.get( child ).forestCost );
NodeMapping< T > childMapping = forestMapping( forest1, child );
return NodeMappings.compose( insertCosts, childMapping );
return NodeMappings.compose( insertMapping, childMapping );
} );
}

Expand All @@ -361,9 +372,9 @@ private NodeMapping< T > forestDeleteMapping( Tree< T > forest1, Tree< T > fores
double deleteCostForest1 = deleteCosts.get( forest1 ).forestCost;
return findBestMapping( forest1.getChildren(), child ->
{
NodeMapping< T > deleteCosts = NodeMappings.empty( deleteCostForest1 - this.deleteCosts.get( child ).forestCost );
NodeMapping< T > deleteMapping = NodeMappings.empty( deleteCostForest1 - this.deleteCosts.get( child ).forestCost );
NodeMapping< T > childMapping = forestMapping( child, forest2 );
return NodeMappings.compose( deleteCosts, childMapping );
return NodeMappings.compose( deleteMapping, childMapping );
} );
}

Expand Down Expand Up @@ -396,30 +407,7 @@ private NodeMapping< T > minCostMaxFlow( final Tree< T > forest1, final Tree< T
String emptyTree1 = "empty1";
String emptyTree2 = "empty2";

FlowNetwork network = new FlowNetwork();
network.addVertices( Arrays.asList( source, sink, emptyTree1, emptyTree2 ) );
network.addVertices( childrenForest1 );
network.addVertices( childrenForest2 );

int n1 = childrenForest1.size();
int n2 = childrenForest2.size();
network.addEdge( source, emptyTree1, n2 - Math.min( n1, n2 ), 0 );
network.addEdge( emptyTree1, emptyTree2, Math.max( n1, n2 ) - Math.min( n1, n2 ), 0 ); // this edge is not needed
network.addEdge( emptyTree2, sink, n1 - Math.min( n1, n2 ), 0 );

for ( Tree< T > child1 : childrenForest1 )
{
network.addEdge( source, child1, 1, 0 );
network.addEdge( child1, emptyTree2, 1, deleteCosts.get( child1 ).treeCost );
for ( Tree< T > child2 : childrenForest2 )
network.addEdge( child1, child2, 1, treeMapping( child1, child2 ).getCost() );
}

for ( Tree< T > child2 : childrenForest2 )
{
network.addEdge( child2, sink, 1, 0 );
network.addEdge( emptyTree1, child2, 1, insertCosts.get( child2 ).treeCost );
}
FlowNetwork network = buildFlowNetwork( source, sink, emptyTree1, emptyTree2, childrenForest1, childrenForest2 );

network.solveMaxFlowMinCost( source, sink );

Expand All @@ -441,6 +429,38 @@ private NodeMapping< T > minCostMaxFlow( final Tree< T > forest1, final Tree< T
return NodeMappings.compose( childMappings );
}

private FlowNetwork buildFlowNetwork(
String source, String sink, String emptyTree1, String emptyTree2, Collection< Tree< T > > childrenForest1,
Collection< Tree< T > > childrenForest2
)
{
FlowNetwork network = new FlowNetwork();
network.addVertices( Arrays.asList( source, sink, emptyTree1, emptyTree2 ) );
network.addVertices( childrenForest1 );
network.addVertices( childrenForest2 );

int numberOfChildrenForest1 = childrenForest1.size();
int numberOfChildrenForest2 = childrenForest2.size();
int minNumberOfChildren = Math.min( numberOfChildrenForest1, numberOfChildrenForest2 );
network.addEdge( source, emptyTree1, numberOfChildrenForest2 - minNumberOfChildren, 0 );
network.addEdge( emptyTree2, sink, numberOfChildrenForest1 - minNumberOfChildren, 0 );

for ( Tree< T > child1 : childrenForest1 )
{
network.addEdge( source, child1, 1, 0 );
network.addEdge( child1, emptyTree2, 1, deleteCosts.get( child1 ).treeCost );
for ( Tree< T > child2 : childrenForest2 )
network.addEdge( child1, child2, 1, treeMapping( child1, child2 ).getCost() );
}

for ( Tree< T > child2 : childrenForest2 )
{
network.addEdge( child2, sink, 1, 0 );
network.addEdge( emptyTree1, child2, 1, insertCosts.get( child2 ).treeCost );
}
return network;
}

/**
* Returns true if the flow value equal to 1. Returns false if the flow value equal to 0.
* Throws an {@link AssertionError} if the flow value is neither 0 nor 1.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package org.mastodon.mamut.treesimilarity.util;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

import org.jgrapht.alg.interfaces.MinimumCostFlowAlgorithm;
import org.jgrapht.graph.DefaultWeightedEdge;
import org.jgrapht.graph.SimpleDirectedWeightedGraph;

/**
* A utility class that encapsulates a graph ({@link SimpleDirectedWeightedGraph}), a corresponding map of edge capacities and a minimum cost flow solution for this graph.<p>
*
* @author Matthias Arzt
* @author Stefan Hahmann
*/
public class FlowNetwork
{

private final SimpleDirectedWeightedGraph< Object, DefaultWeightedEdge > graph = new SimpleDirectedWeightedGraph<>( DefaultWeightedEdge.class );

private final Map< DefaultWeightedEdge, Integer > capacities = new HashMap<>();

private MinimumCostFlowAlgorithm.MinimumCostFlow< DefaultWeightedEdge > flow;

/**
* Adds a collection of vertices to the graph.
* @param vertices vertices to add
*/
public void addVertices( Collection< ? > vertices )
{
vertices.forEach( graph::addVertex );
}

/**
* Adds an edge to the graph with the given capacity and weight.
* @param source source vertex
* @param target target vertex
* @param capacity capacity of the edge
* @param weight weight of the edge
*/
public void addEdge( Object source, Object target, int capacity, double weight )
{
DefaultWeightedEdge e1 = graph.addEdge( source, target );
graph.setEdgeWeight( e1, weight );
capacities.put( e1, capacity );
}

/**
* Solves the maximum flow minimum cost problem on the graph for the given source and sink.
* @param source source vertex
* @param sink sink vertex
*/
public void solveMaxFlowMinCost( Object source, Object sink )
{
flow = JGraphtTools.maxFlowMinCost( graph, capacities, source, sink );
}

/**
* Returns the flow on the edge from source to target.<p>
* NB: The flow is only defined after {@link #solveMaxFlowMinCost(Object, Object)} has been called at least once.
* @param source source vertex
* @param target target vertex
* @return the flow on the edge from source to target
*/
public double getFlow( Object source, Object target )
{
if ( flow == null )
throw new IllegalStateException( "Flow is not defined. Call solveMaxFlowMinCost() first." );
return flow.getFlow( graph.getEdge( source, target ) );
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.mastodon.mamut.treesimilarity;
package org.mastodon.mamut.treesimilarity.util;

import org.jgrapht.alg.flow.PushRelabelMFImpl;
import org.jgrapht.alg.flow.mincost.CapacityScalingMinimumCostFlow;
Expand All @@ -21,14 +21,14 @@ private JGraphtTools()
}

/**
* Computes a maximum (source, sink)-flow of minimum cost and returns its cost.
* G is a digraph with edge costs and capacities. There is a source node s and a sink node t. This function finds a maximum flow from s to t whose total cost is minimized.
* Computes a maximum (source, sink)-flow of minimum cost and returns it.
* Assuming {@code graph} is a digraph with edge costs and capacities. There is a source node s and a sink node t. This function finds a maximum flow from s to t whose total cost is minimized.
*
* @param graph a directed graph with edge costs (i.e. edge weights)
* @param capacities a map from edges to their capacities
* @param source the source node
* @param sink the sink node
* @return the minimum cost of the maximum flow
* @return the maximum flow of minimum cost
*/
public static < V > MinimumCostFlowAlgorithm.MinimumCostFlow< DefaultWeightedEdge > maxFlowMinCost( final SimpleDirectedWeightedGraph< V, DefaultWeightedEdge > graph,
final Map< DefaultWeightedEdge, Integer > capacities, final V source, final V sink )
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package org.mastodon.mamut.treesimilarity;
package org.mastodon.mamut.treesimilarity.util;

import java.util.HashMap;
import java.util.Map;

import org.mastodon.mamut.treesimilarity.ZhangUnorderedTreeEditDistance;
import org.mastodon.mamut.treesimilarity.tree.Tree;

/**
Expand All @@ -15,7 +16,7 @@
*
* @see NodeMappings
*/
interface NodeMapping< T >
public interface NodeMapping< T >
{

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package org.mastodon.mamut.treesimilarity;
package org.mastodon.mamut.treesimilarity.util;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

import org.mastodon.mamut.treesimilarity.ZhangUnorderedTreeEditDistance;
import org.mastodon.mamut.treesimilarity.tree.Tree;

/**
* Utility class for {@link ZhangUnorderedTreeEditDistance} that provides
* static factory methods for the easy creation of {@link NodeMapping}s.
*/
class NodeMappings
public class NodeMappings
{
private NodeMappings()
{
Expand All @@ -37,9 +38,9 @@ public static < T > NodeMapping< T > singleton( double cost, Tree< T > tree1, Tr
}

/**
* @return A {@link NodeMapping} that represents a composed that contains
* all the map entries of the given {@code children}. The costs of the
* composed mapping is the sum of the costs of the children.
* @return A {@link NodeMapping} that represents a composed mapping that
* contains all the map entries of the given {@code children}. The costs of
* the composed mapping is the sum of the costs of the children.
*/
@SafeVarargs
public static < T > NodeMapping< T > compose( NodeMapping< T >... children )
Expand All @@ -48,9 +49,9 @@ public static < T > NodeMapping< T > compose( NodeMapping< T >... children )
}

/**
* @return A {@link NodeMapping} that represents a composed that contains
* all the map entries of the given {@code children}. The costs of the
* composed mapping is the sum of the costs of the children.
* @return A {@link NodeMapping} that represents a composed mapping that
* contains all the map entries of the given {@code children}. The costs of
* the composed mapping is the sum of the costs of the children.
*/
public static < T > NodeMapping< T > compose( List< NodeMapping< T > > children )
{
Expand All @@ -71,10 +72,6 @@ public double getCost()
{
return cost;
}

@Override
public abstract void writeToMap( Map< Tree< T >, Tree< T > > map );

}

private static class EmptyNodeMapping< T > extends AbstractNodeMapping< T >
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package org.mastodon.mamut.treesimilarity;
package org.mastodon.mamut.treesimilarity.util;

import org.jgrapht.graph.DefaultWeightedEdge;
import org.jgrapht.graph.SimpleDirectedWeightedGraph;
import org.junit.Test;
import org.mastodon.mamut.treesimilarity.JGraphtTools;

import java.util.HashMap;
import java.util.Map;
Expand Down
Loading

0 comments on commit b4cb6f9

Please sign in to comment.