Skip to content

Commit

Permalink
feat: add patches and update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
lvjonok committed Sep 17, 2024
1 parent c98f91f commit 77b77ef
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
<!-- # JaxADi -->
<!-- TODO: ADD PATCHES -->

[![CI](https://img.shields.io/github/actions/workflow/status/based-robotics/jaxadi/build.yaml?branch=master)](https://github.com/based-robotics/jaxadi/actions)
[![PyPI version](https://img.shields.io/pypi/v/jaxadi?color=blue)](https://pypi.org/project/jaxadi/)
[![PyPI downloads](https://img.shields.io/pypi/dm/jaxadi?color=blue)](https://pypistats.org/packages/jaxadi)

<p align="center">
<!-- Placeholder for a cool logo -->
<img src="https://github.com/based-robotics/jaxadi/blob/master/_assets/_logo.png?raw=true" alt="JAXADI Logo" width="400"/>
</p>


**JaxADi** is a powerful Python library designed to bridge the gap between `casadi.Function` and JAX-compatible functions. By leveraging the strengths of both CasADi and JAX, JAXADI opens up exciting opportunities for building highly efficient, batchable code that can be executed seamlessly across CPUs, GPUs, and TPUs.

JAXADI can be particularly useful in scenarios involving:
Expand All @@ -16,17 +18,17 @@ JAXADI can be particularly useful in scenarios involving:
- Machine learning models with complex dynamics
- Large-scale numerical optimizations


## Installation

You can install JAXADI using pip:

<!-- Change once it will be realeased -->

```bash
pip install jaxadi
```

For a complete environment setup, we recommend using Conda/Mamba:
For a complete environment setup for examples, we recommend using Conda/Mamba:

```bash
mamba env create -f environment.yml
Expand Down Expand Up @@ -67,24 +69,23 @@ output = jax_fn(input_x, input_y)

JAXADI comes with several examples to help you get started:

1. [Basic Translation](examples/00_translate.py): Learn how to translate CasADi functions to JAX.
1. [Basic Translation](examples/00_translate.py): Learn how to translate CasADi functions to JAX.

2. [Lowering Operations](examples/01_lower.py): Understand the lowering process in JaxADi.
2. [Lowering Operations](examples/01_lower.py): Understand the lowering process in JaxADi.

3. [Function Conversion](examples/02_convert.py): See how to fully convert CasADi functions to JAX.
3. [Function Conversion](examples/02_convert.py): See how to fully convert CasADi functions to JAX.

4. [Pendulum Rollout](examples/03_pendulum_rollout.py): Batched rollout of the nonlinear passive nonlinear pendulum
4. [Pendulum Rollout](examples/03_pendulum_rollout.py): Batched rollout of the nonlinear passive nonlinear pendulum

5. [Pinocchio Integration](examples/04_pinocchio.py): Explore how to convert Pinocchio-based CasADi functions to JAX.
5. [Pinocchio Integration](examples/04_pinocchio.py): Explore how to convert Pinocchio-based CasADi functions to JAX.

6. [MJX Comparison](examples/05_mjx.py): Compare the transformed Pinnocchio forward kinematics with one provided by Mujoco MJX

> **Note**: To run the Pinocchio and MJX examples, ensure you have them properly installed in your environment.
## Performance Benchmarks

<!-- ## Performance Benchmarks
(Consider adding a section about performance comparisons between CasADi and JAXADI-translated functions) -->
The process of benchmarking and evaluating the performance of Jaxadi is described in the [benchmarks](benchmarks/README.md) directory.

<!-- ## Contributing
Expand All @@ -98,5 +99,4 @@ This project draws inspiration from [cusadi](https://github.com/se-hwan/cusadi),

For questions, issues, or suggestions, please [open an issue](https://github.com/based-robotics/jaxadi/issues) on our GitHub repository.


We hope JAXADI empowers your numerical computing and optimization tasks! Happy coding!
10 changes: 8 additions & 2 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Benchmarking

This directory contains a set of benchmarks that we use to evaluate the performance of the `jaxadi` in comparison to `cusadi`.
In order to evaluate the performance of the `jaxadi` library vs `cusadi` we have tried to reproduce the benchmarks from the `cusadi` library first and faced some issues with the proper `cuda` installation.

Due to the difficulty of installation and `cuda` dependencies, we were able to reproduce the tests in the `colab` environment only.
![meme](https://preview.redd.it/explain-please-v0-ma2mz5wxftod1.jpeg?auto=webp&s=2b90dfa3b12e064f54333e1080b3dabbad914f48)

Adding the complexity of setup of benchmarks of [cusadi](https://github.com/se-hwan/cusadi) we have copied and modified the benchmarks to be able to run them in the [`colab` environment](https://colab.research.google.com/github/based-robotics/jaxadi/blob/feature%2Fbenchmarking/benchmarks/jaxadi_vs_cusadi.ipynb) side by side (CUDA vs Jax).

Due limitations we cover only the functions with less than 1e3 operations. All of them are located in the [data](data) directory.

One may run the benchmarks in the colab environment and get the [cusadi results](cuda_benchmark_results.npz) and [jaxadi results](jax_benchmark_results.npz) for comparison.
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ dependencies:
- pip:
- robot_descriptions
- jax
- mujoco
- mujoco_mjx
- jaxadi

0 comments on commit 77b77ef

Please sign in to comment.