Skip to content

Commit

Permalink
fix: pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lvjonok committed Sep 3, 2024
1 parent e1110ea commit 6d6f912
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# jaxadi
# jaxadi
8 changes: 4 additions & 4 deletions examples/03_pinocchio.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@

# Evaluate the function performance
q_val = ca.np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0, 0])
jax_q_val = jnp.array([[0.1], [0.2], [0.3], [0.4], [
0.5], [0.6], [0.7], [0], [0]])
jax_q_val = jnp.array([[0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0], [0]])

print("Casadi evaluation:")
print(fk(q_val))
Expand All @@ -53,6 +52,8 @@
# Second part
# Casadi: Sequential Evaluation
N = int(1e7)


def casadi_sequential_evaluation():
for _ in range(N):
fk(q_val)
Expand All @@ -61,8 +62,7 @@ def casadi_sequential_evaluation():
# JAX: Vectorized Evaluation using vmap
jax_q_vals = jnp.tile(jax_q_val, (N, 1, 1)) # Create a batch of 100 inputs
print(jax_q_vals.shape)
jax_fn_vectorized = jax.vmap(jax_fn, in_axes=(
1,), out_axes=1) # Vectorize the function
jax_fn_vectorized = jax.vmap(jax_fn, in_axes=(1,), out_axes=1) # Vectorize the function

# Performance comparison
print("Performance comparison:")
Expand Down

0 comments on commit 6d6f912

Please sign in to comment.