Skip to content

Commit

Permalink
Support batch axes + broadcasting for viser.transforms (#208)
Browse files Browse the repository at this point in the history
* Support batch axes + broadcasting for `viser.transforms`

* Nits

* Sync docs

* Fix type
  • Loading branch information
brentyi authored May 6, 2024
1 parent f32a152 commit 36ae586
Show file tree
Hide file tree
Showing 14 changed files with 587 additions and 356 deletions.
3 changes: 3 additions & 0 deletions docs/source/examples/02_gui.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Examples of basic GUI elements that we can create, read from, and write to.
max=100,
step=1,
initial_value=(0, 30, 100),
marks=((0, "0"), (50, "5"), (70, "7"), 99),
)
gui_slider_positions = server.add_gui_slider(
"# sliders",
Expand Down Expand Up @@ -122,6 +123,8 @@ Examples of basic GUI elements that we can create, read from, and write to.
gui_text.visible = not gui_checkbox_hide.value
gui_button.visible = not gui_checkbox_hide.value
gui_rgb.disabled = gui_checkbox_disable.value
gui_button.disabled = gui_checkbox_disable.value
gui_upload_button.disabled = gui_checkbox_disable.value
# Update the number of handles in the multi-slider.
if gui_slider_positions.value != len(gui_multi_slider.value):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/03_gui_callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ we get updates.
gui_plane.on_update(lambda _: update_plane())
with server.add_gui_folder("Control", expand_by_default=False):
with server.add_gui_folder("Control"):
gui_show_frame = server.add_gui_checkbox("Show Frame", initial_value=True)
gui_show_everything = server.add_gui_checkbox(
"Show Everything", initial_value=True
Expand Down
11 changes: 4 additions & 7 deletions docs/source/examples/08_smpl_visualizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,10 @@ See here for download instructions:
# Compute SMPL outputs.
smpl_outputs = model.get_outputs(
betas=np.array([x.value for x in gui_elements.gui_betas]),
joint_rotmats=np.stack(
[
tf.SO3.exp(np.array(x.value)).as_matrix()
for x in gui_elements.gui_joints
],
axis=0,
),
joint_rotmats=tf.SO3.exp(
# (num_joints, 3)
np.array([x.value for x in gui_elements.gui_joints])
).as_matrix(),
)
server.add_mesh_simple(
"/human",
Expand Down
11 changes: 6 additions & 5 deletions docs/source/examples/20_scene_pointer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``.
if len(hit_pos) == 0:
return
client.remove_scene_pointer_callback()
# Get the first hit position (based on distance from the ray origin).
hit_pos = min(hit_pos, key=lambda x: onp.linalg.norm(x - origin))
Expand All @@ -85,10 +86,9 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``.
)
hit_pos_handles.append(hit_pos_handle)
@client.on_scene_pointer_done
@client.on_scene_pointer_removed
def _():
click_button_handle.disabled = False
client.remove_scene_pointer_callback()
# Tests "rect-select" scenepointerevent.
paint_button_handle = client.add_gui_button("Paint mesh", icon=viser.Icon.PAINT)
Expand All @@ -99,6 +99,8 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``.
@client.on_scene_pointer(event_type="rect-select")
def _(message: viser.ScenePointerEvent) -> None:
client.remove_scene_pointer_callback()
global mesh_handle
camera = message.client.camera
Expand All @@ -108,7 +110,7 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``.
R_camera_world = tf.SE3.from_rotation_and_translation(
tf.SO3(camera.wxyz), camera.position
).inverse()
vertices = mesh.vertices
vertices = cast(onp.ndarray, mesh.vertices)
vertices = (R_mesh_world.as_matrix() @ vertices.T).T
vertices = (
R_camera_world.as_matrix()
Expand Down Expand Up @@ -141,10 +143,9 @@ To get the demo data, see ``./assets/download_dragon_mesh.sh``.
position=(0.0, 0.0, 0.0),
)
@client.on_scene_pointer_done
@client.on_scene_pointer_removed
def _():
paint_button_handle.disabled = False
client.remove_scene_pointer_callback()
# Button to clear spheres.
clear_button_handle = client.add_gui_button("Clear scene", icon=viser.Icon.X)
Expand Down
83 changes: 83 additions & 0 deletions docs/source/examples/23_plotly.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
.. Comment: this file is automatically generated by `update_example_docs.py`.
It should not be modified manually.
Plotly.
==========================================


Examples of visualizing plotly plots in Viser.



.. code-block:: python
:linenos:
import time
import numpy as onp
import plotly.express as px
import plotly.graph_objects as go
import viser
from PIL import Image
def create_sinusoidal_wave(t: float) -> go.Figure:
"""Create a sinusoidal wave plot, starting at time t."""
x_data = onp.linspace(t, t + 6 * onp.pi, 50)
y_data = onp.sin(x_data) * 10
fig = px.line(
x=list(x_data),
y=list(y_data),
labels={"x": "x", "y": "sin(x)"},
title="Sinusoidal Wave",
)
# this sets the margins to be tight around the title.
fig.layout.title.automargin = True # type: ignore
fig.update_layout(
margin=dict(l=20, r=20, t=20, b=20),
) # Reduce plot margins.
return fig
def main() -> None:
server = viser.ViserServer()
# Plot type 1: Line plot.
line_plot_time = 0.0
line_plot = server.add_gui_plotly(figure=create_sinusoidal_wave(line_plot_time))
# Plot type 2: Image plot.
fig = px.imshow(Image.open("assets/Cal_logo.png"))
fig.update_layout(
margin=dict(l=20, r=20, t=20, b=20),
)
server.add_gui_plotly(figure=fig, aspect=1.0)
# Plot type 3: 3D Scatter plot.
fig = px.scatter_3d(
px.data.iris(),
x="sepal_length",
y="sepal_width",
z="petal_width",
color="species",
)
fig.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01))
fig.update_layout(
margin=dict(l=20, r=20, t=20, b=20),
)
server.add_gui_plotly(figure=fig, aspect=1.0)
while True:
# Update the line plot.
line_plot_time += 0.1
line_plot.figure = create_sinusoidal_wave(line_plot_time)
time.sleep(0.01)
if __name__ == "__main__":
main()
11 changes: 4 additions & 7 deletions examples/08_smpl_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,10 @@ def main(model_path: Path) -> None:
# Compute SMPL outputs.
smpl_outputs = model.get_outputs(
betas=np.array([x.value for x in gui_elements.gui_betas]),
joint_rotmats=np.stack(
[
tf.SO3.exp(np.array(x.value)).as_matrix()
for x in gui_elements.gui_joints
],
axis=0,
),
joint_rotmats=tf.SO3.exp(
# (num_joints, 3)
np.array([x.value for x in gui_elements.gui_joints])
).as_matrix(),
)
server.add_mesh_simple(
"/human",
Expand Down
Loading

0 comments on commit 36ae586

Please sign in to comment.