From 99dfac8f63b838a2fe13d6ea4a22f3ebc410d828 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Tue, 3 Dec 2024 12:43:47 -0600 Subject: [PATCH] Fix cx in Hessian --- src/enzyme.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/enzyme.jl b/src/enzyme.jl index 96f516aa..4e2741f4 100644 --- a/src/enzyme.jl +++ b/src/enzyme.jl @@ -449,31 +449,32 @@ end b.v[col] = 1 end - function _gradient!(dx, f, x, y, obj_weight, cx) + function _gradient!(dx, ℓ, x, y, obj_weight, cx) Enzyme.make_zero!(dx) + dcx = make_zero(cx) res = Enzyme.autodiff( Enzyme.Reverse, - f, + ℓ, Enzyme.Active, Enzyme.Duplicated(x, dx), Enzyme.Const(y), Enzyme.Const(obj_weight), - Enzyme.Const(cx) + Enzyme.Duplicated(cx, dcx) ) return nothing end - function _hvp!(res, f, x, v, y, obj_weight, cx) - # grad = Enzyme.make_zero(x) + function _hvp!(res, ℓ, x, v, y, obj_weight, cx) + dcx = make_zero(cx) Enzyme.autodiff( Enzyme.Forward, _gradient!, res, - Enzyme.Const(f), + Enzyme.Const(ℓ), Enzyme.Duplicated(x, v), Enzyme.Const(y), Enzyme.Const(obj_weight), - Enzyme.Const(cx), + Enzyme.Duplicated(cx, dcx), ) return nothing end