Skip to content

Commit

Permalink
Update colabs to fix bugs and improve code quality.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701021157
Change-Id: Id21a42c98d6e7f254c70b476a5ba409e351252f5
  • Loading branch information
yangyi02 committed Nov 28, 2024
1 parent 900ac92 commit 3d5142c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 23 deletions.
9 changes: 2 additions & 7 deletions colabs/causal_tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,15 @@
"source": [
"# @title Imports {form-width: \"25%\"}\n",
"%matplotlib widget\n",
"import functools\n",
"\n",
"from google.colab import output\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import mediapy as media\n",
"import numpy as np\n",
"from tapnet.models import tapir_model\n",
"from tapnet.utils import model_utils\n",
"from tapnet.utils import transforms\n",
"from tapnet.utils import viz_utils\n",
"from tqdm import tqdm\n",
"import tree\n",
"\n",
"output.enable_custom_widget_manager()"
]
Expand Down Expand Up @@ -259,7 +254,6 @@
"!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n",
"\n",
"video = media.read_video(\"tapnet/examplar_videos/horsejump-high.mp4\")\n",
"height, width = video.shape[1:3]\n",
"media.show_video(video, fps=10)"
]
},
Expand Down Expand Up @@ -315,6 +309,7 @@
"visibles = np.concatenate([x['visibles'][0] for x in predictions], axis=1)\n",
"\n",
"# Visualize sparse point tracks\n",
"height, width = video.shape[1:3]\n",
"tracks = transforms.convert_grid_coordinates(\n",
" tracks, (resize_width, resize_height), (width, height)\n",
")\n",
Expand Down Expand Up @@ -354,7 +349,6 @@
"for sample_idx, sample in enumerate(davis_dataset):\n",
" sample = sample['davis']\n",
" frames = np.round((sample['video'][0] + 1) / 2 * 255).astype(np.uint8)\n",
" height, width = frames.shape[1:3]\n",
" query_points = sample['query_points'][0]\n",
"\n",
" # Extract features for the query point.\n",
Expand Down Expand Up @@ -478,6 +472,7 @@
"\n",
"frames = media.resize_video(video, (resize_height, resize_width))\n",
"query_points = convert_select_points_to_query_points(0, select_points)\n",
"height, width = video.shape[1:3]\n",
"query_points = transforms.convert_grid_coordinates(\n",
" query_points,\n",
" (1, height, width),\n",
Expand Down
8 changes: 3 additions & 5 deletions colabs/tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,12 @@
"source": [
"# @title Imports {form-width: \"25%\"}\n",
"%matplotlib widget\n",
"import functools\n",
"from google.colab import output\n",
"import jax\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import mediapy as media\n",
"import numpy as np\n",
"# matplotlib.use('Agg')\n",
"\n",
"from tapnet.models import tapir_model\n",
"from tapnet.utils import model_utils\n",
"from tapnet.utils import transforms\n",
Expand Down Expand Up @@ -169,7 +166,6 @@
"!wget -P tapnet/examplar_videos http://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n",
"\n",
"video = media.read_video(\"tapnet/examplar_videos/horsejump-high.mp4\")\n",
"height, width = video.shape[1:3]\n",
"media.show_video(video, fps=10)"
]
},
Expand Down Expand Up @@ -254,6 +250,7 @@
"visibles = np.array(visibles)\n",
"\n",
"# Visualize sparse point tracks\n",
"height, width = video.shape[1:3]\n",
"tracks = transforms.convert_grid_coordinates(\n",
" tracks, (resize_width, resize_height), (width, height)\n",
")\n",
Expand Down Expand Up @@ -293,7 +290,6 @@
"for sample_idx, sample in enumerate(davis_dataset):\n",
" sample = sample['davis']\n",
" frames = np.round((sample['video'][0] + 1) / 2 * 255).astype(np.uint8)\n",
" height, width = frames.shape[1:3]\n",
" query_points = sample['query_points'][0]\n",
"\n",
" tracks, visibles = inference(frames, query_points)\n",
Expand Down Expand Up @@ -383,6 +379,7 @@
"visibles = np.concatenate(all_visibles, axis=0)\n",
"\n",
"# Visualize sparse point tracks\n",
"height, width = video.shape[1:3]\n",
"tracks = transforms.convert_grid_coordinates(\n",
" tracks, (resize_width, resize_height), (width, height)\n",
")\n",
Expand Down Expand Up @@ -468,6 +465,7 @@
"query_points = convert_select_points_to_query_points(\n",
" select_frame, select_points\n",
")\n",
"height, width = video.shape[1:3]\n",
"query_points = transforms.convert_grid_coordinates(\n",
" query_points,\n",
" (1, height, width),\n",
Expand Down
7 changes: 1 addition & 6 deletions colabs/torch_causal_tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,13 @@
"# @title Imports {form-width: \"25%\"}\n",
"%matplotlib widget\n",
"from google.colab import output\n",
"import haiku as hk\n",
"import jax\n",
"import matplotlib.pyplot as plt\n",
"import mediapy as media\n",
"import numpy as np\n",
"from tapnet.torch import tapir_model\n",
"from tapnet.utils import transforms\n",
"from tapnet.utils import viz_utils\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import tree\n",
"\n",
"output.enable_custom_widget_manager()"
]
Expand Down Expand Up @@ -144,7 +140,6 @@
"!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n",
"\n",
"video = media.read_video(\"tapnet/examplar_videos/horsejump-high.mp4\")\n",
"height, width = video.shape[1:3]\n",
"media.show_video(video, fps=10)"
]
},
Expand Down Expand Up @@ -324,6 +319,7 @@
"tracks = tracks.cpu().numpy()\n",
"visibles = visibles.cpu().numpy()\n",
"# Visualize sparse point tracks\n",
"height, width = video.shape[1:3]\n",
"tracks = transforms.convert_grid_coordinates(\n",
" tracks, (resize_width, resize_height), (width, height)\n",
")\n",
Expand Down Expand Up @@ -363,7 +359,6 @@
"for sample_idx, sample in enumerate(davis_dataset):\n",
" sample = sample['davis']\n",
" frames = np.round((sample['video'][0] + 1) / 2 * 255).astype(np.uint8)\n",
" height, width = frames.shape[1:3]\n",
" query_points = sample['query_points'][0]\n",
"\n",
" frames = torch.tensor(frames).to(device)\n",
Expand Down
7 changes: 2 additions & 5 deletions colabs/torch_tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@
"# @title Imports {form-width: \"25%\"}\n",
"%matplotlib widget\n",
"from google.colab import output\n",
"import haiku as hk\n",
"import jax\n",
"import matplotlib.pyplot as plt\n",
"import mediapy as media\n",
"import numpy as np\n",
Expand All @@ -107,7 +105,6 @@
"from tapnet.utils import viz_utils\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import tree\n",
"\n",
"output.enable_custom_widget_manager()"
]
Expand Down Expand Up @@ -172,7 +169,6 @@
"def inference(frames, query_points, model):\n",
" # Preprocess video to match model inputs format\n",
" frames = preprocess_frames(frames)\n",
" num_frames, height, width = frames.shape[0:3]\n",
" query_points = query_points.float()\n",
" frames, query_points = frames[None], query_points[None]\n",
"\n",
Expand Down Expand Up @@ -204,7 +200,6 @@
"!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n",
"\n",
"video = media.read_video(\"tapnet/examplar_videos/horsejump-high.mp4\")\n",
"height, width = video.shape[1:3]\n",
"media.show_video(video, fps=10)"
]
},
Expand Down Expand Up @@ -265,6 +260,7 @@
"tracks = tracks.cpu().detach().numpy()\n",
"visibles = visibles.cpu().detach().numpy()\n",
"# Visualize sparse point tracks\n",
"height, width = video.shape[1:3]\n",
"tracks = transforms.convert_grid_coordinates(\n",
" tracks, (resize_width, resize_height), (width, height)\n",
")\n",
Expand Down Expand Up @@ -417,6 +413,7 @@
"query_points = convert_select_points_to_query_points(\n",
" select_frame, select_points\n",
")\n",
"height, width = video.shape[1:3]\n",
"query_points = transforms.convert_grid_coordinates(\n",
" query_points,\n",
" (1, height, width),\n",
Expand Down

0 comments on commit 3d5142c

Please sign in to comment.