Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for pipeline scheduling #40

Merged
merged 4 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name = "FrameworkDemo"
uuid = "cfbf7e84-66d2-421e-b147-9edb7a8672d2"
authors = ["SmalRat <[email protected]>",
"Mateusz Jakub Fila <[email protected]>",
"hegner <[email protected]>",
"Josh Ott <[email protected]>"]
authors = ["SmalRat <[email protected]>", "Mateusz Jakub Fila <[email protected]>", "hegner <[email protected]>", "Josh Ott <[email protected]>"]
version = "0.1.0"

[deps]
Expand Down Expand Up @@ -34,22 +31,24 @@ Colors = "0.12"
Dagger = "0.18.11"
DaggerWebDash = "0.1.3"
DataFrames = "1.6"
Dates = "1.10"
Distributed = "1.10"
EzXML = "1.2"
FileIO = "1.16"
GraphViz = "0.2"
Graphs = "1"
LightGraphs = "1.3"
Logging = "1.10"
MetaGraphs = "0.7"
Plots = "1.40"
TimespanLogging = "0.1.0"
julia = "1.10"
Dates = "1.10"
Distributed = "1.10"
Printf = "1.10"
Test = "1.10"
TimespanLogging = "0.1.0"
julia = "1.10"

[extras]
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Logging", "Test"]
8 changes: 4 additions & 4 deletions bin/schedule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ function main()
max_concurrent = args["max-concurrent"]
fast = args["fast"]

@time "Pipeline execution" FrameworkDemo.run_events(graph;
event_count = event_count,
max_concurrent = max_concurrent,
fast = fast)
@time "Pipeline execution" FrameworkDemo.run_pipeline(graph;
event_count = event_count,
max_concurrent = max_concurrent,
fast = fast)

if !isnothing(args["dot-trace"])
logs = Dagger.fetch_logs!()
Expand Down
10 changes: 9 additions & 1 deletion src/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ end
function fetch_LocalEventLog()
ctx = Dagger.Sch.eager_context()
logs = Dagger.TimespanLogging.get_logs!(ctx.log_sink)
# str = Dagger.show_plan() - doesn't work (exist)
# str = Dagger.show_plan() - doesn't work (exist)
return logs
end

Expand Down Expand Up @@ -99,3 +99,11 @@ function save_logs(log_file, logs)
write(io, logs)
end
end

function dispatch_begin_msg(index)
"Dispatcher: scheduled graph $index"
end

function dispatch_end_msg(index)
"Dispatcher: finished graph $index"
end
17 changes: 10 additions & 7 deletions src/scheduling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ function calibrate_crunch(; fast::Bool = false)::Union{Dagger.Shard, Nothing}
return fast ? nothing : Dagger.@shard calculate_coefficients()
end

function run_events(graph::MetaDiGraph;
event_count::Int,
max_concurrent::Int,
fast::Bool = false)
function run_pipeline(graph::MetaDiGraph;
event_count::Int,
max_concurrent::Int,
fast::Bool = false)
graphs_tasks = Dict{Int, Dagger.DTask}()
notifications = RemoteChannel(() -> Channel{Int}(max_concurrent))
coefficients = FrameworkDemo.calibrate_crunch(; fast = fast)
Expand All @@ -95,15 +95,18 @@ function run_events(graph::MetaDiGraph;
while length(graphs_tasks) >= max_concurrent
finished_graph_id = take!(notifications)
delete!(graphs_tasks, finished_graph_id)
println("Dispatcher: graph finished - graph $finished_graph_id")
@info dispatch_end_msg(finished_graph_id)
end

terminating_results = FrameworkDemo.schedule_graph(graph, coefficients)
graphs_tasks[idx] = Dagger.@spawn notify_graph_finalization(notifications, idx,
terminating_results...)

println("Dispatcher: scheduled graph $idx")
@info dispatch_begin_msg(idx)
end

values(graphs_tasks) .|> wait
for (idx, future) in graphs_tasks
wait(future)
@info dispatch_end_msg(idx)
end
end
26 changes: 26 additions & 0 deletions test/scheduling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using FrameworkDemo
using Dagger
using Graphs
using MetaGraphs
using Logging

function get_alg_timeline(logs::Dict)
timeline = Dict{Int, Any}()
Expand Down Expand Up @@ -50,6 +51,7 @@ end
wait.(tasks)

logs = Dagger.fetch_logs!()
Dagger.disable_logging!()
@test !isnothing(logs)

task_to_tid = lock(Dagger.Sch.EAGER_ID_MAP) do id_map
Expand Down Expand Up @@ -88,4 +90,28 @@ end
@test get_tid("TransformerAB") ∈ get_deps("ConsumerE")
@test get_tid("TransformerAB") ∈ get_deps("ConsumerCD")
end

@testset "Pipeline" begin
event_count = 5

test_logger = TestLogger()
with_logger(test_logger) do
FrameworkDemo.run_pipeline(graph;
max_concurrent = 3,
event_count = event_count,
fast = is_fast)
end
@testset "Start message" begin
messages = for i in 1:event_count
@test any(record -> record.message == FrameworkDemo.dispatch_begin_msg(i),
test_logger.logs)
end
end
@testset "Finish message" begin
messages = for i in 1:event_count
@test any(record -> record.message == FrameworkDemo.dispatch_end_msg(i),
test_logger.logs)
end
end
end
end
Loading