Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep Galley Plans Per Approximate Sparsity Pattern #679

Merged
merged 28 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5896d1a
add issimilar and get_cannonical_stats
Dec 19, 2024
7b69f89
cleanup
Dec 19, 2024
6761338
small fixes
Dec 19, 2024
4bbf23f
fix stats construction
Dec 19, 2024
092579c
last fix, definitely true this time
Dec 19, 2024
79d836d
fix NaiveStats constructor, again
Dec 20, 2024
2ed581d
small change to verbose passing
Dec 20, 2024
e835a12
fix deferred hash issue #664
Dec 20, 2024
4de8a9c
fix deferred equality check
Dec 20, 2024
85ac56e
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
willow-ahrens Dec 28, 2024
c53f9fb
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
willow-ahrens Dec 28, 2024
354db43
Merge remote-tracking branch 'origin/main' into kbd-make-galley-adapt…
willow-ahrens Dec 28, 2024
b5b2831
more accurate benchmark
willow-ahrens Dec 28, 2024
0c8ab41
add evaluation count to benchmark setup
willow-ahrens Dec 28, 2024
bff9d0d
rename GalleyExecutor to AdaptiveExecutor
Jan 2, 2025
5aa6d01
warn when the tag argument is given to the AdaptiveExecutor
Jan 2, 2025
816d974
remove with_scheduler issue
Jan 2, 2025
ffa437f
bug fix
Jan 2, 2025
cf19c69
lowering default threshold & warn on tag argument
Jan 2, 2025
d24a6af
update high-level benchmarks to set scheduler in setup
Jan 2, 2025
e31b572
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
kylebd99 Jan 2, 2025
5fd941b
small fix
Jan 3, 2025
f52839f
add compute overhead check
Jan 3, 2025
595f7b6
temporarily remove galley scheduler
Jan 3, 2025
910b033
add galley_scheduler back
Jan 6, 2025
1dbc64c
drop the warn on tag arg
Jan 6, 2025
bd2a62d
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
willow-ahrens Jan 6, 2025
fb7d3ca
Merge branch 'main' into kbd-make-galley-adaptive-to-inputs
willow-ahrens Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,12 @@ for (scheduler_name, scheduler) in [
end

let
A = Tensor(Dense(SparseList(Element(0.0))), fsprand(1, 1, 1))
x = rand(1)
SUITE["high-level"]["einsum_spmv_call_overhead"][scheduler_name] = @benchmarkable(
begin
A, x = ($A, $x)
@einsum y[i] += A[i, j] * x[j]
end,
setup = (A = Tensor(Dense(SparseList(Element(0.0))), fsprand(1, 1, 1)); x = rand(1)),
evals = 1
)
end

Expand Down
2 changes: 1 addition & 1 deletion src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ include("interface/einsum.jl")
include("Galley/Galley.jl")
using .Galley

export galley_scheduler
export galley_scheduler, GalleyOptimizer, GalleyExecutorCode, GalleyExecutor

@deprecate default fill_value
@deprecate redefault! set_fill_value!
Expand Down
4 changes: 2 additions & 2 deletions src/FinchLogic/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ function Base.:(==)(a::LogicNode, b::LogicNode)
elseif a.kind === immediate
return b.kind === immediate && a.val === b.val
elseif a.kind === deferred
return b.kind === deferred && a.val === b.val && a.type === b.type
return b.kind === deferred && a.ex === b.ex && a.type === b.type
elseif a.kind === field
return b.kind === field && a.name == b.name
elseif a.kind === alias
Expand All @@ -370,7 +370,7 @@ function Base.hash(a::LogicNode, h::UInt)
elseif istree(a)
return hash(a.kind, hash(a.children, h))
elseif a.kind === deferred
return hash(a.kind, hash(a.val, hash(a.type, h)))
return hash(a.kind, hash(a.ex, hash(a.type, h)))
else
error("unimplemented")
end
Expand Down
88 changes: 86 additions & 2 deletions src/Galley/FinchCompat/executor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,100 @@
julia_prgm
end

function Finch.set_options(ctx::GalleyOptimizer; estimator=DCStats)
function Finch.set_options(ctx::GalleyOptimizer; estimator=DCStats, verbose=false)
ctx.estimator=estimator
ctx.verbose=verbose
return ctx
end

"""
get_stats_dict(ctx::GalleyOptimizer, prgm)

Returns a dictionary mapping the location of input tensors in the program to their statistics objects.
"""
function get_stats_dict(ctx::GalleyOptimizer, prgm)
deferred_prgm = Finch.defer_tables(:prgm, prgm)
expr_stats_dict = Dict()
for node in PostOrderDFS(deferred_prgm)
if node.kind == table
expr_stats_dict[node.tns.ex] = ctx.estimator(node.tns.imm, [i.name for i in node.idxs])
end
end
return expr_stats_dict
end

"""
GalleyExecutor(ctx::GalleyOptimizer, tag=:global, verbose=false)

Executes a logic program by compiling it with the given compiler `ctx`. Compiled
codes are cached for each program structure. If the 'tag' argument is ':global', it maintains a set of plans
for inputs with different sparsity structures. In this case, it first checks the cache for a plan that
was compiled for similar inputs and only compiles if it doesn't find one. If the `tag` argument is anything else,
it will only compile once for that tag and will skip this search process.
"""
@kwdef struct GalleyExecutor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like the GalleyExecutor to have a configurable statistics similarity threshold, and to store the caches for each threshold and statistics type separately.

Also, I'd like to choose different executors for different compilation strategies (rather than using a sentinel tag value). Perhaps we can use the current executor for the "use first input strategy", and the GalleyExecutor for the "similar inputs" strategy. Then the GalleyExecutor wouldn't need a tag.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It already caches it per-statistics type, and I can do the same for the threshold by making it a member of the struct. Will do that.

  2. I like this idea. It seems clearer. In this case, we need better names. Maybe we call this one "AdaptiveExecutor" and the other one "TagExecutor"? And only schedulers which rely on statistics will be valid choices for the "AdaptiveExecutor".

I'll make these changes.

ctx::GalleyOptimizer
tag
verbose
end

Base.:(==)(a::GalleyExecutor, b::GalleyExecutor) = a.ctx == b.ctx && a.verbose == b.verbose
Base.hash(a::GalleyExecutor, h::UInt) = hash(GalleyExecutor, hash(a.ctx, hash(a.verbose, h)))

Check warning on line 82 in src/Galley/FinchCompat/executor.jl

View check run for this annotation

Codecov / codecov/patch

src/Galley/FinchCompat/executor.jl#L81-L82

Added lines #L81 - L82 were not covered by tests

GalleyExecutor(ctx::GalleyOptimizer; tag = :global, verbose = false) = GalleyExecutor(ctx, tag, verbose)
function Finch.set_options(ctx::GalleyExecutor; tag = ctx.tag, verbose = ctx.verbose, kwargs...)
GalleyExecutor(Finch.set_options(ctx.ctx; verbose=verbose, kwargs...), tag, verbose)
end

galley_codes = Dict()
function (ctx::GalleyExecutor)(prgm)
(f, code) = if ctx.tag == :global
cur_stats_dict = get_stats_dict(ctx.ctx, prgm)
stats_list = get!(galley_codes, (ctx.ctx, ctx.tag, Finch.get_structure(prgm)), [])
valid_match = nothing
for (stats_dict, f_code) in stats_list
if all(issimilar(cur_stats, stats_dict[cur_expr], 4) for (cur_expr, cur_stats) in cur_stats_dict)
valid_match = f_code
end
end
if isnothing(valid_match)
thunk = Finch.logic_executor_code(ctx.ctx, prgm)
valid_match = (eval(thunk), thunk)
push!(stats_list, (cur_stats_dict, valid_match))
end
valid_match
else
get!(galley_codes, (ctx.ctx, ctx.tag, Finch.get_structure(prgm))) do
thunk = Finch.logic_executor_code(ctx.ctx, prgm)
(eval(thunk), thunk)
end
end
if ctx.verbose
println("Executing:")
display(code)
end
return Base.invokelatest(f, prgm)
end

"""
GalleyExecutorCode(ctx)

Return the code that would normally be used by the GalleyExecutor to run a program.
"""
struct GalleyExecutorCode
ctx
end

function (ctx::GalleyExecutorCode)(prgm)
return Finch.logic_executor_code(ctx.ctx, prgm)

Check warning on line 129 in src/Galley/FinchCompat/executor.jl

View check run for this annotation

Codecov / codecov/patch

src/Galley/FinchCompat/executor.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end

"""
galley_scheduler(verbose = false, estimator=DCStats)

The galley scheduler uses the sparsity patterns of the inputs to optimize the computation.
The first set of inputs given to galley is used to optimize, and the `estimator` is used to
estimate the sparsity of intermediate computations during optimization.
"""
galley_scheduler(; verbose = false, estimator=DCStats) = Finch.LogicExecutor(GalleyOptimizer(verbose=verbose, estimator=estimator); verbose=verbose)
galley_scheduler(;verbose=false) = GalleyExecutor(GalleyOptimizer(;verbose=false); verbose=false)

2 changes: 1 addition & 1 deletion src/Galley/Galley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export PlanNode, Value, Index, Alias, Input, MapJoin, Aggregate, Materialize, Qu
export Scalar, Σ, Mat, Agg
export DCStats, NaiveStats, TensorDef, DC, insert_statistics
export naive, greedy, pruned, exact
export GalleyOptimizer, galley_scheduler
export GalleyOptimizer, GalleyExecutor, GalleyExecutorCode, galley_scheduler

IndexExpr = Symbol
TensorId = Symbol
Expand Down
Loading
Loading