Skip to content

Commit

Permalink
more AD fixing attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Nov 2, 2023
1 parent 15acc1b commit 8db1dfe
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
using Logging, ChainRulesCore
@non_differentiable Base.CoreLogging.current_logger(args...)
function ChainRulesCore.rrule(

Check warning on line 3 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L3

Added line #L3 was not covered by tests
rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(Base.CoreLogging.with_logger),
f::Function,
logger::AbstractLogger
)
return Base.CoreLogging.with_logger(logger) do
ChainRulesCore.rrule_via_ad(rc, f)

Check warning on line 10 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L9-L10

Added lines #L9 - L10 were not covered by tests
end
end
"""
$(TYPEDEF)
"""
Expand Down Expand Up @@ -41,7 +52,6 @@ end
AggregateLogger(logger::AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger)

function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...)
@ignore_derivatives begin
if convert(LogLevel, level) == LogLevel(-1) && haskey(kwargs, :progress)
pr = kwargs[:progress]
if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing)
Expand Down Expand Up @@ -84,7 +94,6 @@ function Logging.handle_message(l::AggregateLogger, level, message, _module, gro
end
end
Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...)

Check warning on line 96 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L96

Added line #L96 was not covered by tests
end # ignore_derivatives
end
Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...)

Check warning on line 98 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L98

Added line #L98 was not covered by tests
Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger)
Expand Down Expand Up @@ -132,7 +141,8 @@ function __solve(prob::AbstractEnsembleProblem,
if get(kwargs, :progress, false)
name = get(kwargs, :progress_name, "Ensemble")
for i in 1:trajectories
@logmsg(LogLevel(-1), "$name #$i", _id=Symbol("SciMLBase_$i"), progress=0)
msg = "$name #$i" # avoid try in logmsg ruining AD
@logmsg(LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0)
end

Check warning on line 146 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L142-L146

Added lines #L142 - L146 were not covered by tests
end

Expand Down

0 comments on commit 8db1dfe

Please sign in to comment.