diff --git a/pytti_5_beta.ipynb b/pytti_5_beta.ipynb index 57f5dd8..b2dd270 100644 --- a/pytti_5_beta.ipynb +++ b/pytti_5_beta.ipynb @@ -333,30 +333,25 @@ "except ModuleNotFoundError:\n", " everything_installed = False\n", "def install_everything():\n", - " !pip install tensorflow==1.15.2\n", - " !pip install transformers &> /dev/null \n", - " !pip install PyGLM &> /dev/null\n", - " !pip install ftfy regex tqdm omegaconf pytorch-lightning &> /dev/null\n", - " !pip install kornia &> /dev/null\n", - " !pip install einops &> /dev/null\n", - " !pip install imageio-ffmpeg &> /dev/null\n", - " !pip install adjustText exrex bunch &> /dev/null\n", - " !pip install matplotlib-label-lines &> /dev/null\n", " !git clone https://github.com/openai/CLIP.git &> /dev/null\n", " !git clone https://github.com/CompVis/taming-transformers.git &> /dev/null\n", " if not path_exists('./pytti'):\n", - " !git clone --branch p5 https://github.com/pytti-tools/pytti-core.git pytti &> /dev/null\n", + " !git clone --branch fix_chdir_imports https://github.com/pytti-tools/pytti-core.git pytti &> /dev/null\n", " else:\n", " !rm -r pytti\n", - " !git clone --branch p5 https://github.com/pytti-tools/pytti-core.git pytti\n", + " !git clone --branch fix_chdir_imports https://github.com/pytti-tools/pytti-core.git pytti\n", + " !pip install -r pytti/requirements.txt\n", " !git clone https://github.com/shariqfarooq123/AdaBins.git &> /dev/null\n", " !git clone https://github.com/zacjiang/GMA.git &> /dev/null\n", - " !mkdir -p AdaBins/pretrained\n", - " if not path_exists('AdaBins/pretrained/AdaBins_nyu.pt'):\n", + " !touch AdaBins/__init__.py\n", + " !touch GMA/__init__.py\n", + " !touch GMA/core/__init__.py\n", + " !mkdir -p ./pretrained\n", + " if not path_exists('./pretrained/AdaBins_nyu.pt'):\n", " !gdown https://drive.google.com/uc?id=1lvyZZbC9NLcS8a__YPcUP7rDiIpbRpoF\n", " if not path_exists('AdaBins_nyu.pt'):\n", " !gdown https://drive.google.com/uc?id=1zgGJrkFkJbRouqMaWArXE4WF_rhj-pxW\n", - " !mv AdaBins_nyu.pt AdaBins/pretrained/AdaBins_nyu.pt\n", + " !mv AdaBins_nyu.pt ./pretrained/AdaBins_nyu.pt\n", " \n", " from pytti.Notebook import change_tqdm_color\n", " change_tqdm_color()\n", @@ -652,472 +647,35 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "cellView": "form", - "id": "hqJ6vY2z3rR8" - }, + "metadata": {}, "outputs": [], "source": [ - "#@title 2.3 Run it!\n", - "#@markdown pytti is 1000% percent better code than VQLIPSE, so have a look at the code. \n", - "#@markdown You just might understand what's going on.\n", - "import torch\n", + "from pytti.workhorse import _main, TB_LOGDIR \n", "\n", - "from os.path import exists as path_exists\n", - "if path_exists('/content/drive/MyDrive/pytti_test'):\n", - " %cd /content/drive/MyDrive/pytti_test\n", - " drive_mounted = True\n", - "else:\n", - " drive_mounted = False\n", - "try:\n", - " from pytti.Notebook import *\n", - "except ModuleNotFoundError:\n", - " if drive_mounted:\n", - " #THIS IS NOT AN ERROR. This is the code that would\n", - " #make an error if something were wrong.\n", - " raise RuntimeError('ERROR: please run setup (step 1.3).')\n", - " else:\n", - " #THIS IS NOT AN ERROR. This is the code that would\n", - " #make an error if something were wrong.\n", - " raise RuntimeError('WARNING: drive is not mounted.\\nERROR: please run setup (step 1.3).')\n", - "change_tqdm_color()\n", - "import sys\n", - "sys.path.append('./AdaBins')\n", - "\n", - "try:\n", - " from pytti import Perceptor\n", - "except ModuleNotFoundError:\n", - " if drive_mounted:\n", - " #THIS IS NOT AN ERROR. This is the code that would\n", - " #make an error if something were wrong.\n", - " raise RuntimeError('ERROR: please run setup (step 1.3).')\n", - " else:\n", - " #THIS IS NOT AN ERROR. This is the code that would\n", - " #make an error if something were wrong.\n", - " raise RuntimeError('WARNING: drive is not mounted.\\nERROR: please run setup (step 1.3).')\n", - "print(\"Loading pytti...\")\n", - "from pytti.Image import PixelImage, RGBImage, VQGANImage\n", - "from pytti.ImageGuide import DirectImageGuide\n", - "from pytti.Perceptor.Embedder import HDMultiClipEmbedder\n", - "from pytti.Perceptor.Prompt import parse_prompt\n", - "from pytti.LossAug import TVLoss, HSVLoss, OpticalFlowLoss, TargetFlowLoss\n", - "from pytti.Transforms import zoom_2d, zoom_3d, apply_flow\n", - "from pytti import *\n", - "from pytti.LossAug.DepthLoss import init_AdaBins\n", - "print(\"pytti loaded.\")\n", - "\n", - "import torch, gc, glob, subprocess, warnings, re, math, json\n", - "import numpy as np\n", - "from IPython import display\n", - "from PIL import Image, ImageEnhance\n", - "\n", - "from torchvision.transforms import functional as TF\n", - "\n", - "#display settings, because usability counts\n", - "#warnings.filterwarnings(\"error\", category=UserWarning)\n", - "%matplotlib inline \n", - "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", - "sns.set()\n", - "import pandas as pd\n", - "plt.style.use('bmh')\n", - "pd.options.display.max_columns = None\n", - "pd.options.display.width = 175\n", - "\n", - "latest = -1\n", - "#@markdown check `batch_mode` to run batch settings\n", - "batch_mode = False #@param{type:\"boolean\"}\n", - "if batch_mode:\n", - " try:\n", - " batch_list\n", - " except NameError:\n", - " raise RuntimeError(\"ERROR: no batch settings. Please run 'batch settings' cell at the bottom of the page to use batch mode.\")\n", - "else:\n", - " try:\n", - " params\n", - " except NameError:\n", - " raise RuntimeError(\"ERROR: no parameters. Please run parameters (step 2.1).\")\n", - "#@markdown check `restore` to restore from a previous run\n", - "restore = False#@param{type:\"boolean\"}\n", - "#@markdown check `reencode` if you are restoring with a modified image or modified image settings\n", - "reencode = False#@param{type:\"boolean\"}\n", - "#@markdown which run to restore\n", - "restore_run = latest #@param{type:\"raw\"}\n", - "if restore and restore_run == latest:\n", - " _, restore_run = get_last_file(f'backup/{params.file_namespace}', \n", - " f'^(?P
{re.escape(params.file_namespace)}\\\\(?)(?P\\\\d*)(?P\\\\)?_\\\\d+\\\\.bak)$')\n",
-        "\n",
-        "def do_run():\n",
-        "  clear_rotoscopers()#what a silly name\n",
-        "  vram_profiling(params.approximate_vram_usage)\n",
-        "  reset_vram_usage()\n",
-        "  global CLIP_MODEL_NAMES\n",
-        "  #@markdown which frame to restore from\n",
-        "  restore_frame =  latest#@param{type:\"raw\"}\n",
-        "\n",
-        "  #set up seed for deterministic RNG\n",
-        "  if params.seed is not None:\n",
-        "    torch.manual_seed(params.seed)\n",
-        "\n",
-        "  #load CLIP\n",
-        "  load_clip(params)\n",
-        "  embedder = HDMultiClipEmbedder(cutn=params.cutouts, \n",
-        "                                 cut_pow = params.cut_pow, \n",
-        "                                 padding = params.cutout_border,\n",
-        "                                 border_mode = params.border_mode)\n",
-        "  \n",
-        "  #load scenes\n",
-        "  with vram_usage_mode('Text Prompts'):\n",
-        "    print('Loading prompts...')\n",
-        "    prompts = [[parse_prompt(embedder, p.strip()) \n",
-        "              for p in (params.scene_prefix + stage + params.scene_suffix).strip().split('|') if p.strip()]\n",
-        "              for stage in params.scenes.split('||') if stage]\n",
-        "    print('Prompts loaded.')\n",
-        "\n",
-        "  #load init image\n",
-        "  if params.init_image != '':\n",
-        "    init_image_pil = Image.open(fetch(params.init_image)).convert('RGB')\n",
-        "    init_size = init_image_pil.size\n",
-        "    #automatic aspect ratio matching\n",
-        "    if params.width == -1:\n",
-        "      params.width = int(params.height*init_size[0]/init_size[1])\n",
-        "    if params.height == -1:\n",
-        "      params.height = int(params.width*init_size[1]/init_size[0])\n",
-        "  else:\n",
-        "    init_image_pil = None\n",
-        "\n",
-        "  #video source\n",
-        "  if params.animation_mode == \"Video Source\":\n",
-        "    print(f'loading {params.video_path}...')\n",
-        "    video_frames = get_frames(params.video_path)\n",
-        "    params.pre_animation_steps = max(params.steps_per_frame, params.pre_animation_steps)\n",
-        "    if init_image_pil is None:\n",
-        "      init_image_pil = Image.fromarray(video_frames.get_data(0)).convert('RGB')\n",
-        "      #enhancer = ImageEnhance.Contrast(init_image_pil)\n",
-        "      #init_image_pil = enhancer.enhance(2)\n",
-        "      init_size = init_image_pil.size\n",
-        "      if params.width == -1:\n",
-        "        params.width = int(params.height*init_size[0]/init_size[1])\n",
-        "      if params.height == -1:\n",
-        "        params.height = int(params.width*init_size[1]/init_size[0])\n",
-        "\n",
-        "  #set up image\n",
-        "  if params.image_model == \"Limited Palette\":\n",
-        "    img = PixelImage(*format_params(params,\n",
-        "                     'width', 'height', 'pixel_size', \n",
-        "                     'palette_size', 'palettes', 'gamma', \n",
-        "                     'hdr_weight', 'palette_normalization_weight'))\n",
-        "    img.encode_random(random_pallet = params.random_initial_palette)\n",
-        "    if params.target_palette.strip() != '':\n",
-        "      img.set_pallet_target(Image.open(fetch(params.target_palette)).convert('RGB'))\n",
-        "    else:\n",
-        "      img.lock_pallet(params.lock_palette)\n",
-        "  elif params.image_model == \"Unlimited Palette\":\n",
-        "    img = RGBImage(params.width, params.height, params.pixel_size)\n",
-        "    img.encode_random()\n",
-        "  elif params.image_model == \"VQGAN\":\n",
-        "    VQGANImage.init_vqgan(params.vqgan_model)\n",
-        "    img = VQGANImage(params.width, params.height, params.pixel_size)\n",
-        "    img.encode_random()\n",
-        "\n",
-        "  loss_augs = []\n",
-        "\n",
-        "  if init_image_pil is not None:\n",
-        "    if not restore:\n",
-        "      print(\"Encoding image...\")\n",
-        "      img.encode_image(init_image_pil)\n",
-        "      print(\"Encoded Image:\")\n",
-        "      display.display(img.decode_image())\n",
-        "    #set up init image prompt\n",
-        "    init_augs = ['direct_init_weight']\n",
-        "    init_augs = [build_loss(x,params[x],f'init image ({params.init_image})', img, init_image_pil) \n",
-        "                  for x in init_augs if params[x] not in ['','0']]\n",
-        "    loss_augs.extend(init_augs)\n",
-        "    if params.semantic_init_weight not in ['','0']:\n",
-        "      semantic_init_prompt = parse_prompt(embedder, \n",
-        "                                    f\"init image [{params.init_image}]:{params.semantic_init_weight}\", \n",
-        "                                    init_image_pil)\n",
-        "      prompts[0].append(semantic_init_prompt)\n",
-        "    else:\n",
-        "      semantic_init_prompt = None\n",
-        "  else:\n",
-        "    init_augs, semantic_init_prompt = [], None\n",
-        "\n",
-        "  #other image prompts\n",
-        "\n",
-        "  loss_augs.extend(type(img).get_preferred_loss().TargetImage(p.strip(), img.image_shape, is_path = True) \n",
-        "                   for p in params.direct_image_prompts.split('|') if p.strip())\n",
-        "\n",
-        "  #stabilization\n",
-        "\n",
-        "  stabilization_augs = ['direct_stabilization_weight',\n",
-        "                        'depth_stabilization_weight',\n",
-        "                        'edge_stabilization_weight']\n",
-        "  stabilization_augs = [build_loss(x,params[x],'stabilization',\n",
-        "                                   img, init_image_pil) \n",
-        "                        for x in stabilization_augs if params[x] not in ['','0']]\n",
-        "  loss_augs.extend(stabilization_augs)\n",
-        "  \n",
-        "  if params.semantic_stabilization_weight not in ['0','']:\n",
-        "    last_frame_semantic = parse_prompt(embedder, \n",
-        "                                       f\"stabilization:{params.semantic_stabilization_weight}\", \n",
-        "                                       init_image_pil if init_image_pil else img.decode_image())\n",
-        "    last_frame_semantic.set_enabled(init_image_pil is not None)\n",
-        "    for scene in prompts:\n",
-        "      scene.append(last_frame_semantic)\n",
-        "  else:\n",
-        "    last_frame_semantic = None\n",
-        "\n",
-        "  #optical flow\n",
-        "  if params.animation_mode == 'Video Source':\n",
-        "    if params.flow_stabilization_weight == '':\n",
-        "      params.flow_stabilization_weight = '0'\n",
-        "    optical_flows = [OpticalFlowLoss.TargetImage(f\"optical flow stabilization (frame {-2**i}):{params.flow_stabilization_weight}\", \n",
-        "                                                 img.image_shape) \n",
-        "                     for i in range(params.flow_long_term_samples + 1)]\n",
-        "    for optical_flow in optical_flows:\n",
-        "      optical_flow.set_enabled(False)\n",
-        "    loss_augs.extend(optical_flows)\n",
-        "  elif params.animation_mode == '3D' and params.flow_stabilization_weight not in ['0','']:\n",
-        "    optical_flows = [TargetFlowLoss.TargetImage(f\"optical flow stabilization:{params.flow_stabilization_weight}\", \n",
-        "                                                img.image_shape)]\n",
-        "    for optical_flow in optical_flows:\n",
-        "      optical_flow.set_enabled(False)\n",
-        "    loss_augs.extend(optical_flows)\n",
-        "  else:\n",
-        "    optical_flows = []\n",
-        "  #other loss augs\n",
-        "  if params.smoothing_weight != 0:\n",
-        "    loss_augs.append(TVLoss(weight = params.smoothing_weight))\n",
-        "  \n",
-        "  #set up filespace\n",
-        "  subprocess.run(['mkdir','-p',f'images_out/{params.file_namespace}'])\n",
-        "  subprocess.run(['mkdir','-p',f'backup/{params.file_namespace}'])\n",
-        "  if restore:\n",
-        "    base_name = params.file_namespace if restore_run == 0 else f'{params.file_namespace}({restore_run})'\n",
-        "  elif not params.allow_overwrite:\n",
-        "    #finds the next available base_name to save files with. Why did I do this with regex? \n",
-        "    _, i = get_next_file(f'images_out/{params.file_namespace}', \n",
-        "                         f'^(?P
{re.escape(params.file_namespace)}\\\\(?)(?P\\\\d*)(?P\\\\)?_1\\\\.png)$',\n",
-        "                         [f\"{params.file_namespace}_1.png\",f\"{params.file_namespace}(1)_1.png\"])\n",
-        "    base_name = params.file_namespace if i == 0 else f'{params.file_namespace}({i})'\n",
-        "  else:\n",
-        "    base_name = params.file_namespace\n",
-        "\n",
-        "  #restore\n",
-        "  if restore:\n",
-        "    if not reencode:\n",
-        "      if restore_frame == latest:\n",
-        "        filename, restore_frame = get_last_file(f'backup/{params.file_namespace}', \n",
-        "                                                f'^(?P
{re.escape(base_name)}_)(?P\\\\d*)(?P\\\\.bak)$')\n",
-        "      else: \n",
-        "        filename = f'{base_name}_{restore_frame}.bak'\n",
-        "      print(\"restoring from\", filename)\n",
-        "      img.load_state_dict(torch.load(f'backup/{params.file_namespace}/{filename}'))\n",
-        "    else:#reencode\n",
-        "      if restore_frame == latest:\n",
-        "        filename, restore_frame = get_last_file(f'images_out/{params.file_namespace}', \n",
-        "                                                f'^(?P
{re.escape(base_name)}_)(?P\\\\d*)(?P\\\\.png)$')\n",
-        "      else: \n",
-        "        filename = f'{base_name}_{restore_frame}.png'\n",
-        "      print(\"restoring from\", filename)\n",
-        "      img.encode_image(Image.open(f'images_out/{params.file_namespace}/{filename}').convert('RGB'))\n",
-        "    i = restore_frame*params.save_every\n",
-        "  else:\n",
-        "    i = 0\n",
+        "%load_ext tensorboard"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "%tensorboard --logdir $TB_LOGDIR "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "#@title 2.3 Run it!\n",
+        "from omegaconf import OmegaConf\n",
+        "cfg = OmegaConf.create(dict(params))\n",
         "\n",
