diff --git a/experiments/03_CLIP_TO_ONNX.ipynb b/experiments/03_CLIP_TO_ONNX.ipynb new file mode 100644 index 00000000..d943d27b --- /dev/null +++ b/experiments/03_CLIP_TO_ONNX.ipynb @@ -0,0 +1,323 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "f7f8e989-cdc9-475e-918d-af20530fcfe6", + "metadata": { + "is_executing": true + }, + "outputs": [], + "source": [ + "!pip3 install -q torch transformers optimum pillow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8c9e276-58f2-45a4-af40-6d7bacc30eec", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "from typing import Optional, Dict, Union, Tuple\n", + "\n", + "import torch\n", + "import numpy as np\n", + "from PIL import Image\n", + "from transformers import (\n", + " CLIPVisionModelWithProjection,\n", + " CLIPTextModelWithProjection,\n", + " CLIPImageProcessor,\n", + " CLIPTokenizerFast,\n", + ")\n", + "from transformers.models.clip.modeling_clip import (\n", + " CLIPTextModelOutput,\n", + " CLIPVisionModelOutput,\n", + " CLIPModel,\n", + ")\n", + "from optimum.onnxruntime import ORTModelForCustomTasks\n", + "from optimum.exporters.onnx.model_configs import CLIPTextWithProjectionOnnxConfig, ViTOnnxConfig\n", + "from optimum.exporters.onnx import export_models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bbceb40f-22cd-4d92-be6e-fe14f16f7bc2", + "metadata": {}, + "outputs": [], + "source": [ + "model_id = \"openai/clip-vit-base-patch32\"\n", + "output_dir = \"split-clip-onnx\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f96ff7fb-518e-405e-a7e0-46f836ffdec8", + "metadata": {}, + "outputs": [], + "source": [ + "class CLIPVisionModelWithProjectionOnnxConfig(ViTOnnxConfig):\n", + " @property\n", + " def outputs(self) -> Dict[str, Dict[int, str]]:\n", + " return {\n", + " \"image_embeds\": {0: \"batch_size\"},\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18863f0b-6bd5-463f-bebc-38bf40d51b9c", + "metadata": {}, + "outputs": [], + "source": [ + "class CLIPTextModelWithProjectionAndAttentionOnnxConfig(CLIPTextWithProjectionOnnxConfig):\n", + " @property\n", + " def inputs(self) -> Dict[str, Dict[int, str]]:\n", + " return {\n", + " \"input_ids\": {0: \"batch_size\", 1: \"sequence_length\"},\n", + " \"attention_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d37b16a-ec30-40e1-8404-0f0d51abfa76", + "metadata": {}, + "outputs": [], + "source": [ + "class CLIPTextModelWithProjectionNormalized(CLIPTextModelWithProjection):\n", + " def forward(\n", + " self,\n", + " input_ids: Optional[torch.Tensor] = None,\n", + " attention_mask: Optional[torch.Tensor] = None,\n", + " position_ids: Optional[torch.Tensor] = None,\n", + " output_attentions: Optional[bool] = None,\n", + " output_hidden_states: Optional[bool] = None,\n", + " return_dict: Optional[bool] = None,\n", + " ) -> Union[Tuple, CLIPTextModelOutput]:\n", + " text_outputs = super().forward(\n", + " input_ids,\n", + " attention_mask,\n", + " position_ids,\n", + " output_attentions,\n", + " output_hidden_states,\n", + " return_dict,\n", + " )\n", + " normalized_text_embeds = text_outputs.text_embeds / text_outputs.text_embeds.norm(\n", + " p=2, dim=-1, keepdim=True\n", + " )\n", + " return CLIPTextModelOutput(\n", + " text_embeds=normalized_text_embeds,\n", + " last_hidden_state=text_outputs.last_hidden_state,\n", + " hidden_states=text_outputs.hidden_states,\n", + " attentions=text_outputs.attentions,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cd05d33-5f4f-4b36-aa2c-43fd97d5061d", + "metadata": {}, + "outputs": [], + "source": [ + "class CLIPVisionModelWithProjectionNormalized(CLIPVisionModelWithProjection):\n", + " def forward(\n", + " self,\n", + " pixel_values: Optional[torch.FloatTensor] = None,\n", + " output_attentions: Optional[bool] = None,\n", + " output_hidden_states: Optional[bool] = None,\n", + " return_dict: Optional[bool] = None,\n", + " ) -> Union[Tuple, CLIPVisionModelOutput]:\n", + " vision_outputs = super().forward(pixel_values, return_dict)\n", + " normalized_image_embeds = vision_outputs.image_embeds / vision_outputs.image_embeds.norm(\n", + " p=2, dim=-1, keepdim=True\n", + " )\n", + " return CLIPVisionModelOutput(\n", + " image_embeds=normalized_image_embeds,\n", + " last_hidden_state=vision_outputs.last_hidden_state,\n", + " hidden_states=vision_outputs.hidden_states,\n", + " attentions=vision_outputs.attentions,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb5e5617-f8ce-4dcf-9148-68ddd91854c9", + "metadata": {}, + "outputs": [], + "source": [ + "text_model = CLIPTextModelWithProjectionNormalized.from_pretrained(model_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f578131-c6bb-460e-8200-d0b7f0aa4135", + "metadata": {}, + "outputs": [], + "source": [ + "vision_model = CLIPVisionModelWithProjectionNormalized.from_pretrained(model_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c57db080-642f-4b62-96ad-75af9c0ab277", + "metadata": {}, + "outputs": [], + "source": [ + "text_config = CLIPTextModelWithProjectionAndAttentionOnnxConfig(text_model.config)\n", + "vision_config = CLIPVisionModelWithProjectionOnnxConfig(vision_model.config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdb01aab-0dff-4fd5-9297-d16599274fdb", + "metadata": {}, + "outputs": [], + "source": [ + "text_model.config.save_pretrained(f\"./{output_dir}/text\")\n", + "vision_model.config.save_pretrained(f\"./{output_dir}/image\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "133bcd2c-57ae-4132-a691-3d129057f275", + "metadata": {}, + "outputs": [], + "source": [ + "export_models(\n", + " models_and_onnx_configs={\n", + " \"text_model\": (text_model, text_config),\n", + " \"vision_model\": (vision_model, vision_config),\n", + " },\n", + " output_dir=Path(f\"./{output_dir}\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd9167b9-0d00-4a24-8f5e-41392d923b95", + "metadata": {}, + "outputs": [], + "source": [ + "os.rename(f\"./{output_dir}/text_model.onnx\", f\"./{output_dir}/text/model.onnx\")\n", + "os.rename(f\"./{output_dir}/vision_model.onnx\", f\"./{output_dir}/image/model.onnx\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2281cc75-1027-4dbe-a7a5-86ee1dad4c3e", + "metadata": {}, + "outputs": [], + "source": [ + "ort_vision_model = ORTModelForCustomTasks.from_pretrained(\n", + " f\"./{output_dir}/image\", config=vision_config\n", + ")\n", + "image_processor = CLIPImageProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n", + "image_input = image_processor(images=Image.open(\"assets/image.jpeg\"), return_tensors=\"pt\")\n", + "\n", + "with torch.inference_mode():\n", + " image_outputs = ort_vision_model(**image_input)\n", + "image_processor.save_pretrained(f\"./{output_dir}/image\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ce61045-b1e7-4b62-a47b-4bcb27ac7288", + "metadata": {}, + "outputs": [], + "source": [ + "ort_text_model = ORTModelForCustomTasks.from_pretrained(f\"./{output_dir}/text\", config=text_config)\n", + "text_processor = CLIPTokenizerFast.from_pretrained(\"openai/clip-vit-base-patch32\")\n", + "text_input = text_processor(\"What am I using?\", return_tensors=\"pt\")\n", + "\n", + "with torch.inference_mode():\n", + " text_outputs = ort_text_model(**text_input)\n", + "text_processor.save_pretrained(f\"./{output_dir}/text\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee1e3d13-6884-4aa3-a6ab-aaa41fc17134", + "metadata": {}, + "outputs": [], + "source": [ + "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", + "inputs = {**text_input, **image_input}\n", + "clip_model.eval()\n", + "with torch.inference_mode():\n", + " gt_output = clip_model(**inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec07ce55-a0e7-42d2-b1f4-ba1370ab45b7", + "metadata": {}, + "outputs": [], + "source": [ + "print(np.allclose(gt_output.text_embeds.numpy(), text_outputs.text_embeds, atol=1e-6))\n", + "print(np.allclose(gt_output.image_embeds.numpy(), image_outputs.image_embeds, atol=1e-6))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9b15597-3054-40f1-a080-fea19d378d88", + "metadata": {}, + "outputs": [], + "source": [ + "token = \"\"\n", + "# create_repo(repo_id='Qdrant/clip-ViT-B-32-vision', exist_ok=True, token=token)\n", + "# create_repo(repo_id='Qdrant/clip-ViT-B-32-text', exist_ok=True, token=token)\n", + "\n", + "ort_text_model.push_to_hub(\n", + " save_directory=f\"./{output_dir}/text/\",\n", + " repository_id=\"Qdrant/clip-ViT-B-32-text\",\n", + " use_auth_token=token,\n", + ")\n", + "ort_vision_model.push_to_hub(\n", + " save_directory=f\"./{output_dir}/image\",\n", + " repository_id=\"Qdrant/clip-ViT-B-32-vision\",\n", + " use_auth_token=token,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/experiments/assets/image.jpeg b/experiments/assets/image.jpeg new file mode 100644 index 00000000..e131e8ec Binary files /dev/null and b/experiments/assets/image.jpeg differ