Skip to content

Commit

Permalink
release TAPIR clustering from RoboTAP
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561646394
Change-Id: I6fc2b265bb4ec18a9e79c485c31893b3f90ca857
  • Loading branch information
cdoersch authored and copybara-github committed Aug 31, 2023
1 parent 7b3a336 commit e8236e9
Show file tree
Hide file tree
Showing 3 changed files with 1,251 additions and 17 deletions.
170 changes: 170 additions & 0 deletions colabs/tapir_clustering.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MWPOsk-I8o69"
},
"outputs": [],
"source": [
"# @title Download Code {form-width: \"25%\"}\n",
"!git clone https://github.com/deepmind/tapnet.git"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OUfaseMw_hqJ"
},
"outputs": [],
"source": [
"# @title Install Dependencies {form-width: \"25%\"}\n",
"!pip install -r tapnet/requirements_inference.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dNWBx_DOHSSt"
},
"outputs": [],
"source": [
"# @title Download Model {form-width: \"25%\"}\n",
"\n",
"%mkdir tapnet/checkpoints\n",
"\n",
"!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/causal_tapir_checkpoint.npy\n",
"\n",
"%ls tapnet/checkpoints\n",
"\n",
"checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jtTNXUNCHVAL"
},
"outputs": [],
"source": [
"# @title Imports {form-width: \"25%\"}\n",
"%matplotlib widget\n",
"import functools\n",
"\n",
"import haiku as hk\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import mediapy as media\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"import tree\n",
"\n",
"from tapnet import tapir_clustering\n",
"from tapnet.utils import transforms\n",
"from tapnet.utils import viz_utils"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0J9kVfSuHmqS"
},
"outputs": [],
"source": [
"# @title Load an Exemplar Video {form-width: \"25%\"}\n",
"\n",
"%mkdir tapnet/examplar_videos\n",
"\n",
"!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/robotap/for_clustering.mp4\n",
"\n",
"video = media.read_video('tapnet/examplar_videos/for_clustering.mp4')\n",
"height, width = video.shape[1:3]\n",
"media.show_video(video[::5], fps=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7Vjhi4PdJ2W-"
},
"outputs": [],
"source": [
"# @title Run TAPIR to extract point tracks {form-width: \"25%\"}\n",
"\n",
"demo_videos = {\"dummy_id\":video}\n",
"demo_episode_ids = list(demo_videos.keys())\n",
"track_dict = tapir_clustering.track_many_points(\n",
" demo_videos,\n",
" demo_episode_ids,\n",
" checkpoint_path,\n",
" point_batch_size=1024,\n",
" points_per_frame=10,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kU2yqJVTPgg-"
},
"outputs": [],
"source": [
"# @title Run the clustering {form-width: \"25%\"}\n",
"\n",
"clustered = tapir_clustering.compute_clusters(\n",
" track_dict['separation_tracks'],\n",
" track_dict['separation_visibility'],\n",
" track_dict['demo_episode_ids'],\n",
" track_dict['video_shape'],\n",
" track_dict['query_features'],\n",
" max_num_cats=12,\n",
" final_num_cats=7,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FCNCAeLVQ0r2"
},
"outputs": [],
"source": [
"# @title Display the inferred clusters {form-width: \"25%\"}\n",
"\n",
"separation_visibility_trim = clustered['separation_visibility']\n",
"separation_tracks_trim = clustered['separation_tracks']\n",
"\n",
"pointtrack_video = viz_utils.plot_tracks_v2(\n",
" (demo_videos[demo_episode_ids[0]]).astype(np.uint8),\n",
" separation_tracks_trim[demo_episode_ids[0]],\n",
" 1.0-separation_visibility_trim[demo_episode_ids[0]],\n",
" trackgroup=clustered['classes']\n",
")\n",
"media.show_video(pointtrack_video, fps=20)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading

0 comments on commit e8236e9

Please sign in to comment.