diff --git a/src/scheduling.jl b/src/scheduling.jl index ae04929..73aa7d8 100644 --- a/src/scheduling.jl +++ b/src/scheduling.jl @@ -52,13 +52,25 @@ end struct Event data_flow::DataFlowGraph - store::MetaDiGraph + store::Dict{Int, Dagger.DTask} event_number::Int function Event(data_flow::DataFlowGraph, event_number::Int = 0) - new(data_flow, MetaDiGraph(data_flow.graph), event_number) + new(data_flow, Dict{Int, Dagger.DTask}(), event_number) end end +function put_result!(event::Event, index::Int, result::Dagger.DTask) + return event.store[index] = result +end + +function get_result(event::Event, index::Int)::Dagger.DTask + return event.store[index] +end + +function get_results(event::Event, vertices::Vector{Int}) + return get_result.(Ref(event), vertices) +end + function notify_graph_finalization(notifications::RemoteChannel, graph_id::Int, terminating_results...) println("Graph $graph_id: all tasks in the graph finished!") @@ -66,10 +78,6 @@ function notify_graph_finalization(notifications::RemoteChannel, graph_id::Int, println("Graph $graph_id: notified!") end -function get_promises(graph::MetaDiGraph, vertices::Vector) - return [get_prop(graph, v, :res_data) for v in vertices] -end - function is_terminating_alg(graph::AbstractGraph, vertex_id::Int) successor_dataobjects = outneighbors(graph, vertex_id) is_terminating(vertex) = isempty(outneighbors(graph, vertex)) @@ -78,7 +86,7 @@ end function schedule_algorithm(event::Event, vertex_id::Int, coefficients::Union{Dagger.Shard, Nothing}) - incoming_data = get_promises(event.store, inneighbors(event.store, vertex_id)) + incoming_data = get_results(event, inneighbors(event.data_flow.graph, vertex_id)) algorithm = get_algorithm(event.data_flow, vertex_id) if isnothing(coefficients) alg_helper(data...) = algorithm(data...; coefficients = missing) @@ -92,11 +100,12 @@ function schedule_graph(event::Event, coefficients::Union{Dagger.Shard, Nothing} terminating_results = Dagger.DTask[] for vertex_id in event.data_flow.algorithm_indices res = schedule_algorithm(event, vertex_id, coefficients) - set_prop!(event.store, vertex_id, :res_data, res) + put_result!(event, vertex_id, res) for v in outneighbors(event.data_flow.graph, vertex_id) - set_prop!(event.store, v, :res_data, res) + put_result!(event, v, res) end - is_terminating_alg(event.data_flow.graph, vertex_id) && push!(terminating_results, res) + is_terminating_alg(event.data_flow.graph, vertex_id) && + push!(terminating_results, res) end return terminating_results diff --git a/test/scheduling.jl b/test/scheduling.jl index 09251b6..21a3293 100644 --- a/test/scheduling.jl +++ b/test/scheduling.jl @@ -63,7 +63,7 @@ end end function get_tid(node_id::String)::Int - task = get_prop(event.store, graph[node_id, :node_id], :res_data) + task = FrameworkDemo.get_result(event, graph[node_id, :node_id]) return task_to_tid[task.uid] end