From 9a32a4acbfba4bb61248e206fa404bc4f0af96c3 Mon Sep 17 00:00:00 2001 From: lucylq Date: Tue, 13 Aug 2024 08:36:47 -0700 Subject: [PATCH] [executorch] Preprocess export test Differential Revision: D61047506 Pull Request resolved: https://github.com/pytorch/executorch/pull/4651 --- examples/models/flamingo/__init__.py | 0 examples/models/flamingo/export_preprocess.py | 19 ++ .../models/flamingo/export_preprocess_lib.py | 87 +++++++ .../models/flamingo/install_requirements.sh | 9 + examples/models/flamingo/test_preprocess.py | 244 ++++++++++++++++++ 5 files changed, 359 insertions(+) create mode 100644 examples/models/flamingo/__init__.py create mode 100644 examples/models/flamingo/export_preprocess.py create mode 100644 examples/models/flamingo/export_preprocess_lib.py create mode 100644 examples/models/flamingo/install_requirements.sh create mode 100644 examples/models/flamingo/test_preprocess.py diff --git a/examples/models/flamingo/__init__.py b/examples/models/flamingo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/models/flamingo/export_preprocess.py b/examples/models/flamingo/export_preprocess.py new file mode 100644 index 0000000000..c5a930c88c --- /dev/null +++ b/examples/models/flamingo/export_preprocess.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from export_preprocess_lib import export_preprocess, lower_to_executorch_preprocess + + +def main(): + ep = export_preprocess() + et = lower_to_executorch_preprocess(ep) + + with open("preprocess.pte", "wb") as file: + et.write_to_file(file) + + +if __name__ == "__main__": + main() diff --git a/examples/models/flamingo/export_preprocess_lib.py b/examples/models/flamingo/export_preprocess_lib.py new file mode 100644 index 0000000000..736116de8b --- /dev/null +++ b/examples/models/flamingo/export_preprocess_lib.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge +from executorch.exir.program._program import ExecutorchProgramManager + +from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa + +from torch.export import Dim, ExportedProgram +from torchtune.models.clip.inference._transforms import _CLIPImageTransform + +from .passes.replace_custom_ops_with_aten_ops_pass import ( + ReplaceCustomOpsWithAtenOpsPass, +) + + +def get_example_inputs() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + image = torch.ones(3, 800, 600) + target_size = torch.tensor([448, 336]) + canvas_size = torch.tensor([448, 448]) + return (image, target_size, canvas_size) + + +def get_dynamic_shapes() -> Dict[str, Dict[int, Dim]]: + img_h = Dim("img_h", min=1, max=4000) + img_w = Dim("img_w", min=1, max=4000) + + dynamic_shapes = { + "image": {1: img_h, 2: img_w}, + "target_size": None, + "canvas_size": None, + } + return dynamic_shapes + + +def export_preprocess( + resample: str = "bilinear", + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + max_num_tiles: int = 4, + tile_size: int = 224, + antialias: bool = False, +) -> ExportedProgram: + + # Instantiate eager model. + image_transform_model = _CLIPImageTransform( + resample=resample, + image_mean=image_mean, + image_std=image_std, + max_num_tiles=max_num_tiles, + tile_size=tile_size, + antialias=antialias, + ) + + # Replace non-exportable ops with custom ops. + image_transform_model.pad = torch.ops.preprocess.pad.default + image_transform_model.tile_crop = torch.ops.preprocess.tile_crop.default + + # Export. + example_inputs = get_example_inputs() + dynamic_shapes = get_dynamic_shapes() + ep = torch.export.export( + image_transform_model, + example_inputs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + return ep + + +def lower_to_executorch_preprocess( + exported_program: ExportedProgram, +) -> ExecutorchProgramManager: + edge_program = to_edge( + exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False) + ) + # Replace custom ops with aten ops. + edge_program = edge_program.transform([ReplaceCustomOpsWithAtenOpsPass()]) + + et_program = edge_program.to_executorch(ExecutorchBackendConfig()) + return et_program diff --git a/examples/models/flamingo/install_requirements.sh b/examples/models/flamingo/install_requirements.sh new file mode 100644 index 0000000000..0bcf302ca9 --- /dev/null +++ b/examples/models/flamingo/install_requirements.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Install torchtune nightly for model definitions. +pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir diff --git a/examples/models/flamingo/test_preprocess.py b/examples/models/flamingo/test_preprocess.py new file mode 100644 index 0000000000..896a01655e --- /dev/null +++ b/examples/models/flamingo/test_preprocess.py @@ -0,0 +1,244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import numpy as np +import PIL +import torch + +from parameterized import parameterized +from PIL import Image + +from torchtune.models.clip.inference._transforms import ( + _CLIPImageTransform, + CLIPImageTransform, +) + +from torchtune.modules.transforms import ( + find_supported_resolutions, + get_canvas_best_fit, + get_inscribed_size, +) +from torchvision.transforms.v2 import functional as F + +from .export_preprocess_lib import export_preprocess + + +@dataclass +class PreprocessConfig: + image_mean: Optional[List[float]] = None + image_std: Optional[List[float]] = None + resize_to_max_canvas: bool = True + resample: str = "bilinear" + antialias: bool = False + tile_size: int = 224 + max_num_tiles: int = 4 + possible_resolutions = None + + +class TestImageTransform(unittest.TestCase): + """ + This unittest checks that the exported image transform model produces the + same output as the reference model. + + Reference model: CLIPImageTransform + https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L115 + Eager and exported models: _CLIPImageTransform + https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26 + """ + + def setUp(self): + np.random.seed(0) + + def prepare_inputs( + self, image: Image.Image, config: PreprocessConfig + ) -> Tuple[torch.Tensor]: + """ + Prepare inputs for eager and exported models: + - Convert PIL image to tensor. + - Calculate the best resolution; a canvas with height and width divisible by tile_size. + - Calculate the inscribed size; the size of the image inscribed within best_resolution, + without distortion. + + These calculations are done by the reference model inside __init__ and __call__ + https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L115 + """ + image_tensor = F.to_dtype( + F.grayscale_to_rgb_image(F.to_image(image)), scale=True + ) + + # Calculate possible resolutions. + possible_resolutions = config.possible_resolutions + if possible_resolutions is None: + possible_resolutions = find_supported_resolutions( + max_num_tiles=config.max_num_tiles, tile_size=config.tile_size + ) + possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) + + # Limit resizing. + max_size = None if config.resize_to_max_canvas else config.tile_size + + # Find the best canvas to fit the image without distortion. + best_resolution = get_canvas_best_fit( + image=image_tensor, + possible_resolutions=possible_resolutions, + resize_to_max_canvas=config.resize_to_max_canvas, + ) + best_resolution = torch.tensor(best_resolution) + + # Find the dimensions of the image, such that it is inscribed within best_resolution + # without distortion. + inscribed_size = get_inscribed_size( + image_tensor.shape[-2:], best_resolution, max_size + ) + inscribed_size = torch.tensor(inscribed_size) + + return image_tensor, inscribed_size, best_resolution + + # This test setup mirrors the one in torchtune: + # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py + # The values are slightly different, as torchtune uses antialias=True, + # and this test uses antialias=False, which is exportable (has a portable kernel). + @parameterized.expand( + [ + ( + (100, 400, 3), # image_size + torch.Size([2, 3, 224, 224]), # expected shape + False, # resize_to_max_canvas + [0.2230, 0.1763], # expected_tile_means + [1.0, 1.0], # expected_tile_max + [0.0, 0.0], # expected_tile_min + [1, 2], # expected_aspect_ratio + ), + ( + (1000, 300, 3), # image_size + torch.Size([4, 3, 224, 224]), # expected shape + True, # resize_to_max_canvas + [0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means + [0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max + [0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min + [4, 1], # expected_aspect_ratio + ), + ( + (200, 200, 3), # image_size + torch.Size([4, 3, 224, 224]), # expected shape + True, # resize_to_max_canvas + [0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means + [0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max + [0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min + [2, 2], # expected_aspect_ratio + ), + ( + (600, 200, 3), # image_size + torch.Size([3, 3, 224, 224]), # expected shape + False, # resize_to_max_canvas + [0.4472, 0.4468, 0.3031], # expected_tile_means + [1.0, 1.0, 1.0], # expected_tile_max + [0.0, 0.0, 0.0], # expected_tile_min + [3, 1], # expected_aspect_ratio + ), + ] + ) + def test_preprocess( + self, + image_size: Tuple[int], + expected_shape: torch.Size, + resize_to_max_canvas: bool, + expected_tile_means: List[float], + expected_tile_max: List[float], + expected_tile_min: List[float], + expected_ar: List[int], + ) -> None: + config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas) + + reference_model = CLIPImageTransform( + image_mean=config.image_mean, + image_std=config.image_std, + resize_to_max_canvas=config.resize_to_max_canvas, + resample=config.resample, + antialias=config.antialias, + tile_size=config.tile_size, + max_num_tiles=config.max_num_tiles, + possible_resolutions=None, + ) + + eager_model = _CLIPImageTransform( + image_mean=config.image_mean, + image_std=config.image_std, + resample=config.resample, + antialias=config.antialias, + tile_size=config.tile_size, + max_num_tiles=config.max_num_tiles, + ) + + exported_model = export_preprocess( + image_mean=config.image_mean, + image_std=config.image_std, + resample=config.resample, + antialias=config.antialias, + tile_size=config.tile_size, + max_num_tiles=config.max_num_tiles, + ) + + # Prepare image input. + image = ( + np.random.randint(0, 256, np.prod(image_size)) + .reshape(image_size) + .astype(np.uint8) + ) + image = PIL.Image.fromarray(image) + + # Run reference model. + reference_output = reference_model(image=image) + reference_image = reference_output["image"] + reference_ar = reference_output["aspect_ratio"].tolist() + + # Check output shape and aspect ratio matches expected values. + self.assertEqual(reference_image.shape, expected_shape) + self.assertEqual(reference_ar, expected_ar) + + # Check pixel values within expected range [0, 1] + self.assertTrue(0 <= reference_image.min() <= reference_image.max() <= 1) + + # Check mean, max, and min values of the tiles match expected values. + for i, tile in enumerate(reference_image): + self.assertAlmostEqual( + tile.mean().item(), expected_tile_means[i], delta=1e-4 + ) + self.assertAlmostEqual(tile.max().item(), expected_tile_max[i], delta=1e-4) + self.assertAlmostEqual(tile.min().item(), expected_tile_min[i], delta=1e-4) + + # Check num tiles matches the product of the aspect ratio. + expected_num_tiles = reference_ar[0] * reference_ar[1] + self.assertEqual(expected_num_tiles, reference_image.shape[0]) + + # Pre-work for eager and exported models. The reference model performs these + # calculations and passes the result to _CLIPImageTransform, the exportable model. + image_tensor, inscribed_size, best_resolution = self.prepare_inputs( + image=image, config=config + ) + + # Run eager and exported models. + eager_image, eager_ar = eager_model( + image_tensor, inscribed_size, best_resolution + ) + eager_ar = eager_ar.tolist() + + exported_image, exported_ar = exported_model.module()( + image_tensor, inscribed_size, best_resolution + ) + exported_ar = exported_ar.tolist() + + # Check eager and exported models match reference model. + self.assertTrue(torch.allclose(reference_image, eager_image)) + self.assertTrue(torch.allclose(reference_image, exported_image)) + + self.assertTrue(reference_ar, eager_ar) + self.assertTrue(reference_ar, exported_ar)