Skip to content

Commit

Permalink
Fix a bug to DAVIS first eval in torch_tapir_demo colab.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700959920
Change-Id: I21766d126ee94a21fe5ab77d155f3e3f5756680b
  • Loading branch information
yangyi02 committed Nov 28, 2024
1 parent 69056d9 commit 900ac92
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 89 deletions.
166 changes: 83 additions & 83 deletions colabs/causal_tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
},
"outputs": [],
"source": [
"# @title Install code and dependencies {form-width: \"25%\"}\n",
"# @title Install Code and Dependencies {form-width: \"25%\"}\n",
"!pip install git+https://github.com/google-deepmind/tapnet.git"
]
},
Expand Down Expand Up @@ -322,6 +322,88 @@
"media.show_video(video_viz, fps=10)"
]
},
{
"metadata": {
"id": "VZdUm7PfFlYN"
},
"cell_type": "code",
"source": [
"# @title Download TAPVid-DAVSIS Dataset {form-width: \"25%\"}\n",
"!wget https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip\n",
"!unzip tapvid_davis.zip"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {
"id": "4QWKHTjyFlYN"
},
"cell_type": "code",
"source": [
"# @title DAVIS First Eval on 256x256 Resolution {form-width: \"25%\"}\n",
"%%time\n",
"\n",
"from tapnet import evaluation_datasets\n",
"\n",
"davis_dataset = evaluation_datasets.create_davis_dataset(\n",
" 'tapvid_davis/tapvid_davis.pkl', query_mode='first'\n",
")\n",
"\n",
"summed_scalars = None\n",
"for sample_idx, sample in enumerate(davis_dataset):\n",
" sample = sample['davis']\n",
" frames = np.round((sample['video'][0] + 1) / 2 * 255).astype(np.uint8)\n",
" height, width = frames.shape[1:3]\n",
" query_points = sample['query_points'][0]\n",
"\n",
" # Extract features for the query point.\n",
" query_features = online_model_init(frames[None], query_points[None])\n",
"\n",
" causal_state = tapir.construct_initial_causal_state(\n",
" query_points.shape[0], len(query_features.resolutions) - 1\n",
" )\n",
"\n",
" # Predict point tracks frame by frame\n",
" predictions = []\n",
" for i in range(frames.shape[0]):\n",
" # Note: we add a batch dimension.\n",
" tracks, visibles, causal_state = online_model_predict(\n",
" frames=frames[None, i : i + 1],\n",
" query_features=query_features,\n",
" causal_context=causal_state,\n",
" )\n",
" predictions.append({'tracks': tracks, 'visibles': visibles})\n",
"\n",
" tracks = np.concatenate([x['tracks'][0] for x in predictions], axis=1)\n",
" visibles = np.concatenate([x['visibles'][0] for x in predictions], axis=1)\n",
" occluded = ~visibles\n",
"\n",
" query_points = sample['query_points'][0]\n",
"\n",
" scalars = evaluation_datasets.compute_tapvid_metrics(\n",
" query_points[None],\n",
" sample['occluded'],\n",
" sample['target_points'],\n",
" occluded[None],\n",
" tracks[None],\n",
" query_mode='first',\n",
" )\n",
" scalars = jax.tree.map(lambda x: np.array(np.sum(x, axis=0)), scalars)\n",
" print(sample_idx, scalars)\n",
"\n",
" if summed_scalars is None:\n",
" summed_scalars = scalars\n",
" else:\n",
" summed_scalars = jax.tree.map(np.add, summed_scalars, scalars)\n",
"\n",
" num_samples = sample_idx + 1\n",
" mean_scalars = jax.tree.map(lambda x: x / num_samples, summed_scalars)\n",
" print(mean_scalars)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -430,88 +512,6 @@
"media.show_video(video_viz, fps=10)"
]
},
{
"metadata": {
"id": "9ogryCtJV_sA"
},
"cell_type": "code",
"source": [
"# @title Download TAPVid-DAVSIS Dataset {form-width: \"25%\"}\n",
"!wget https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip\n",
"!unzip tapvid_davis.zip"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {
"id": "aRU0r0yCV_sA"
},
"cell_type": "code",
"source": [
"# @title DAVIS first eval on 256x256 resolution {form-width: \"25%\"}\n",
"%%time\n",
"\n",
"from tapnet import evaluation_datasets\n",
"\n",
"davis_dataset = evaluation_datasets.create_davis_dataset(\n",
" 'tapvid_davis/tapvid_davis.pkl', query_mode='first'\n",
")\n",
"\n",
"summed_scalars = None\n",
"for sample_idx, sample in enumerate(davis_dataset):\n",
" sample = sample['davis']\n",
" frames = np.round((sample['video'][0] + 1) / 2 * 255).astype(np.uint8)\n",
" height, width = frames.shape[1:3]\n",
" query_points = sample['query_points'][0]\n",
"\n",
" # Extract features for the query point.\n",
" query_features = online_model_init(frames[None], query_points[None])\n",
"\n",
" causal_state = tapir.construct_initial_causal_state(\n",
" query_points.shape[0], len(query_features.resolutions) - 1\n",
" )\n",
"\n",
" # Predict point tracks frame by frame\n",
" predictions = []\n",
" for i in range(frames.shape[0]):\n",
" # Note: we add a batch dimension.\n",
" tracks, visibles, causal_state = online_model_predict(\n",
" frames=frames[None, i : i + 1],\n",
" query_features=query_features,\n",
" causal_context=causal_state,\n",
" )\n",
" predictions.append({'tracks': tracks, 'visibles': visibles})\n",
"\n",
" tracks = np.concatenate([x['tracks'][0] for x in predictions], axis=1)\n",
" visibles = np.concatenate([x['visibles'][0] for x in predictions], axis=1)\n",
" occluded = ~visibles\n",
"\n",
" query_points = sample['query_points'][0]\n",
"\n",
" scalars = evaluation_datasets.compute_tapvid_metrics(\n",
" query_points[None],\n",
" sample['occluded'],\n",
" sample['target_points'],\n",
" occluded[None],\n",
" tracks[None],\n",
" query_mode='first',\n",
" )\n",
" scalars = jax.tree.map(lambda x: np.array(np.sum(x, axis=0)), scalars)\n",
" print(sample_idx, scalars)\n",
"\n",
" if summed_scalars is None:\n",
" summed_scalars = scalars\n",
" else:\n",
" summed_scalars = jax.tree.map(np.add, summed_scalars, scalars)\n",
"\n",
" num_samples = sample_idx + 1\n",
" mean_scalars = jax.tree.map(lambda x: x / num_samples, summed_scalars)\n",
" print(mean_scalars)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {
Expand Down
4 changes: 2 additions & 2 deletions colabs/tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
},
"outputs": [],
"source": [
"# @title Install code and dependencies {form-width: \"25%\"}\n",
"# @title Install Code and Dependencies {form-width: \"25%\"}\n",
"!pip install git+https://github.com/google-deepmind/tapnet.git"
]
},
Expand Down Expand Up @@ -280,7 +280,7 @@
},
"cell_type": "code",
"source": [
"# @title DAVIS first eval on 256x256 resolution {form-width: \"25%\"}\n",
"# @title DAVIS First Eval on 256x256 Resolution {form-width: \"25%\"}\n",
"%%time\n",
"\n",
"from tapnet import evaluation_datasets\n",
Expand Down
4 changes: 2 additions & 2 deletions colabs/torch_causal_tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
},
"outputs": [],
"source": [
"# @title Install code and dependencies {form-width: \"25%\"}\n",
"# @title Install Code and Dependencies {form-width: \"25%\"}\n",
"!pip install git+https://github.com/google-deepmind/tapnet.git"
]
},
Expand Down Expand Up @@ -350,7 +350,7 @@
},
"cell_type": "code",
"source": [
"# @title DAVIS first eval on 256x256 resolution {form-width: \"25%\"}\n",
"# @title DAVIS First Eval on 256x256 Resolution {form-width: \"25%\"}\n",
"%%time\n",
"\n",
"from tapnet import evaluation_datasets\n",
Expand Down
5 changes: 3 additions & 2 deletions colabs/torch_tapir_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
},
"outputs": [],
"source": [
"# @title Install code and dependencies {form-width: \"25%\"}\n",
"# @title Install Code and Dependencies {form-width: \"25%\"}\n",
"!pip install 'tapnet[torch] @ git+https://github.com/google-deepmind/tapnet.git'"
]
},
Expand Down Expand Up @@ -291,7 +291,7 @@
},
"cell_type": "code",
"source": [
"# @title DAVIS first eval on 256x256 resolution {form-width: \"25%\"}\n",
"# @title DAVIS First Eval on 256x256 Resolution {form-width: \"25%\"}\n",
"%%time\n",
"\n",
"from tapnet import evaluation_datasets\n",
Expand All @@ -313,6 +313,7 @@
"\n",
" tracks = tracks.cpu().detach().numpy()\n",
" visibles = visibles.cpu().detach().numpy()\n",
" query_points = query_points.cpu().detach().numpy()\n",
" occluded = ~visibles\n",
"\n",
" scalars = evaluation_datasets.compute_tapvid_metrics(\n",
Expand Down

0 comments on commit 900ac92

Please sign in to comment.