From 5f91eb3c0592fcdba4f78f07b942a39510d4f56d Mon Sep 17 00:00:00 2001 From: Keunhong Park Date: Thu, 16 Sep 2021 04:15:31 -0700 Subject: [PATCH] Add figures. --- .../figures/hypernerf_ap_ds_figure.ipynb | 1045 ++++++++++++ .../figures/hypernerf_optim_latent.ipynb | 148 ++ .../figures/level_set_visualization.ipynb | 117 ++ .../figures/nerfies_2d_experiments.ipynb | 921 +++++++++++ notebooks/figures/nerfies_eval_skeleton.ipynb | 912 +++++++++++ notebooks/figures/sdf_2d.ipynb | 1419 +++++++++++++++++ 6 files changed, 4562 insertions(+) create mode 100644 notebooks/figures/hypernerf_ap_ds_figure.ipynb create mode 100644 notebooks/figures/hypernerf_optim_latent.ipynb create mode 100644 notebooks/figures/level_set_visualization.ipynb create mode 100644 notebooks/figures/nerfies_2d_experiments.ipynb create mode 100644 notebooks/figures/nerfies_eval_skeleton.ipynb create mode 100644 notebooks/figures/sdf_2d.ipynb diff --git a/notebooks/figures/hypernerf_ap_ds_figure.ipynb b/notebooks/figures/hypernerf_ap_ds_figure.ipynb new file mode 100644 index 0000000..84a8a13 --- /dev/null +++ b/notebooks/figures/hypernerf_ap_ds_figure.ipynb @@ -0,0 +1,1045 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "hypernerf_ap_ds_figure", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "id": "EOnw38FUTbLI" + }, + "source": [ + "# @title Imports\n", + "\n", + "from dataclasses import dataclass\n", + "from pprint import pprint\n", + "from typing import Any, List, Callable, Dict, Sequence\n", + "\n", + "import numpy as np\n", + "\n", + "import jax\n", + "from jax.config import config as jax_config\n", + "import jax.numpy as jnp\n", + "from jax import grad, jit, vmap\n", + "from jax import random\n", + "\n", + "import flax\n", + "import flax.linen as nn\n", + "from flax import jax_utils\n", + "from flax import optim\n", + "from flax.metrics import tensorboard\n", + "from flax.training import checkpoints\n", + "\n", + "from absl import logging\n", + "\n", + "# Monkey patch logging.\n", + "def myprint(msg, *args, **kwargs):\n", + " print(msg % args)\n", + "\n", + "logging.info = myprint \n", + "logging.warn = myprint\n", + "\n", + "\n", + "import gin\n", + "gin.enter_interactive_mode()\n", + "\n", + "\n", + "def scatter_points(points, **kwargs):\n", + " \"\"\"Convenience function for plotting points in Plotly.\"\"\"\n", + " return go.Scatter3d(\n", + " x=points[:, 0],\n", + " y=points[:, 1],\n", + " z=points[:, 2],\n", + " mode='markers',\n", + " **kwargs,\n", + " )\n", + "\n", + "\n", + "from IPython.core.display import display, HTML, Latex\n", + "\n", + "\n", + "def Markdown(text):\n", + " IPython.core.display._display_mimetype('text/markdown', [text], raw=True)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "wJvpDubUf9Nr" + }, + "source": [ + "print(jax.devices())" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "m_ngWEcyCUQR" + }, + "source": [ + "# @title Utilities\n", + "import contextlib\n", + "\n", + "\n", + "@contextlib.contextmanager\n", + "def plot_to_array(height, width, rows=1, cols=1, dpi=100, no_axis=False):\n", + " \"\"\"A context manager that plots to a numpy array.\n", + "\n", + " When the context manager exits the output array will be populated with an\n", + " image of the plot.\n", + "\n", + " Usage:\n", + " ```\n", + " with plot_to_array(480, 640, 2, 2) as (fig, axes, out_image):\n", + " axes[0][0].plot(...)\n", + " ```\n", + " Args:\n", + " height: the height of the canvas\n", + " width: the width of the canvas\n", + " rows: the number of axis rows\n", + " cols: the number of axis columns\n", + " dpi: the DPI to render at\n", + " no_axis: if True will hide the axes of the plot\n", + "\n", + " Yields:\n", + " A 3-tuple of: a pyplot Figure, array of Axes, and the output np.ndarray.\n", + " \"\"\"\n", + " out_array = np.empty((height, width, 3), dtype=np.uint8)\n", + " fig, axes = plt.subplots(\n", + " rows, cols, figsize=(width / dpi, height / dpi), dpi=dpi)\n", + " if no_axis:\n", + " for ax in fig.axes:\n", + " ax.margins(0, 0)\n", + " ax.axis('off')\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + "\n", + " yield fig, axes, out_array\n", + "\n", + " # If we haven't already shown or saved the plot, then we need to\n", + " # draw the figure first...\n", + " fig.tight_layout(pad=0)\n", + " fig.canvas.draw()\n", + "\n", + " # Now we can save it to a numpy array.\n", + " data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n", + " data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n", + " plt.close()\n", + "\n", + " np.copyto(out_array, data)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "JIL7vtr9orwb" + }, + "source": [ + "rng = random.PRNGKey(20200823)\n", + "# Shift the numpy random seed by host_id() to shuffle data loaded by different\n", + "# hosts.\n", + "np.random.seed(20201473 + jax.host_id())" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "SnjlLfRyRXaZ" + }, + "source": [ + "dataset_name = '' # @param {type:\"string\"}\n", + "data_dir = gpath.GPath(dataset_name) \n", + "print('data_dir: ', data_dir)\n", + "# assert data_dir.exists()\n", + "\n", + "exp_dir = '' # @param {type:\"string\"}\n", + "exp_dir = gpath.GPath(exp_dir)\n", + "print('exp_dir: ', exp_dir)\n", + "assert exp_dir.exists()\n", + "\n", + "config_path = exp_dir / 'config.gin'\n", + "print('config_path', config_path)\n", + "assert config_path.exists()\n", + "\n", + "checkpoint_dir = exp_dir / 'checkpoints'\n", + "print('checkpoint_dir: ', checkpoint_dir)\n", + "assert checkpoint_dir.exists()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Brw2WgO7pj5d" + }, + "source": [ + "# @title Load configuration.\n", + "reload()\n", + "\n", + "import IPython\n", + "from IPython.display import display, HTML\n", + "\n", + "\n", + "def config_line_predicate(l):\n", + " return (\n", + " 'ExperimentConfig.camera_type' not in l\n", + " and 'preload_data' not in l\n", + " # and 'metadata_at_density' not in l\n", + " # and 'hyper_grad_loss_weight' not in l\n", + " )\n", + "\n", + "\n", + "print(config_path)\n", + "\n", + "with config_path.open('rt') as f:\n", + " lines = filter(config_line_predicate, f.readlines())\n", + " gin_config = ''.join(lines)\n", + "\n", + "gin.parse_config(gin_config)\n", + "\n", + "exp_config = configs.ExperimentConfig()\n", + "\n", + "train_config = configs.TrainConfig(\n", + " batch_size=2048,\n", + " hyper_sheet_alpha_schedule=None,\n", + ")\n", + "eval_config = configs.EvalConfig(\n", + " chunk=4096,\n", + ")\n", + "dummy_model = models.NerfModel({}, 0, 0)\n", + "\n", + "display(HTML(Markdown(gin.config.markdown(gin.config_str()))))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Q6nu1E1s8Heo" + }, + "source": [ + "datasource = exp_config.datasource_cls(\n", + " data_dir=data_dir,\n", + " image_scale=exp_config.image_scale,\n", + " random_seed=exp_config.random_seed,\n", + " # Enable metadata based on model needs.\n", + " use_warp_id=dummy_model.use_warp,\n", + " use_appearance_id=(\n", + " dummy_model.nerf_embed_key == 'appearance'\n", + " or dummy_model.hyper_embed_key == 'appearance'),\n", + " use_camera_id=dummy_model.nerf_embed_key == 'camera',\n", + " use_time=dummy_model.warp_embed_key == 'time')\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "leESeBqh0RSQ" + }, + "source": [ + "# @title Load model\n", + "reload()\n", + "\n", + "devices = jax.devices()\n", + "rng, key = random.split(rng)\n", + "params = {}\n", + "model, params['model'] = models.construct_nerf(\n", + " key,\n", + " batch_size=train_config.batch_size,\n", + " embeddings_dict=datasource.embeddings_dict,\n", + " near=datasource.near,\n", + " far=datasource.far)\n", + "\n", + "optimizer_def = optim.Adam(0.0)\n", + "if train_config.use_weight_norm:\n", + " optimizer_def = optim.WeightNorm(optimizer_def)\n", + "optimizer = optimizer_def.create(params)\n", + "state = model_utils.TrainState(\n", + " optimizer=optimizer,\n", + " warp_alpha=0.0)\n", + "scalar_params = training.ScalarParams(\n", + " learning_rate=0.0,\n", + " elastic_loss_weight=0.0,\n", + " background_loss_weight=train_config.background_loss_weight)\n", + "try:\n", + " state_dict = checkpoints.restore_checkpoint(checkpoint_dir, None)\n", + " state = state.replace(\n", + " optimizer=flax.serialization.from_state_dict(state.optimizer, state_dict['optimizer']),\n", + " warp_alpha=state_dict['warp_alpha'])\n", + "except KeyError:\n", + " # Load legacy checkpoints.\n", + " optimizer = optimizer_def.create(params['model'])\n", + " state = model_utils.TrainState(optimizer=optimizer)\n", + " state = checkpoints.restore_checkpoint(checkpoint_dir, state)\n", + " state = state.replace(optimizer=state.optimizer.replace(target={'model': state.optimizer.target}))\n", + "step = state.optimizer.state.step + 1\n", + "state = jax_utils.replicate(state, devices=devices)\n", + "del params" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "YgaryK1QVQaF" + }, + "source": [ + "# @title Render function.\n", + "import functools\n", + "reload()\n", + "\n", + "use_warp = True # @param{type: 'boolean'}\n", + "use_points = False # @param{type: 'boolean'}\n", + "\n", + "params = jax_utils.unreplicate(state.optimizer.target)\n", + "\n", + "\n", + "def _model_fn(key_0, key_1, params, rays_dict, warp_extras):\n", + " out = model.apply({'params': params},\n", + " rays_dict,\n", + " warp_extras,\n", + " rngs={\n", + " 'coarse': key_0,\n", + " 'fine': key_1\n", + " },\n", + " mutable=False,\n", + " metadata_encoded=True,\n", + " return_points=use_points,\n", + " return_weights=use_points,\n", + " use_warp=use_warp)\n", + " return jax.lax.all_gather(out, axis_name='batch')\n", + "\n", + "pmodel_fn = jax.pmap(\n", + " # Note rng_keys are useless in eval mode since there's no randomness.\n", + " _model_fn,\n", + " # key0, key1, params, rays_dict, warp_extras\n", + " in_axes=(0, 0, 0, 0, 0),\n", + " devices=devices,\n", + " donate_argnums=(3,), # Donate the 'rays' argument.\n", + " axis_name='batch',\n", + ")\n", + "\n", + "render_fn = functools.partial(evaluation.render_image,\n", + " model_fn=pmodel_fn,\n", + " device_count=len(devices),\n", + " chunk=8192)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "a2jkqa5EGiZL" + }, + "source": [ + "# @title Latent code utils\n", + "\n", + "def get_hyper_code(params, item_id):\n", + " appearance_id = datasource.get_appearance_id(item_id)\n", + " metadata = {\n", + " 'warp': jnp.array([appearance_id], jnp.uint32),\n", + " 'appearance': jnp.array([appearance_id], jnp.uint32),\n", + " }\n", + " return model.apply({'params': params['model']},\n", + " metadata,\n", + " method=model.encode_hyper_embed)\n", + "\n", + "\n", + "def get_appearance_code(params, item_id):\n", + " appearance_id = datasource.get_appearance_id(item_id)\n", + " metadata = {\n", + " 'appearance': jnp.array([appearance_id], jnp.uint32),\n", + " }\n", + " return model.apply({'params': params['model']},\n", + " metadata,\n", + " method=model.encode_nerf_embed)\n", + "\n", + "\n", + "def get_warp_code(params, item_id):\n", + " warp_id = datasource.get_warp_id(item_id)\n", + " metadata = {\n", + " 'warp': jnp.array([warp_id], jnp.uint32),\n", + " }\n", + " return model.apply({'params': params['model']},\n", + " metadata,\n", + " method=model.encode_warp_embed)\n", + "\n", + "def get_codes(item_id):\n", + " appearance_code = None\n", + " if model.use_rgb_condition:\n", + " appearance_code = get_appearance_code(params, item_id)\n", + " \n", + " warp_codes = None\n", + " if model.use_warp:\n", + " warp_code = get_warp_code(params, item_id)\n", + " \n", + " hyper_codes = None\n", + " if model.has_hyper:\n", + " hyper_code = get_hyper_code(params, item_id)\n", + " \n", + " return appearance_code, warp_code, hyper_code\n", + "\n", + "\n", + "def make_batch(camera, appearance_code=None, warp_code=None, hyper_code=None):\n", + " batch = datasets.camera_to_rays(camera)\n", + " batch_shape = batch['origins'][..., 0].shape\n", + " metadata = {}\n", + " if appearance_code is not None:\n", + " appearance_code = appearance_code.squeeze(0)\n", + " metadata['encoded_nerf'] = jnp.broadcast_to(\n", + " appearance_code[None, None, :], (*batch_shape, appearance_code.shape[-1]))\n", + " if warp_code is not None:\n", + " metadata['encoded_warp'] = jnp.broadcast_to(\n", + " warp_code[None, None, :], (*batch_shape, warp_code.shape[-1]))\n", + " batch['metadata'] = metadata\n", + "\n", + " if hyper_code is not None:\n", + " batch['metadata']['encoded_hyper'] = jnp.broadcast_to(\n", + " hyper_code[None, None, :], (*batch_shape, hyper_code.shape[-1]))\n", + " \n", + " return batch" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "-K19Z2Tsmlxe" + }, + "source": [ + "# @title Manual crop\n", + "\n", + "render_scale = 0.5\n", + "target_rgb = image_utils.downsample_image(datasource.load_rgb(datasource.train_ids[0]), int(1/render_scale))\n", + "top, bottom, left, right = 2 * np.array([89, 75, 32, 26]) # K\n", + "\n", + "# top, bottom, left, right = 2 * np.array([60, 70, 14, 10]) # R\n", + "# top, bottom, left, right = 0, 30, 68, 68 # lemon\n", + "# top, bottom, left, right = 40, 100, 2, 40 # slice-banana\n", + "target_rgb = target_rgb[top:-bottom, left:-right]\n", + "print(target_rgb.shape)\n", + "media.show_image(target_rgb)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0Zj3uIsjJfKO" + }, + "source": [ + "## Hyper grid." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "c92t15kuKanf" + }, + "source": [ + "# @title Sample points and metadata\n", + "\n", + "item_id = datasource.train_ids[0]\n", + "camera = datasource.load_camera(item_id).scale(render_scale)\n", + "camera.crop_image_domain()\n", + "batch = make_batch(camera, *get_codes(item_id))\n", + "origins = batch['origins']\n", + "directions = batch['directions']\n", + "metadata = batch['metadata']\n", + "z_vals, points = model_utils.sample_along_rays(\n", + " rng, origins[None, ...], directions[None, ...], \n", + " model.num_coarse_samples,\n", + " model.near, \n", + " model.far, \n", + " model.use_stratified_sampling,\n", + " model.use_linear_disparity)\n", + "points = points.reshape((-1, 3))\n", + "points = random.permutation(rng, points)[:8096*4]\n", + "print(points.shape)\n", + "\n", + "warp_metadata = random.randint(\n", + " key, (points.shape[0], 1), 0, model.num_warp_embeds, dtype=jnp.uint32)\n", + "warp_embed = model.apply({'params': params['model']},\n", + " {model.warp_embed_key: warp_metadata},\n", + " method=model.encode_warp_embed)\n", + "# warp_embed = jnp.broadcast_to(\n", + "# warp_embed[:, jnp.newaxis, :],\n", + "# shape=(*points.shape[:-1], warp_embed.shape[-1]))\n", + "if model.has_hyper_embed:\n", + " hyper_metadata = random.randint(\n", + " key, (points.shape[0], 1), 0, model.num_hyper_embeds, dtype=jnp.uint32)\n", + " hyper_embed_key = (model.warp_embed_key if model.hyper_use_warp_embed\n", + " else model.hyper_embed_key)\n", + " hyper_embed = model.apply({'params': params['model']},\n", + " {hyper_embed_key: hyper_metadata},\n", + " method=model.encode_hyper_embed)\n", + " # hyper_embed = jnp.broadcast_to(\n", + " # hyper_embed[:, jnp.newaxis, :],\n", + " # shape=(*batch_shape, hyper_embed.shape[-1]))\n", + "else:\n", + " hyper_embed = None\n", + "\n", + "map_fn = functools.partial(model.apply, method=model.map_points)\n", + "warped_points, _ = map_fn(\n", + " {'params': params['model']}, \n", + " points[:, None], hyper_embed[:, None], warp_embed[:, None], \n", + " jax_utils.unreplicate(state.extra_params))\n", + "hyper_points = np.array(warped_points[..., 3:].squeeze())\n", + "print(hyper_points.shape)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "MvdO4IoUhSu-" + }, + "source": [ + "# umin, vmin = hyper_points.min(axis=0)\n", + "# umax, vmax = hyper_points.max(axis=0)\n", + "umin, vmin = np.percentile(hyper_points[..., :2], 20, axis=0)\n", + "umax, vmax = np.percentile(hyper_points[..., :2], 99, axis=0)\n", + "umin, vmin, umax, vmax" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "l0k0Ciz_Wuk6" + }, + "source": [ + "n = 7\n", + "uu, vv = np.meshgrid(np.linspace(umin, umax, n), np.linspace(vmin, vmax, n))\n", + "hyper_grid = np.stack([uu, vv], axis=-1)\n", + "hyper_grid[0, 0], hyper_grid[-1, -1]" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2Rv8WW3kGibq" + }, + "source": [ + "import itertools\n", + "import gc\n", + "gc.collect()\n", + "\n", + "grid_frames = []\n", + "\n", + "camera = datasource.load_camera(item_id).scale(render_scale)\n", + "camera = camera.crop_image_domain(left, right, top, bottom)\n", + "\n", + "batch = make_batch(camera, *get_codes(item_id))\n", + "batch_shape = batch['origins'][..., 0].shape\n", + "for i, j in itertools.product(range(n), range(n)):\n", + " hyper_point = jnp.array(hyper_grid[i, j])\n", + " # hyper_point = jnp.concatenate([hyper_point, jnp.zeros((6,))])\n", + " hyper_point = jnp.broadcast_to(\n", + " hyper_point[None, None, :], \n", + " (*batch_shape, hyper_point.shape[-1]))\n", + " batch['metadata']['hyper_point'] = hyper_point\n", + " \n", + " render = render_fn(state, batch, rng=rng)\n", + " pred_rgb = np.array(render['rgb'])\n", + " pred_depth_med = np.array(render['med_depth'])\n", + " pred_depth_viz = viz.colorize(1.0 / pred_depth_med.squeeze())\n", + " del render\n", + " \n", + " media.show_images([pred_rgb, pred_depth_viz])\n", + " grid_frames.append({\n", + " 'rgb': pred_rgb,\n", + " 'depth': pred_depth_med,\n", + " })\n", + "\n", + "media.show_images([f['rgb'] for f in grid_frames], columns=n)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "nejSHA2WyAQD" + }, + "source": [ + "from matplotlib import pyplot as plt\n", + "from mpl_toolkits.axes_grid1 import ImageGrid\n", + "import numpy as np\n", + "\n", + "fig = plt.figure(figsize=(24., 24.))\n", + "grid = ImageGrid(fig, 111, # similar to subplot(111)\n", + " nrows_ncols=(n, n), # creates 2x2 grid of axes\n", + " axes_pad=0.1, # pad between axes in inch.\n", + " )\n", + "\n", + "images = [f['rgb'] for f in grid_frames]\n", + "for ax, im in zip(grid, images):\n", + " # Iterating over the grid returns the Axes.\n", + " ax.imshow(im)\n", + " ax.set_axis_off()\n", + " ax.margins(x=0, y=0)\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + " ax.set_aspect('equal')\n", + "fig.tight_layout(pad=0)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "4FC-MsLz5dAM" + }, + "source": [ + "from scipy import interpolate\n", + "\n", + "num_samples = 200\n", + "# points = np.random.uniform(0, 1, size=(10, 2))\n", + "rng = random.PRNGKey(3)\n", + "# points = random.uniform(rng, (20, 2))\n", + "points = np.array([\n", + " [0.2, 0.1],\n", + " [0.2, 0.8],\n", + " [0.8, 0.8],\n", + " [0.8, 0.1],\n", + " [0.5, 0.1],\n", + " [0.2, 0.4],\n", + " [0.5, 0.7],\n", + " [0.8, 0.7],\n", + " [0.6, 0.2],\n", + " [0.2, 0.1],\n", + "])\n", + "t = np.arange(len(points))\n", + "xs = np.linspace(0, len(points) - 1, num_samples)\n", + "cs = interpolate.CubicSpline(t, points, bc_type='periodic')\n", + "\n", + "interp_points = cs(xs).astype(np.float32)\n", + "fig, ax = plt.subplots()\n", + "ax.scatter(interp_points[:, 0], interp_points[:, 1], s=2)\n", + "ax.scatter(points[:, 0], points[:, 1])\n", + "ax.set_aspect('equal')" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "lQUcIumi7Ts8" + }, + "source": [ + "interp_hyper_points = np.stack([(umax - umin) * interp_points[:, 0] + umin, (vmax - vmin) * interp_points[:, 1] + vmin], axis=-1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "A7NYM0Rs_Tie" + }, + "source": [ + "## Make Orbit Cameras" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "fuZGlM90_UZ-" + }, + "source": [ + "ref_cameras = utils.parallel_map(datasource.load_camera, datasource.all_ids)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eMsYdVEDXcFA" + }, + "source": [ + "## Select Keyframes and Interpolate Codes" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "71g9fqINo8sC" + }, + "source": [ + "# @title Show training frames to choose IDs\n", + "target_ids = datasource.train_ids[::4]\n", + "target_rgbs = utils.parallel_map(\n", + " lambda i: image_utils.downsample_image(datasource.load_rgb(i), int(1/render_scale)), \n", + " target_ids)\n", + "media.show_images(target_rgbs, titles=target_ids, columns=20)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oSZraLc2Y7QN" + }, + "source": [ + "## Render" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "GU9PaEHmX_QU" + }, + "source": [ + "# @title Latent code functions\n", + "\n", + "reload()\n", + "\n", + "\n", + "def get_hyper_code(params, item_id):\n", + " appearance_id = datasource.get_appearance_id(item_id)\n", + " metadata = {\n", + " 'warp': jnp.array([appearance_id], jnp.uint32),\n", + " 'appearance': jnp.array([appearance_id], jnp.uint32),\n", + " }\n", + " return model.apply({'params': params['model']},\n", + " metadata,\n", + " method=model.encode_hyper_embed)\n", + "\n", + "\n", + "def get_appearance_code(params, item_id):\n", + " appearance_id = datasource.get_appearance_id(item_id)\n", + " metadata = {\n", + " 'appearance': jnp.array([appearance_id], jnp.uint32),\n", + " }\n", + " return model.apply({'params': params['model']},\n", + " metadata,\n", + " method=model.encode_nerf_embed)\n", + "\n", + "\n", + "def get_warp_code(params, item_id):\n", + " warp_id = datasource.get_warp_id(item_id)\n", + " metadata = {\n", + " 'warp': jnp.array([warp_id], jnp.uint32),\n", + " }\n", + " return model.apply({'params': params['model']},\n", + " metadata,\n", + " method=model.encode_warp_embed)\n", + "\n", + "\n", + "params = jax_utils.unreplicate(state.optimizer.target)\n", + "if model.use_rgb_condition:\n", + " test_appearance_code = get_appearance_code(params, datasource.train_ids[0])\n", + " print('appearance code:', test_appearance_code)\n", + "\n", + "if model.use_warp:\n", + " test_warp_code = get_warp_code(params, datasource.train_ids[0])\n", + " print('warp code:', test_warp_code)\n", + "\n", + "if model.has_hyper:\n", + " test_hyper_code = get_hyper_code(params, datasource.train_ids[0])\n", + " print('hyper code:', test_hyper_code)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "3uFsdK2naA4D" + }, + "source": [ + "# @title Render function.\n", + "import functools\n", + "reload()\n", + "\n", + "use_warp = True # @param{type: 'boolean'}\n", + "use_points = False # @param{type: 'boolean'}\n", + "\n", + "params = jax_utils.unreplicate(state.optimizer.target)\n", + "\n", + "\n", + "def _model_fn(key_0, key_1, params, rays_dict, warp_extras):\n", + " out = model.apply({'params': params},\n", + " rays_dict,\n", + " warp_extras,\n", + " rngs={\n", + " 'coarse': key_0,\n", + " 'fine': key_1\n", + " },\n", + " mutable=False,\n", + " metadata_encoded=True,\n", + " return_points=use_points,\n", + " return_weights=use_points,\n", + " use_warp=use_warp)\n", + " return jax.lax.all_gather(out, axis_name='batch')\n", + "\n", + "pmodel_fn = jax.pmap(\n", + " # Note rng_keys are useless in eval mode since there's no randomness.\n", + " _model_fn,\n", + " # key0, key1, params, rays_dict, warp_extras\n", + " in_axes=(0, 0, 0, 0, 0),\n", + " devices=devices,\n", + " donate_argnums=(3,), # Donate the 'rays' argument.\n", + " axis_name='batch',\n", + ")\n", + "\n", + "render_fn = functools.partial(evaluation.render_image,\n", + " model_fn=pmodel_fn,\n", + " device_count=len(devices),\n", + " chunk=8192)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "fK7FZ1q3GLyi" + }, + "source": [ + "item_ids = ['001428']\n", + "item_ids = ['000082']\n", + "# item_ids = ['000457']\n", + "# item_ids = ['000429']\n", + "# item_ids = ['000610'] # ricardo\n", + "render_scale = 1.0\n", + "\n", + "media.show_images([datasource.load_rgb(x) for x in item_ids], titles=item_ids)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "MquILATfGL2x" + }, + "source": [ + "\n", + "import gc\n", + "gc.collect()\n", + "\n", + "base_camera = datasource.load_camera(item_ids[0]).scale(render_scale)\n", + "# base_camera = datasource.load_camera('000037').scale(render_scale)\n", + "base_camera = datasource.load_camera('000389').scale(render_scale)\n", + "# orbit_cameras = [c.scale(render_scale) for c in make_orbit_cameras(360)] \n", + "# base_camera = orbit_cameras[270]\n", + "\n", + "out_frames = []\n", + "for i, item_id in enumerate(item_ids):\n", + " camera = base_camera\n", + " print(f'>>> Rendering ID {item_id} <<<')\n", + " appearance_code = get_appearance_code(params, item_id).squeeze() if model.use_nerf_embed else None\n", + " warp_code = get_warp_code(params, item_id).squeeze() if model.use_warp else None\n", + " hyper_code = get_hyper_code(params, item_id).squeeze() if model.has_hyper_embed else None\n", + " batch = make_batch(camera, appearance_code, warp_code, hyper_code)\n", + "\n", + " render = render_fn(state, batch, rng=rng)\n", + " pred_rgb = np.array(render['rgb'])\n", + " pred_depth_med = np.array(render['med_depth'])\n", + " pred_depth_viz = viz.colorize(1.0 / pred_depth_med.squeeze())\n", + "\n", + " media.show_images([pred_rgb, pred_depth_viz])\n", + " out_frames.append({\n", + " 'rgb': pred_rgb,\n", + " 'depth': pred_depth_med,\n", + " 'med_points': np.array(render['med_points']),\n", + " })\n", + " del batch, render" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "_8t2T_f4IKFB" + }, + "source": [ + "from skimage.color import hsv2rgb\n", + "\n", + "def sinebow(h):\n", + " f = lambda x : np.sin(np.pi * x)**2\n", + " return np.stack([f(3/6-h), f(5/6-h), f(7/6-h)], -1)\n", + "\n", + "\n", + "def colorize_flow(u, v, phase=0, freq=1):\n", + " coords = np.stack([u, v], axis=-1)\n", + " mag = np.linalg.norm(coords, axis=-1) / np.sqrt(2)\n", + " angle = np.arctan2(-v, -u) / np.pi / (2/freq)\n", + " print(angle.min(), angle.max())\n", + " # return viz.colorize(np.log(mag+1e-6), cmap='gray')\n", + " colorwheel = sinebow(angle + phase/360*np.pi)\n", + " # brightness = mag[..., None] ** 1.414\n", + " brightness = mag[..., None] ** 1.0\n", + " # brightness = (25 * np.cbrt(mag[..., None]*100) - 17)/100\n", + " # brightness = (((mag[..., None]*100 + 17)/25)**3)/100\n", + " bg = np.ones_like(colorwheel) * 0.5\n", + " # bg = np.ones_like(colorwheel) * 0.0\n", + " return colorwheel * brightness + bg * (1.0 - brightness)\n", + "\n", + " \n", + "def visualize_hyper_points(frame):\n", + " hyper_points = frame['med_points'].squeeze()[..., 3:]\n", + " uu = (hyper_points[..., 0] - umin) / (umax - umin)\n", + " vv = (hyper_points[..., 1] - vmin) / (vmax - vmin)\n", + " normalized_hyper_points = np.stack([uu, vv], axis=-1)\n", + " normalized_hyper_points = (normalized_hyper_points - 0.5) * 2.0\n", + " print(normalized_hyper_points.min(), normalized_hyper_points.max())\n", + " return colorize_flow(normalized_hyper_points[..., 0], normalized_hyper_points[..., 1])\n", + "\n", + "\n", + "uu = np.linspace(-1, 1, 256)\n", + "vv = np.linspace(-1, 1, 256)\n", + "uu, vv = np.meshgrid(uu, vv)\n", + "\n", + "media.show_image(colorize_flow(uu, vv))\n", + "\n", + "\n", + "# media.show_image(visualize_hyper_points(out_frames[0]))\n", + "for frame in out_frames:\n", + " pred_rgb = frame['rgb']\n", + " pred_depth = frame['depth']\n", + " # depth_viz = viz.colorize(1/pred_depth.squeeze(), cmin=1.6, cmax=3.0, cmap='turbo', invert=False)\n", + " depth_viz = viz.colorize(1/pred_depth.squeeze(), cmin=1.6, cmax=2.3, cmap='turbo', invert=False)\n", + " hyper_viz = visualize_hyper_points(out_frames[0])\n", + " media.show_images([pred_rgb, depth_viz, hyper_viz])\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Zp6vJhCEKCSn" + }, + "source": [ + "uu = np.linspace(-1, 1, 1024)\n", + "vv = np.linspace(-1, 1, 1024)\n", + "uu, vv = np.meshgrid(uu, vv)\n", + "\n", + "media.show_image(colorize_flow(uu, vv))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "GxYcw_ZiirkW" + }, + "source": [ + "from PIL import Image, ImageDraw\n", + "\n", + "\n", + "def crop_circle(img, width=3, color=(0, 0, 0)):\n", + " img = Image.fromarray(image_utils.image_to_uint8(img))\n", + " h,w=img.size\n", + " \n", + " # Create same size alpha layer with circle\n", + " alpha = Image.new('L', img.size,0)\n", + " draw = ImageDraw.Draw(alpha)\n", + " draw.pieslice([0,0,h,w],0,360,fill=255)\n", + " # Convert alpha Image to numpy array\n", + " npAlpha=np.array(alpha)\n", + "\n", + " draw = ImageDraw.Draw(img)\n", + " draw.arc([0, 0, h, w], 0, 360, fill=tuple(color), width=width)\n", + " npImage=np.array(img)\n", + " \n", + " # Add alpha layer to RGB\n", + " npImage=np.dstack((npImage,npAlpha))\n", + " return image_utils.image_to_float32(npImage)\n", + "\n", + "\n", + "media.show_image(crop_circle(images[0]))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "C0v2AYnlIXnr" + }, + "source": [ + "from matplotlib import pyplot as plt\n", + "from mpl_toolkits.axes_grid1 import ImageGrid\n", + "import numpy as np\n", + "\n", + "fig = plt.figure(figsize=(24., 24.))\n", + "grid = ImageGrid(fig, 111, # similar to subplot(111)\n", + " nrows_ncols=(n, n), # creates 2x2 grid of axes\n", + " axes_pad=0.1, # pad between axes in inch.\n", + " )\n", + "\n", + "uu = np.linspace(-1, 1, 7)\n", + "vv = np.linspace(-1, 1, 7)\n", + "uu, vv = np.meshgrid(uu, vv)\n", + "grid_colors = image_utils.image_to_uint8(colorize_flow(uu, vv))\n", + "grid_colors = grid_colors.reshape((-1, 3))\n", + "\n", + "images = [f['rgb'] for f in grid_frames]\n", + "for i, (ax, im) in enumerate(zip(grid, images)):\n", + " # Iterating over the grid returns the Axes.\n", + " color = tuple(grid_colors[i])\n", + " ax.imshow(crop_circle(im, width=14, color=color))\n", + " ax.set_axis_off()\n", + " ax.margins(x=0, y=0)\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + " ax.set_aspect('equal')\n", + "fig.tight_layout(pad=0)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "YTMDxXK4Ig3K" + }, + "source": [ + "(np.array([75, 140, 40, 100], dtype=np.float)*0.9).round()" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/notebooks/figures/hypernerf_optim_latent.ipynb b/notebooks/figures/hypernerf_optim_latent.ipynb new file mode 100644 index 0000000..6d0ae98 --- /dev/null +++ b/notebooks/figures/hypernerf_optim_latent.ipynb @@ -0,0 +1,148 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "hypernerf_optim_latent", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "id": "-QWb-snnOf1I" + }, + "source": [ + "def apply_model(rng, state, batch):\n", + " fine_key, coarse_key = random.split(rng, 2)\n", + " model_out = model.apply(\n", + " {'params': state.optimizer.target['model']}, \n", + " batch,\n", + " extra_params=state.extra_params,\n", + " metadata_encoded=True,\n", + " rngs={'fine': fine_key, 'coarse': coarse_key})\n", + " return model_out\n", + "\n", + "\n", + "def loss_fn(rng, state, target, batch):\n", + " batch['metadata'] = jax.tree_map(lambda x: x.reshape((1, -1)), \n", + " target['metadata'])\n", + " model_out = apply_model(rng, state, batch)['fine']\n", + " # loss = ((model_out['rgb'] - batch['rgb']) ** 2).mean(axis=-1)\n", + " loss = jnp.abs(model_out['rgb'] - batch['rgb']).mean(axis=-1)\n", + " return loss.mean()\n", + "\n", + "\n", + "def optim_step(rng, state, optimizer, batch):\n", + " rng, key = random.split(rng, 2)\n", + " grad_fn = jax.value_and_grad(loss_fn, argnums=2)\n", + " loss, grad = grad_fn(key, state, optimizer.target, batch)\n", + " grad = jax.lax.pmean(grad, axis_name='batch')\n", + " loss = jax.lax.pmean(loss, axis_name='batch')\n", + "\n", + " optimizer = optimizer.apply_gradient(grad)\n", + "\n", + " return rng, loss, optimizer\n", + "\n", + "\n", + "p_optim_step = jax.pmap(optim_step, axis_name='batch')\n", + "\n", + "key = random.PRNGKey(0)\n", + "key = key + jax.process_index()\n", + "keys = random.split(key, jax.local_device_count())\n", + "\n", + "optimizer_def = optim.Adam(0.1)\n", + "init_metadata = evaluation.encode_metadata(\n", + " model, \n", + " jax_utils.unreplicate(state.optimizer.target['model']), \n", + " jax.tree_map(lambda x: x[0, 0], data['metadata']))\n", + "# init_metadata = jax.tree_map(lambda x: x[0], init_metadata)\n", + "# Initialize to zero.\n", + "init_metadata = jax.tree_map(lambda x: jnp.zeros_like(x), init_metadata)\n", + "optimizer = optimizer_def.create({'metadata': init_metadata})\n", + "optimizer = jax_utils.replicate(optimizer, jax.local_devices())\n", + "devices = jax.local_devices()\n", + "batch_size = 1024\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "gzqX4gTWOf3b" + }, + "source": [ + "metadata_progression = []\n", + "\n", + "for i in range(25):\n", + " batch_inds = random.choice(keys[0], np.arange(train_data['rgb'].shape[0]), replace=False, shape=(batch_size,))\n", + " batch = jax.tree_map(lambda x: x[batch_inds, ...], train_data)\n", + " batch = datasets.prepare_data(batch)\n", + " keys, loss, optimizer = p_optim_step(keys, state, optimizer, batch)\n", + " loss = jax_utils.unreplicate(loss)\n", + " metadata_progression.append(jax.tree_map(lambda x: np.array(x), jax_utils.unreplicate(optimizer.target['metadata'])))\n", + " print(f'train_loss = {loss.item():.04f}')\n", + " del batch" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "-s9pCeiqZhKi" + }, + "source": [ + "frames = []\n", + "for metadata in metadata_progression:\n", + "# metadata = jax_utils.unreplicate(optimizer.target['metadata'])\n", + " camera = datasource.load_camera(target_id).scale(1.0)\n", + " batch = make_batch(camera, None, metadata['encoded_warp'], metadata['encoded_hyper'])\n", + " render = render_fn(state, batch, rng=rng)\n", + " pred_rgb = np.array(render['rgb'])\n", + " pred_depth_med = np.array(render['med_depth'])\n", + " pred_depth_viz = viz.colorize(1.0 / pred_depth_med.squeeze())\n", + " media.show_images([pred_rgb, pred_depth_viz])\n", + " frames.append({ \n", + " 'rgb': pred_rgb,\n", + " 'depth': pred_depth_med,\n", + " })\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "tu8cC9gKe2BS" + }, + "source": [ + "media.show_image(data['rgb'])\n", + "media.show_videos([\n", + " [d['rgb'] for d in frames],\n", + " [viz.colorize(1/d['depth'].squeeze(), cmin=1.5, cmax=2.9) for d in frames],\n", + "], fps=10)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "EUpAEa4boPbw" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/notebooks/figures/level_set_visualization.ipynb b/notebooks/figures/level_set_visualization.ipynb new file mode 100644 index 0000000..c60fd5c --- /dev/null +++ b/notebooks/figures/level_set_visualization.ipynb @@ -0,0 +1,117 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "level_set_visualization.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "id": "_QO__KXIxS78" + }, + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import cm\n", + "import mediapy as media" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "fiVMLQ0bxmXQ" + }, + "source": [ + "def f(x, y):\n", + " r = np.sqrt(x**2 + y**2)\n", + " n = r*5+1\n", + " z = (np.abs(x)**n + np.abs(y)**n)**(1/n)\n", + " return z\n", + "\n", + "def g(x, y):\n", + " z = 1 - np.minimum(f(x - 0.5, y), f(x + 0.5, y))\n", + " return np.maximum(z, 0)\n", + "\n", + "n = 100\n", + "x, y = np.meshgrid(np.linspace(-1.5, 1.5, 2*n), np.linspace(-1, 1, n), indexing='xy')\n", + "plt.contourf(g(x, y))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Yv_ZrUADzXk7" + }, + "source": [ + "\n", + "n = 1000\n", + "x, y = np.meshgrid(np.linspace(-1.5, 1.5, 2*n), np.linspace(-1, 1, n), indexing='xy')\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 6), subplot_kw={\"projection\": \"3d\"})\n", + "\n", + "z = g(x, y)\n", + "\n", + "z0s = [0.2, 0.5, 0.8]\n", + "colors = [cm.tab10(0), cm.tab10(1), cm.tab10(2)]\n", + "\n", + "ax.plot_surface(x, y, z, linewidth=0, antialiased=False, color='gray')\n", + "\n", + "x, y = np.meshgrid([-1.5, 1.5], [-1, 1], indexing='xy')\n", + "xv = np.array([-1.5, -1.5, 1.5, 1.5, -1.5])\n", + "yv = np.array([-1, 1, 1, -1, -1])\n", + "\n", + "for z0, color in zip(z0s, colors):\n", + " ax.plot3D(xv, yv, z0, zorder=10, color=color)\n", + "\n", + "ax.axis(False)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "d1pX7ypI3PE0" + }, + "source": [ + "vis = []\n", + "for z0, color in zip(z0s, colors):\n", + " mask = z > z0\n", + " vis.append(1-mask[:,:,None] * (1-np.array(color[:3])[None,None,:]))\n", + "\n", + "plt.figure(figsize=(8,8))\n", + "plt.imshow(np.concatenate(vis[::-1], 0))\n", + "plt.axis(False)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "gS1-LG-u6xZs" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/notebooks/figures/nerfies_2d_experiments.ipynb b/notebooks/figures/nerfies_2d_experiments.ipynb new file mode 100644 index 0000000..3156c68 --- /dev/null +++ b/notebooks/figures/nerfies_2d_experiments.ipynb @@ -0,0 +1,921 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "nerfies_2d_experiments.ipynb", + "provenance": [ + { + "file_id": "1wdNjll4hs3b8yy7O0IKpXSXrXAyOVEaY", + "timestamp": 1631790666234 + }, + { + "file_id": "1Zv_LdAL82dfOruTngKaSLViawaa3SkL6", + "timestamp": 1615009722746 + } + ], + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", + "kind": "private" + } + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + } + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "id": "N_PcTdlrG3nu" + }, + "source": [ + "# @title Imports\n", + "\n", + "from dataclasses import dataclass\n", + "from pprint import pprint\n", + "from typing import Any, List, Callable, Dict, Sequence, Optional, Tuple\n", + "from io import BytesIO\n", + "from IPython.display import display, HTML\n", + "from base64 import b64encode\n", + "import PIL\n", + "import IPython\n", + "import tempfile\n", + "import imageio\n", + "\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "import tensorflow as tf\n", + "\n", + "import jax\n", + "from jax.config import config as jax_config\n", + "import jax.numpy as jnp\n", + "from jax import grad, jit, vmap\n", + "from jax import random\n", + "\n", + "import flax\n", + "import flax.linen as nn\n", + "# from flax import nn\n", + "from flax import jax_utils\n", + "from flax import optim\n", + "from flax.metrics import tensorboard\n", + "from flax.training import checkpoints\n", + "\n", + "from absl import logging\n", + "\n", + "# Monkey patch logging.\n", + "def myprint(msg, *args, **kwargs):\n", + " print(msg % args)\n", + "\n", + "logging.info = myprint \n", + "logging.warn = myprint" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "u4kI6OZ0MwPP" + }, + "source": [ + "## Utility Functions" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bph93J0PMyeI", + "cellView": "form" + }, + "source": [ + "# @title Dataset Utilities\n", + "\n", + "def prepare_tf_data(xs):\n", + " \"\"\"Convert a input batch from tf Tensors to numpy arrays.\"\"\"\n", + " local_device_count = jax.local_device_count()\n", + " def _prepare(x):\n", + " # Use _numpy() for zero-copy conversion between TF and NumPy.\n", + " x = x._numpy() # pylint: disable=protected-access\n", + "\n", + " # reshape (host_batch_size, height, width, 3) to\n", + " # (local_devices, device_batch_size, height, width, 3)\n", + " return x.reshape((local_device_count, -1) + x.shape[1:])\n", + "\n", + " return jax.tree_map(_prepare, xs)\n", + "\n", + "\n", + "def iterator_from_dataset(dataset: tf.data.Dataset,\n", + " batch_size: int,\n", + " repeat: bool = True,\n", + " prefetch_size: int = 0,\n", + " devices: Optional[Sequence[Any]] = None):\n", + " \"\"\"Create a data iterator that returns JAX arrays from a TF dataset.\n", + "\n", + " Args:\n", + " dataset: the dataset to iterate over.\n", + " batch_size: the batch sizes the iterator should return.\n", + " repeat: whether the iterator should repeat the dataset.\n", + " prefetch_size: the number of batches to prefetch to device.\n", + " devices: the devices to prefetch to.\n", + "\n", + " Returns:\n", + " An iterator that returns data batches.\n", + " \"\"\"\n", + " if repeat:\n", + " dataset = dataset.repeat()\n", + "\n", + " if batch_size > 0:\n", + " dataset = dataset.batch(batch_size)\n", + " it = map(prepare_tf_data, dataset)\n", + " else:\n", + " it = map(prepare_tf_data_unbatched, dataset)\n", + "\n", + " if prefetch_size > 0:\n", + " it = jax_utils.prefetch_to_device(it, prefetch_size, devices)\n", + "\n", + " return it\n" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ux0cH2TpQyEN" + }, + "source": [ + "## Create Data" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Ry0smCvRyttR" + }, + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from matplotlib import cm\n", + "import matplotlib as mpl\n", + "\n", + "\n", + "def fig_to_array(fig, height, width, dpi=100):\n", + " out_array = np.zeros((height, width, 4), dtype=np.uint8)\n", + " fig.set_size_inches((width / dpi, height / dpi))\n", + " fig.set_dpi(dpi)\n", + " for ax in fig.axes:\n", + " ax.margins(0, 0)\n", + " ax.axis('off')\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + "\n", + " # If we haven't already shown or saved the plot, then we need to\n", + " # draw the figure first...\n", + " fig.tight_layout(pad=0)\n", + " fig.canvas.draw()\n", + "\n", + " # Now we can save it to a numpy array.\n", + " data = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)\n", + " data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))\n", + " plt.close()\n", + "\n", + " np.copyto(out_array, data)\n", + " out_array = np.roll(out_array, -1, 2)\n", + " return out_array\n", + "\n", + "\n", + "def add_colorwheel(fig, cmap, lim, label):\n", + " display_axes = fig.add_axes([0.1,0.1,0.8,0.8], projection='polar', label=label)\n", + " display_axes._direction = 2*np.pi ## This is a nasty hack - using the hidden field to \n", + " ## multiply the values such that 1 become 2*pi\n", + " ## this field is supposed to take values 1 or -1 only!!\n", + " \n", + " norm = mpl.colors.Normalize(0.0, 2*np.pi)\n", + " \n", + " # Plot the colorbar onto the polar axis\n", + " # note - use orientation horizontal so that the gradient goes around\n", + " # the wheel rather than centre out\n", + " quant_steps = 2056\n", + " cb = mpl.colorbar.ColorbarBase(display_axes, cmap=cm.get_cmap(cmap,quant_steps),\n", + " norm=norm,\n", + " orientation='horizontal')\n", + " \n", + " # aesthetics - get rid of border and axis labels \n", + " cb.outline.set_visible(False) \n", + " display_axes.set_axis_off()\n", + " display_axes.set_rlim(lim)\n", + "\n", + " display_axes.margins(0, 0)\n", + " display_axes.axis('off')\n", + " display_axes.get_xaxis().set_visible(False)\n", + " display_axes.get_yaxis().set_visible(False)\n", + " display_axes.set_aspect(1)\n", + "\n", + "fig = plt.figure()\n", + "add_colorwheel(fig, 'tab20b', (-1.1, 1.0), 'a')\n", + "add_colorwheel(fig, 'twilight', (-2.0, 1), 'b')\n", + "add_colorwheel(fig, 'hsv', (-5.0, 1), 'c')\n", + "\n", + "colorwheel = fig_to_array(fig, 400, 400)\n", + "mediapy.show_image(colorwheel)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "v-_vZ8cfFtcv" + }, + "source": [ + "# import matplotlib\n", + "import scipy\n", + "from PIL import Image, ImageDraw, ImageFont, ImageOps\n", + "\n", + "def barron_colormap(n):\n", + " curve = lambda x, s, t: np.where(x