From aa01cc52819b5092e13a99935919ee0eab8ec635 Mon Sep 17 00:00:00 2001 From: simeon Date: Fri, 20 Sep 2024 21:33:57 +0300 Subject: [PATCH] added demo notebook --- README.md | 6 +- examples/_demo.ipynb | 542 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 547 insertions(+), 1 deletion(-) create mode 100644 examples/_demo.ipynb diff --git a/README.md b/README.md index d4f8bf3..a0ae82e 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,14 @@ [![PyPI version](https://img.shields.io/pypi/v/jaxadi?color=blue)](https://pypi.org/project/jaxadi/) [![PyPI downloads](https://img.shields.io/pypi/dm/jaxadi?color=blue)](https://pypistats.org/packages/jaxadi) +

JAXADI Logo

-**JaxADi** is a powerful Python library designed to bridge the gap between `casadi.Function` and JAX-compatible functions. By leveraging the strengths of both CasADi and JAX, JAXADI opens up exciting opportunities for building highly efficient, batchable code that can be executed seamlessly across CPUs, GPUs, and TPUs. +**JaxADi** is a Python library designed to bridge the gap between `casadi.Function` and JAX-compatible functions. By leveraging the strengths of both CasADi and JAX, JAXADI opens up exciting opportunities for building highly efficient, batchable code that can be executed seamlessly across CPUs, GPUs, and TPUs. JAXADI can be particularly useful in scenarios involving: @@ -18,6 +19,9 @@ JAXADI can be particularly useful in scenarios involving: - Optimal control problems - Machine learning models with complex dynamics +Please go through the short demo to get fastly onboard + + ## Installation You can install JAXADI using pip: diff --git a/examples/_demo.ipynb b/examples/_demo.ipynb new file mode 100644 index 0000000..a85bfda --- /dev/null +++ b/examples/_demo.ipynb @@ -0,0 +1,542 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1lbr8ps99CvG" + }, + "source": [ + "\n", + "

\n", + " \"JAXADI\n", + "

\n", + "\n", + "Welcome to [JaxADi](https://github.com/based-robotics/jaxadi), a Python library designed to seamlessly bridge the gap between CasADi and JAX-compatible functions. By harnessing the power of both CasADi and JAX, JaxADi opens up a world of possibilities for creating highly efficient, batchable code that can be executed effortlessly across CPUs, GPUs, and TPUs.\n", + "\n", + "JaxADi shines in various scenarios, including:\n", + "\n", + "- Complex robotics simulations\n", + "- Challenging optimal control problems\n", + "- Machine learning models with intricate dynamics\n", + "\n", + "Let's dive in and explore the capabilities of JaxADi!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RqFh22rA-aNn" + }, + "source": [ + "# **Getting Started with JaxADi**\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZiPM5SudoAAz" + }, + "source": [ + "## **Installation**\n", + "\n", + "Getting JaxADi up and running is a breeze. Simply use pip to install the [package]((https://pypi.org/project/jaxadi/)):\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KDixGI1Nkp-F", + "outputId": "79004462-61bf-45ba-f44e-bc3ddfd56f82" + }, + "outputs": [], + "source": [ + "!pip install jaxadi" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Vo7v8I5j-kPb" + }, + "source": [ + "## **Basic Usage**\n", + "\n", + "JaxADi offers a straightforward and intuitive API. Let's start by defining an example CasADi function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "wCypWJi-9CHs", + "outputId": "f2c065ec-cfca-4b41-95b8-67cbfcc05a4f" + }, + "outputs": [], + "source": [ + "import casadi as cs\n", + "\n", + "# Define input variables\n", + "x = cs.SX.sym(\"x\", 3, 2)\n", + "y = cs.SX.sym(\"y\", 2, 2)\n", + "# Define a nonlinear function\n", + "z = x @ y # Matrix multiplication\n", + "z_squared = z * z # Element-wise squaring\n", + "z_sin = cs.sin(z) # Element-wise sine\n", + "result = z_squared + z_sin # Element-wise addition\n", + "# Create the CasADi function\n", + "casadi_fn = cs.Function(\"complex_nonlinear_func\", [x, y], [result])\n", + "casadi_fn" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3Zi8ejdw_OPu" + }, + "source": [ + "Get JAX-compatible function string representation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 122 + }, + "id": "CKalPhfF9B-E", + "outputId": "d5e85ce9-fdab-4ef0-e41e-1d5bc79a3400" + }, + "outputs": [], + "source": [ + "from jaxadi import translate\n", + "\n", + "# Get JAX-compatible function string representation\n", + "jax_fn_string = translate(casadi_fn)\n", + "jax_fn_string" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "67ZF1_XN_5Vi" + }, + "source": [ + "\n", + "Define JAX function from CasADi one" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "My7Eejgc_tIc", + "outputId": "6c0df722-4402-4b5f-bb10-eeba74c0e772" + }, + "outputs": [], + "source": [ + "from jaxadi import convert\n", + "\n", + "# Define JAX function from CasADi one\n", + "jax_fn = convert(casadi_fn, compile=True)\n", + "jax_fn" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "glj5zPU0ALAm" + }, + "source": [ + "Now, let's verify that our JaxADi function produces the same results as the original CasADi function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GwqWQoRj_qVo", + "outputId": "93230a30-449c-4fa2-fc3e-0bd03aba0834" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from jax import numpy as jnp\n", + "\n", + "# Run compiled function\n", + "input_x = np.random.rand(3, 2)\n", + "input_y = np.random.rand(2, 2)\n", + "output_jaxadi = np.array(jax_fn(jnp.array(input_x), jnp.array(input_y)))\n", + "output_casadi = np.array(casadi_fn(input_x, input_y))\n", + "if np.allclose(output_jaxadi, output_casadi):\n", + " print(\"The outputs of casadi and jaxadi functions are same\")\n", + "else:\n", + " print(\"Something went wrong...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BzJ8MzGUB5oT" + }, + "source": [ + "# **JaxADi in Action: Pendulum Rollout Example**\n", + "\n", + "To showcase the power of JaxADi, let's dive into a practical example: simulating an uncontrolled pendulum. We'll compare the performance of CasADi and JAX implementations for batch simulations.\n", + "\n", + "First, let's set up our pendulum model:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "zf5N5StGUAQT" + }, + "outputs": [], + "source": [ + "# Static parameters\n", + "dt = 0.02\n", + "g = 9.81 # Acceleration due to gravity\n", + "L = 1.0 # Length of the pendulum\n", + "b = 0.1 # Damping coefficient\n", + "I = 1.0 # Moment of inertia" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lDCwXXnqUDCD" + }, + "source": [ + "Define pendulum model as CasADi function" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "i2UVZWzCUCzM" + }, + "outputs": [], + "source": [ + "state = cs.SX.sym(\"state\", 2)\n", + "theta, omega = state[0], state[1]\n", + "\n", + "theta_dot = omega\n", + "omega_dot = (-b * omega - (g / L) * cs.sin(theta)) / I\n", + "\n", + "next_theta = theta + theta_dot * dt\n", + "next_omega = omega + omega_dot * dt\n", + "\n", + "next_state = cs.vertcat(next_theta, next_omega)\n", + "casadi_pendulum = cs.Function(\"pendulum_model\", [state], [next_state])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tqoq5eMTUWfe" + }, + "source": [ + "Convert it to JAX:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "ITdfIt86US4R" + }, + "outputs": [], + "source": [ + "jax_model = convert(casadi_pendulum, compile=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iZEz1rW8UgH4" + }, + "source": [ + "Now, let's implement rollout functions for both CasADi and JaxADi:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "AXPIw1EdUpYM" + }, + "outputs": [], + "source": [ + "import jax\n", + "\n", + "timesteps = 100\n", + "\n", + "\n", + "def casadi_sequential_rollout(initial_states):\n", + " batch_size = initial_states.shape[0]\n", + " rollout_states = np.zeros((timesteps + 1, batch_size, 2))\n", + "\n", + " rollout_states[0] = initial_states\n", + " for t in range(1, timesteps + 1):\n", + " rollout_states[t] = np.array([casadi_pendulum(state).full().flatten() for state in rollout_states[t - 1]])\n", + "\n", + " return rollout_states\n", + "\n", + "\n", + "@jax.jit\n", + "def jax_vectorized_rollout(initial_states):\n", + " def single_step(state):\n", + " return jnp.array(jax_model(state)).reshape(\n", + " 2,\n", + " )\n", + "\n", + " def scan_fn(carry, _):\n", + " next_state = jax.vmap(single_step)(carry)\n", + " return next_state, next_state\n", + "\n", + " _, rollout_states = jax.lax.scan(scan_fn, initial_states, None, length=timesteps)\n", + " return jnp.concatenate([initial_states[None, ...], rollout_states], axis=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c33TQFmLUymd" + }, + "source": [ + "Let's compare the performance of these two implementations:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OggXan5lUqaO", + "outputId": "1de0f2ba-6e4b-483e-8221-6d119bb20812" + }, + "outputs": [], + "source": [ + "import timeit\n", + "\n", + "# Test parameters\n", + "batch_size = 4096\n", + "\n", + "\n", + "def generate_random_inputs(batch_size):\n", + " return np.random.uniform(-np.pi, np.pi, (batch_size, 2))\n", + "\n", + "\n", + "initial_states = generate_random_inputs(batch_size)\n", + "print(\"Warming up JAX...\")\n", + "_ = jax_vectorized_rollout(initial_states)\n", + "print(\"Warm-up complete. Let's roll!\")\n", + "\n", + "print(\"\\nPerformance Showdown:\")\n", + "initial_states = generate_random_inputs(batch_size)\n", + "\n", + "print(f\"CasADi sequential rollout ({batch_size} pendulums, {timesteps} timesteps):\")\n", + "casadi_time = timeit.timeit(lambda: casadi_sequential_rollout(initial_states), number=1)\n", + "print(f\"Time: {casadi_time:.4f} seconds\")\n", + "\n", + "print(f\"\\nJAX vectorized rollout ({batch_size} pendulums, {timesteps} timesteps):\")\n", + "jax_time = timeit.timeit(lambda: np.array(jax_vectorized_rollout(initial_states)), number=1)\n", + "print(f\"Time: {jax_time:.4f} seconds\")\n", + "\n", + "print(f\"\\nSpeedup factor: {casadi_time / jax_time:.2f}x\")\n", + "\n", + "# Verify results\n", + "print(\"\\nDouble-checking our results:\")\n", + "casadi_results = casadi_sequential_rollout(initial_states[:10])\n", + "jax_results = np.array(jax_vectorized_rollout(initial_states[:10]))\n", + "\n", + "print(\"First 10 rollouts match:\", np.allclose(casadi_results, jax_results, atol=1e-4))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KfMQhfBaGr21" + }, + "source": [ + "\n", + "# **JaxADi + Other Libraries: A Perfect Match**\n", + "\n", + "JaxADi plays well with other CasADi-oriented libraries. Let's see how we can use it with [liecasadi](https://github.com/ami-iit/liecasadi) to vectorize the `log` method for SO3 groups:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "InLmX2TnG5cz", + "outputId": "ec405bd6-caaf-4dd9-b6db-b6d20f78c0d5" + }, + "outputs": [], + "source": [ + "!pip install liecasadi" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a_13cLEcfhbz" + }, + "source": [ + "Let us form the casadi function that takes quaternion and returns tangent:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "48eeUS-NG8AK" + }, + "outputs": [], + "source": [ + "from liecasadi import SO3\n", + "\n", + "# Create SO3 object from quaternion\n", + "quat = cs.SX.sym(\"quaternion\", 4)\n", + "transform = SO3(xyzw=quat)\n", + "# Get the tangent via Log and convert this to function\n", + "tang_vec = transform.log().vec\n", + "tang_fn = cs.Function(\"tangent_function\", [quat], [tang_vec])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nwbaSs6TZpVI" + }, + "source": [ + "Generate JAX function to calculate the tangent" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "wU6GU92kXA84" + }, + "outputs": [], + "source": [ + "jax_tang = convert(tang_fn, compile=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x7z3jLlBnpsc" + }, + "source": [ + "Test the functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ezXWgM_pXJJP", + "outputId": "62c64e10-b673-40c4-bcf5-55df943c02ed" + }, + "outputs": [], + "source": [ + "quat_random = np.random.randn(4)\n", + "quat_random /= np.linalg.norm(quat_random)\n", + "print(np.array(tang_fn(quat_random)).reshape(3))\n", + "print(np.array(jax_tang(quat_random)).reshape(3))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OrKDscHoZsTp" + }, + "source": [ + "With this JAX-compatible function, you can now easily batch the log operation and perform your sample-based calculations efficiently!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mrZcrpIEZIrY" + }, + "source": [ + "# **Wrapping Up**\n", + "\n", + "We've just scratched the surface of what's possible with JaxADi. There's a whole world of CasADi-oriented libraries out there waiting to be supercharged with JaxADi. We encourage you to explore how you can [use JaxADi to transform Pinocchio](https://github.com/based-robotics/jaxadi/blob/master/examples/04_pinocchio.py) calculations, [compare it with MJX](https://github.com/based-robotics/jaxadi/blob/master/examples/04_pinocchio.py), and dive into our repository for more [examples](https://github.com/based-robotics/jaxadi/tree/master/examples).\n", + "\n", + "We're always on the lookout for exciting applications, such as parallelizable Model Predictive Control (MPC). If you come up with something cool, don't hesitate to share it with the community!\n", + "\n", + "If JaxADi helps you in your research, we'd be thrilled if you could cite it:\n", + "\n", + "```bibtex\n", + "@misc{jaxadi2024,\n", + " title = {JaxADi: Bridging CasADi and JAX for Efficient Numerical Computing},\n", + " author = {Alentev, Igor and Kozlov, Lev and Nedelchev, Simeon},\n", + " year = {2024},\n", + " url = {https://github.com/based-robotics/jaxadi},\n", + " note = {Accessed: [Insert Access Date]}\n", + "}\n", + "```\n", + "\n", + "\n", + "Got questions, issues, or brilliant ideas? We'd love to hear from you! [Open an issue](https://github.com/based-robotics/jaxadi/issues) on our GitHub repository, and let's make JaxADi even better together.\n", + "\n", + "We hope JaxADi supercharges your numerical computing and optimization tasks. Now go forth and compute efficiently! Happy coding!" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}