diff --git a/src/enzyme.jl b/src/enzyme.jl index 96f516aa..4e2741f4 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -449,31 +449,32 @@ end b.v[col] = 1 end - function _gradient!(dx, f, x, y, obj_weight, cx) + function _gradient!(dx, ℓ, x, y, obj_weight, cx) Enzyme.make_zero!(dx) + dcx = make_zero(cx) res = Enzyme.autodiff( Enzyme.Reverse, - f, + ℓ, Enzyme.Active, Enzyme.Duplicated(x, dx), Enzyme.Const(y), Enzyme.Const(obj_weight), - Enzyme.Const(cx) + Enzyme.Duplicated(cx, dcx) ) return nothing end - function _hvp!(res, f, x, v, y, obj_weight, cx) - # grad = Enzyme.make_zero(x) + function _hvp!(res, ℓ, x, v, y, obj_weight, cx) + dcx = make_zero(cx) Enzyme.autodiff( Enzyme.Forward, _gradient!, res, - Enzyme.Const(f), + Enzyme.Const(ℓ), Enzyme.Duplicated(x, v), Enzyme.Const(y), Enzyme.Const(obj_weight), - Enzyme.Const(cx), + Enzyme.Duplicated(cx, dcx), ) return nothing end