class GraphOps[VD, ED] {
def pregel[A]
(initialMsg: A,
maxIter: Int = Int.MaxValue,
activeDir: EdgeDirection = EdgeDirection.Out)
(vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {
// Receive the initial message at each vertex
var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()
// compute the messages
var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
var activeMessages = messages.count()
// Loop until no messages remain or maxIterations is achieved
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages: -----------------------------------------------------------------------
// Run the vertex program on all vertices that receive messages
val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
// Merge the new vertex values back into the graph
g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache()
// Send Messages: ------------------------------------------------------------------------------
// Vertices that didn't receive a message above don't appear in newVerts and therefore don't
// get to send messages. More precisely the map phase of mapReduceTriplets is only invoked
// on edges in the activeDir of vertices in newVerts
messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDir))).cache()
activeMessages = messages.count()
i += 1
以下範例是我們可以使用Pregel運算子來表示單源最短路徑(Single source shortest path)的運算。
import org.apache.spark.graphx._
// Import random graph generation library
import org.apache.spark.graphx.util.GraphGenerators
// A graph with edge attributes containing distances
val graph: Graph[Int, Double] =
GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble)
val sourceId: VertexId = 42 // The ultimate source
// Initialize the graph such that all vertices except the root have distance infinity.
val initialGraph = graph.mapVertices((id, _) => if (id == sourceId) 0.0 else Double.PositiveInfinity)
val sssp = initialGraph.pregel(Double.PositiveInfinity)(
(id, dist, newDist) => math.min(dist, newDist), // Vertex Program
triplet => { // Send Message
if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
} else {
(a,b) => math.min(a,b) // Merge Message