diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index adc3b8b5cc..87a75e49cb 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -1,4 +1,15 @@ using Logging, ChainRulesCore +@non_differentiable Base.CoreLogging.current_logger(args...) +function ChainRulesCore.rrule( + rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, + ::typeof(Base.CoreLogging.with_logger), + f::Function, + logger::AbstractLogger +) + return Base.CoreLogging.with_logger(logger) do + ChainRulesCore.rrule_via_ad(rc, f) + end +end """ $(TYPEDEF) """ @@ -41,7 +52,6 @@ end AggregateLogger(logger::AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger) function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...) - @ignore_derivatives begin if convert(LogLevel, level) == LogLevel(-1) && haskey(kwargs, :progress) pr = kwargs[:progress] if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing) @@ -84,7 +94,6 @@ function Logging.handle_message(l::AggregateLogger, level, message, _module, gro end end Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...) - end # ignore_derivatives end Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...) Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger) @@ -132,7 +141,8 @@ function __solve(prob::AbstractEnsembleProblem, if get(kwargs, :progress, false) name = get(kwargs, :progress_name, "Ensemble") for i in 1:trajectories - @logmsg(LogLevel(-1), "$name #$i", _id=Symbol("SciMLBase_$i"), progress=0) + msg = "$name #$i" # avoid try in logmsg ruining AD + @logmsg(LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0) end end