Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange behavior of bm.hessian() #661

Closed
3 tasks done
Dr-Chen-Xiaoyu opened this issue Apr 10, 2024 · 4 comments
Closed
3 tasks done

Strange behavior of bm.hessian() #661

Dr-Chen-Xiaoyu opened this issue Apr 10, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@Dr-Chen-Xiaoyu
Copy link

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
  • If applicable, include full error messages/tracebacks.

Hi, Chaoming,

I am trying to use bm.hessian() to compute the hessian matrix of parameters of a model as to a loss, just like using bm.grad() for gradients.

import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
print('bp version:', bp.__version__)
bm.set_mode(bm.training_mode)
bm.random.seed(321)
# bp version: 2.4.6.post5

class RNN(bp.DynamicalSystem):
    def __init__(self, num_in, num_hidden):
        super(RNN, self).__init__()
        self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
        self.out = bp.dnn.Dense(num_hidden, 1)
    def update(self, x):
        return self.out(self.rnn(x))

# define the loss function
def lossfunc(inputs, targets):
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    predicts = runner.predict(inputs)
    loss = bp.losses.mean_squared_error(predicts, targets)
    return loss

model = RNN(1, 2)
data_x=bm.random.rand(1,1000,1)
data_y=data_x+bm.random.randn(1,1000,1)
print(lossfunc(data_x,data_y))
#1.0623081

It works well with bm.grad:

lossgrad = bm.grad(lossfunc, grad_vars=model.train_vars(), return_value=True)
grad_vector=lossgrad(data_x,data_y)
print(grad_vector[0]['Dense0.W'])
#[[-0.03799846]
# [-0.38051015]]

I expect to return a nested hessian matrix just like the 2nd example in https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html, but get strange behavior with bm.hessian():

losshess = bm.hessian(lossfunc, grad_vars=model.train_vars(), return_value=True)
hess_matrix=losshess(data_x,data_y)
print(hess_matrix[0]['Dense0.W'])
print(hess_matrix[1]['Dense0.W'])
#Dense0.W
#[[-0.03799846]
# [-0.38051015]]

By the way, appreciate a lot if some examples could be provided in the document of https://brainpy.readthedocs.io/en/latest/apis/generated/brainpy.math.hessian.html 😊

Best,
Xiaoyu Chen, SJTU

@Dr-Chen-Xiaoyu Dr-Chen-Xiaoyu added the bug Something isn't working label Apr 10, 2024
@Dr-Chen-Xiaoyu
Copy link
Author

Inspired from jax.hessian(), I tried to modify the loss function to let the parameters of interest to be the inputs.

def lossfunc_2(w,inputs, targets):
    model.out.W=w
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    predicts = runner.predict(inputs)
    loss = bp.losses.mean_squared_error(predicts, targets)
    return loss

It seems to work:

losshess = bm.hessian(lossfunc_2,argnums=0)
hess_matrix=losshess(bm.zeros((2,1)),data_x,data_y)
print(hess_matrix.squeeze())
#[[0.03443691 0.02228835]
# [0.02228835 0.2377306 ]]

I guess runner.predict() could be transformed by brainpy into a pure function? and then running with the original jax.hessian() style?

chaoming0625 added a commit that referenced this issue Apr 11, 2024
@chaoming0625
Copy link
Collaborator

Thanks for the report. I have submitted a PR for fixing the error. Currently, the new API can produce the same behavior of the functional jax.hessian(). Please try the new API after the PR #662 has been merged.

@Dr-Chen-Xiaoyu
Copy link
Author

Thanks for the report. I have submitted a PR for fixing the error. Currently, the new API can produce the same behavior of the functional jax.hessian(). Please try the new API after the PR #662 has been merged.

Thanks~😊

chaoming0625 added a commit that referenced this issue Apr 14, 2024
* fix issue #661

* fix tests

* updates
@chaoming0625
Copy link
Collaborator

PR #662 has been merged. So I close this issue. Reopen any time if there are additional questions.

Routhleck added a commit that referenced this issue May 12, 2024
This reverts commit 4bd1898.
chaoming0625 pushed a commit that referenced this issue May 14, 2024
…670)

* Revert "fix issue #661 (#662)"

This reverts commit 4bd1898.

* Replace

* Support jax==0.4.28

This reverts commit 59fb681.

* Fix JIT bugs and Replace deprecated functions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants