Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Apr 14, 2024
1 parent 302e574 commit c667790
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
14 changes: 7 additions & 7 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def test1(self):
hh = bp.dyn.HH(1)
hh.reset()

tree = jax.tree_structure(hh)
leaves = jax.tree_leaves(hh)
tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)

print(tree)
print(leaves)
print(jax.tree_unflatten(tree, leaves))
print(jax.tree.unflatten(tree, leaves))
print()


Expand Down Expand Up @@ -281,13 +281,13 @@ def not_close(x, y):
def all_close(x, y):
assert bm.allclose(x, y)

jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)

random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)

obj.load_state_dict(random_state)
jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)



Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ tqdm
pathos
taichi
numba
braincore
braintools


# test requirements
Expand Down

0 comments on commit c667790

Please sign in to comment.