-        "  #graphs\n",
-        "  if params.show_graphs:\n",
-        "    fig, axs = plt.subplots(4, 1, figsize=(21,13))\n",
-        "    axs  = np.asarray(axs).flatten()\n",
-        "    #fig.facecolor = (0,0,0)\n",
-        "  else:\n",
-        "    fig, axs = None, None\n",
-        "\n",
-        "  #make the main model object\n",
-        "  model = DirectImageGuide(img, embedder, lr = params.learning_rate)\n",
-        "\n",
-        "  #Update is called each step.\n",
-        "  def update(i, stage_i):\n",
-        "    #display\n",
-        "    if params.clear_every > 0 and i > 0 and i % params.clear_every == 0:\n",
-        "      display.clear_output()\n",
-        "    if params.display_every > 0 and i % params.display_every == 0:\n",
-        "      print(f\"Step {i} losses:\")\n",
-        "      if model.dataframe:\n",
-        "        print(model.dataframe[0].iloc[-1])\n",
-        "      if params.approximate_vram_usage:\n",
-        "        print(\"VRAM Usage:\")\n",
-        "        print_vram_usage()\n",
-        "      display_width = int(img.image_shape[0]*params.display_scale)\n",
-        "      display_height = int(img.image_shape[1]*params.display_scale)\n",
-        "      if stage_i > 0 and params.show_graphs:\n",
-        "        model.plot_losses(axs)\n",
-        "        im = img.decode_image()\n",
-        "        sidebyside = make_hbox(im.resize((display_width, display_height), Image.LANCZOS), fig)\n",
-        "        display.display(sidebyside)\n",
-        "      else:\n",
-        "        im = img.decode_image()\n",
-        "        display.display(im.resize((display_width, display_height), Image.LANCZOS))\n",
-        "      if params.show_palette and isinstance(img, PixelImage):\n",
-        "        print('Palette:')\n",
-        "        display.display(img.render_pallet())\n",
-        "    #save\n",
-        "    if i > 0 and params.save_every > 0 and i % params.save_every == 0:\n",
-        "      try:\n",
-        "        im\n",
-        "      except NameError:\n",
-        "        im = img.decode_image()\n",
-        "      n = i//params.save_every\n",
-        "      filename = f\"images_out/{params.file_namespace}/{base_name}_{n}.png\"\n",
-        "      im.save(filename)\n",
-        "      if params.backups > 0:\n",
-        "        filename = f\"backup/{params.file_namespace}/{base_name}_{n}.bak\"\n",
-        "        torch.save(img.state_dict(), filename)\n",
-        "        if n > params.backups:\n",
-        "          subprocess.run(['rm', f\"backup/{params.file_namespace}/{base_name}_{n-params.backups}.bak\"])\n",
-        "    #animate\n",
-        "    t = (i - params.pre_animation_steps)/(params.steps_per_frame*params.frames_per_second)\n",
-        "    set_t(t)\n",
-        "    if i >= params.pre_animation_steps:\n",
-        "      if (i - params.pre_animation_steps) % params.steps_per_frame == 0:\n",
-        "        print(f\"Time: {t:.4f} seconds\")\n",
-        "        update_rotoscopers(((i - params.pre_animation_steps)//params.steps_per_frame+1)*params.frame_stride)\n",
-        "        if params.reset_lr_each_frame:\n",
-        "          model.set_optim(None)\n",
-        "        if params.animation_mode == \"2D\":\n",
-        "          tx, ty = parametric_eval(params.translate_x), parametric_eval(params.translate_y)\n",
-        "          theta = parametric_eval(params.rotate_2d)\n",
-        "          zx, zy = parametric_eval(params.zoom_x_2d), parametric_eval(params.zoom_y_2d)\n",
-        "          next_step_pil = zoom_2d(img, \n",
-        "                                  (tx,ty), (zx,zy), theta, \n",
-        "                                  border_mode = params.infill_mode, sampling_mode = params.sampling_mode)\n",
-        "        elif params.animation_mode == \"3D\":\n",
-        "          try:\n",
-        "            im\n",
-        "          except NameError:\n",
-        "            im = img.decode_image()\n",
-        "          with vram_usage_mode('Optical Flow Loss'):\n",
-        "            flow, next_step_pil = zoom_3d(img, \n",
-        "                                        (params.translate_x,params.translate_y,params.translate_z_3d), params.rotate_3d, \n",
-        "                                        params.field_of_view, params.near_plane, params.far_plane,\n",
-        "                                        border_mode = params.infill_mode, sampling_mode = params.sampling_mode,\n",
-        "                                        stabilize = params.lock_camera)\n",
-        "            freeze_vram_usage()\n",
-        "            \n",
-        "          for optical_flow in optical_flows:\n",
-        "            optical_flow.set_last_step(im)\n",
-        "            optical_flow.set_target_flow(flow)\n",
-        "            optical_flow.set_enabled(True)\n",
-        "        elif params.animation_mode == \"Video Source\":\n",
-        "          frame_n = min((i - params.pre_animation_steps)*params.frame_stride//params.steps_per_frame, len(video_frames) - 1)\n",
-        "          next_frame_n = min(frame_n + params.frame_stride, len(video_frames) - 1)\n",
-        "          next_step_pil = Image.fromarray(video_frames.get_data(next_frame_n)).convert('RGB').resize(img.image_shape, Image.LANCZOS)\n",
-        "          for j, optical_flow in enumerate(optical_flows):\n",
-        "            old_frame_n = frame_n - (2**j - 1)*params.frame_stride\n",
-        "            save_n = i//params.save_every - (2**j - 1)\n",
-        "            if old_frame_n < 0 or save_n < 1:\n",
-        "              break\n",
-        "            current_step_pil = Image.fromarray(video_frames.get_data(old_frame_n)).convert('RGB').resize(img.image_shape, Image.LANCZOS)\n",
-        "            filename = f\"backup/{params.file_namespace}/{base_name}_{save_n}.bak\"\n",
-        "            filename = None if j == 0 else filename\n",
-        "            flow_im, mask_tensor = optical_flow.set_flow(current_step_pil, next_step_pil, \n",
-        "                                                        img, filename, \n",
-        "                                                        params.infill_mode, params.sampling_mode)\n",
-        "            optical_flow.set_enabled(True)\n",
-        "            #first flow is previous frame\n",
-        "            if j == 0:\n",
-        "              mask_accum = mask_tensor.detach()\n",
-        "              valid = mask_tensor.mean()\n",
-        "              print(\"valid pixels:\", valid)\n",
-        "              if params.reencode_each_frame or valid < .03:\n",
-        "                if isinstance(img, PixelImage) and valid >= .03:\n",
-        "                  img.lock_pallet()\n",
-        "                  img.encode_image(next_step_pil, smart_encode = False)\n",
-        "                  img.lock_pallet(params.lock_palette)\n",
-        "                else:\n",
-        "                  img.encode_image(next_step_pil)\n",
-        "                reencoded = True\n",
-        "              else:\n",
-        "                reencoded = False\n",
-        "            else:\n",
-        "              with torch.no_grad():\n",
-        "                optical_flow.set_mask((mask_tensor - mask_accum).clamp(0,1))\n",
-        "                mask_accum.add_(mask_tensor)\n",
-        "        if params.animation_mode != 'off':\n",
-        "          for aug in stabilization_augs:\n",
-        "            aug.set_comp(next_step_pil)\n",
-        "            aug.set_enabled(True)\n",
-        "          if last_frame_semantic is not None:\n",
-        "            last_frame_semantic.set_image(embedder, next_step_pil)\n",
-        "            last_frame_semantic.set_enabled(True)\n",
-        "          for aug in init_augs:\n",
-        "            aug.set_enabled(False)\n",
-        "          if semantic_init_prompt is not None:\n",
-        "            semantic_init_prompt.set_enabled(False)\n",
-        "            \n",
-        "      \n",
-        "  model.update = update\n",
-        "  \n",
-        "  print(f\"Settings saved to images_out/{params.file_namespace}/{base_name}_settings.txt\")\n",
-        "  save_settings(params, f\"images_out/{params.file_namespace}/{base_name}_settings.txt\")\n",
-        "\n",
-        "  skip_prompts = i // params.steps_per_scene\n",
-        "  skip_steps   = i %  params.steps_per_scene\n",
-        "  last_scene = prompts[0] if skip_prompts == 0 else prompts[skip_prompts - 1]\n",
-        "  for scene in prompts[skip_prompts:]:\n",
-        "    print(\"Running prompt:\", ' | '.join(map(str,scene)))\n",
-        "    i += model.run_steps(params.steps_per_scene-skip_steps, \n",
-        "                         scene, last_scene, loss_augs, \n",
-        "                         interp_steps = params.interpolation_steps,\n",
-        "                         i_offset = i, skipped_steps = skip_steps)\n",
-        "    skip_steps = 0\n",
-        "    model.clear_dataframe()\n",
-        "    last_scene = scene\n",
-        "  if fig:\n",
-        "    del fig, axs\n",
-        "\n",
-        "#if __name__ == '__main__':\n",
-        "try:\n",
-        "  gc.collect()\n",
-        "  torch.cuda.empty_cache()\n",
-        "  if batch_mode:\n",
-        "    if restore:\n",
-        "      settings_list = batch_list[restore_run:]\n",
-        "    else:\n",
-        "      settings_list = batch_list\n",
-        "      namespace = batch_list[0]['file_namespace']\n",
-        "      subprocess.run(['mkdir','-p',f'images_out/{namespace}'])\n",
-        "      save_batch(batch_list, f'images_out/{namespace}/{namespace}_batch settings.txt')\n",
-        "      print(f\"Batch settings saved to images_out/{namespace}/{namespace}_batch settings.txt\")\n",
-        "    for settings in settings_list:\n",
-        "      setting_string = json.dumps(settings)\n",
-        "      print(\"SETTINGS:\")\n",
-        "      print(setting_string)\n",
-        "      params = load_settings(setting_string)\n",
-        "      if params.animation_mode == '3D':\n",
-        "        init_AdaBins()\n",
-        "      params.allow_overwrite = False\n",
-        "      do_run()\n",
-        "      restore = False\n",
-        "      reencode = False\n",
-        "      gc.collect()\n",
-        "      torch.cuda.empty_cache()\n",
-        "  else:\n",
-        "    if params.animation_mode == '3D':\n",
-        "      pass\n",
-        "      #init_AdaBins()\n",
-        "    do_run()\n",
-        "    print(\"Complete.\")\n",
-        "    gc.collect()\n",
-        "    torch.cuda.empty_cache()\n",
-        "except KeyboardInterrupt:\n",
-        "  pass\n",
-        "except RuntimeError:\n",
-        "  print_vram_usage()\n",
-        "  raise\n",
-        "      \n",
-        "#print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))"
+        "# function wraps step 2.3 of the original p5 notebook\n",
+        "_main(cfg)"
       ]
     },
     {