Skip to content

Commit

Permalink
Fix cx in Hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Dec 3, 2024
1 parent 3b6a666 commit 99dfac8
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,31 +449,32 @@ end
b.v[col] = 1
end

function _gradient!(dx, f, x, y, obj_weight, cx)
function _gradient!(dx, , x, y, obj_weight, cx)
Enzyme.make_zero!(dx)
dcx = make_zero(cx)
res = Enzyme.autodiff(
Enzyme.Reverse,
f,
,
Enzyme.Active,
Enzyme.Duplicated(x, dx),
Enzyme.Const(y),
Enzyme.Const(obj_weight),
Enzyme.Const(cx)
Enzyme.Duplicated(cx, dcx)
)
return nothing
end

function _hvp!(res, f, x, v, y, obj_weight, cx)
# grad = Enzyme.make_zero(x)
function _hvp!(res, , x, v, y, obj_weight, cx)
dcx = make_zero(cx)
Enzyme.autodiff(
Enzyme.Forward,
_gradient!,
res,
Enzyme.Const(f),
Enzyme.Const(),
Enzyme.Duplicated(x, v),
Enzyme.Const(y),
Enzyme.Const(obj_weight),
Enzyme.Const(cx),
Enzyme.Duplicated(cx, dcx),
)
return nothing
end
Expand Down

0 comments on commit 99dfac8

Please sign in to comment.