Skip to content

Commit

Permalink
[executorch] Preprocess export test
Browse files Browse the repository at this point in the history
Differential Revision: D61047506

Pull Request resolved: pytorch#4651
  • Loading branch information
lucylq authored Aug 13, 2024
1 parent 2654f59 commit 9a32a4a
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 0 deletions.
Empty file.
19 changes: 19 additions & 0 deletions examples/models/flamingo/export_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from export_preprocess_lib import export_preprocess, lower_to_executorch_preprocess


def main():
ep = export_preprocess()
et = lower_to_executorch_preprocess(ep)

with open("preprocess.pte", "wb") as file:
et.write_to_file(file)


if __name__ == "__main__":
main()
87 changes: 87 additions & 0 deletions examples/models/flamingo/export_preprocess_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, List, Optional, Tuple

import torch
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
from executorch.exir.program._program import ExecutorchProgramManager

from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa

from torch.export import Dim, ExportedProgram
from torchtune.models.clip.inference._transforms import _CLIPImageTransform

from .passes.replace_custom_ops_with_aten_ops_pass import (
ReplaceCustomOpsWithAtenOpsPass,
)


def get_example_inputs() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image = torch.ones(3, 800, 600)
target_size = torch.tensor([448, 336])
canvas_size = torch.tensor([448, 448])
return (image, target_size, canvas_size)


def get_dynamic_shapes() -> Dict[str, Dict[int, Dim]]:
img_h = Dim("img_h", min=1, max=4000)
img_w = Dim("img_w", min=1, max=4000)

dynamic_shapes = {
"image": {1: img_h, 2: img_w},
"target_size": None,
"canvas_size": None,
}
return dynamic_shapes


def export_preprocess(
resample: str = "bilinear",
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
max_num_tiles: int = 4,
tile_size: int = 224,
antialias: bool = False,
) -> ExportedProgram:

# Instantiate eager model.
image_transform_model = _CLIPImageTransform(
resample=resample,
image_mean=image_mean,
image_std=image_std,
max_num_tiles=max_num_tiles,
tile_size=tile_size,
antialias=antialias,
)

# Replace non-exportable ops with custom ops.
image_transform_model.pad = torch.ops.preprocess.pad.default
image_transform_model.tile_crop = torch.ops.preprocess.tile_crop.default

# Export.
example_inputs = get_example_inputs()
dynamic_shapes = get_dynamic_shapes()
ep = torch.export.export(
image_transform_model,
example_inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
)
return ep


def lower_to_executorch_preprocess(
exported_program: ExportedProgram,
) -> ExecutorchProgramManager:
edge_program = to_edge(
exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
)
# Replace custom ops with aten ops.
edge_program = edge_program.transform([ReplaceCustomOpsWithAtenOpsPass()])

et_program = edge_program.to_executorch(ExecutorchBackendConfig())
return et_program
9 changes: 9 additions & 0 deletions examples/models/flamingo/install_requirements.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Install torchtune nightly for model definitions.
pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir
244 changes: 244 additions & 0 deletions examples/models/flamingo/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from dataclasses import dataclass
from typing import List, Optional, Tuple

import numpy as np
import PIL
import torch

from parameterized import parameterized
from PIL import Image

from torchtune.models.clip.inference._transforms import (
_CLIPImageTransform,
CLIPImageTransform,
)

from torchtune.modules.transforms import (
find_supported_resolutions,
get_canvas_best_fit,
get_inscribed_size,
)
from torchvision.transforms.v2 import functional as F

from .export_preprocess_lib import export_preprocess


@dataclass
class PreprocessConfig:
image_mean: Optional[List[float]] = None
image_std: Optional[List[float]] = None
resize_to_max_canvas: bool = True
resample: str = "bilinear"
antialias: bool = False
tile_size: int = 224
max_num_tiles: int = 4
possible_resolutions = None


class TestImageTransform(unittest.TestCase):
"""
This unittest checks that the exported image transform model produces the
same output as the reference model.
Reference model: CLIPImageTransform
https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L115
Eager and exported models: _CLIPImageTransform
https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26
"""

def setUp(self):
np.random.seed(0)

def prepare_inputs(
self, image: Image.Image, config: PreprocessConfig
) -> Tuple[torch.Tensor]:
"""
Prepare inputs for eager and exported models:
- Convert PIL image to tensor.
- Calculate the best resolution; a canvas with height and width divisible by tile_size.
- Calculate the inscribed size; the size of the image inscribed within best_resolution,
without distortion.
These calculations are done by the reference model inside __init__ and __call__
https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L115
"""
image_tensor = F.to_dtype(
F.grayscale_to_rgb_image(F.to_image(image)), scale=True
)

# Calculate possible resolutions.
possible_resolutions = config.possible_resolutions
if possible_resolutions is None:
possible_resolutions = find_supported_resolutions(
max_num_tiles=config.max_num_tiles, tile_size=config.tile_size
)
possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2)

# Limit resizing.
max_size = None if config.resize_to_max_canvas else config.tile_size

# Find the best canvas to fit the image without distortion.
best_resolution = get_canvas_best_fit(
image=image_tensor,
possible_resolutions=possible_resolutions,
resize_to_max_canvas=config.resize_to_max_canvas,
)
best_resolution = torch.tensor(best_resolution)

# Find the dimensions of the image, such that it is inscribed within best_resolution
# without distortion.
inscribed_size = get_inscribed_size(
image_tensor.shape[-2:], best_resolution, max_size
)
inscribed_size = torch.tensor(inscribed_size)

return image_tensor, inscribed_size, best_resolution

# This test setup mirrors the one in torchtune:
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
# The values are slightly different, as torchtune uses antialias=True,
# and this test uses antialias=False, which is exportable (has a portable kernel).
@parameterized.expand(
[
(
(100, 400, 3), # image_size
torch.Size([2, 3, 224, 224]), # expected shape
False, # resize_to_max_canvas
[0.2230, 0.1763], # expected_tile_means
[1.0, 1.0], # expected_tile_max
[0.0, 0.0], # expected_tile_min
[1, 2], # expected_aspect_ratio
),
(
(1000, 300, 3), # image_size
torch.Size([4, 3, 224, 224]), # expected shape
True, # resize_to_max_canvas
[0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means
[0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max
[0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min
[4, 1], # expected_aspect_ratio
),
(
(200, 200, 3), # image_size
torch.Size([4, 3, 224, 224]), # expected shape
True, # resize_to_max_canvas
[0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means
[0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max
[0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min
[2, 2], # expected_aspect_ratio
),
(
(600, 200, 3), # image_size
torch.Size([3, 3, 224, 224]), # expected shape
False, # resize_to_max_canvas
[0.4472, 0.4468, 0.3031], # expected_tile_means
[1.0, 1.0, 1.0], # expected_tile_max
[0.0, 0.0, 0.0], # expected_tile_min
[3, 1], # expected_aspect_ratio
),
]
)
def test_preprocess(
self,
image_size: Tuple[int],
expected_shape: torch.Size,
resize_to_max_canvas: bool,
expected_tile_means: List[float],
expected_tile_max: List[float],
expected_tile_min: List[float],
expected_ar: List[int],
) -> None:
config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas)

reference_model = CLIPImageTransform(
image_mean=config.image_mean,
image_std=config.image_std,
resize_to_max_canvas=config.resize_to_max_canvas,
resample=config.resample,
antialias=config.antialias,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
possible_resolutions=None,
)

eager_model = _CLIPImageTransform(
image_mean=config.image_mean,
image_std=config.image_std,
resample=config.resample,
antialias=config.antialias,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
)

exported_model = export_preprocess(
image_mean=config.image_mean,
image_std=config.image_std,
resample=config.resample,
antialias=config.antialias,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
)

# Prepare image input.
image = (
np.random.randint(0, 256, np.prod(image_size))
.reshape(image_size)
.astype(np.uint8)
)
image = PIL.Image.fromarray(image)

# Run reference model.
reference_output = reference_model(image=image)
reference_image = reference_output["image"]
reference_ar = reference_output["aspect_ratio"].tolist()

# Check output shape and aspect ratio matches expected values.
self.assertEqual(reference_image.shape, expected_shape)
self.assertEqual(reference_ar, expected_ar)

# Check pixel values within expected range [0, 1]
self.assertTrue(0 <= reference_image.min() <= reference_image.max() <= 1)

# Check mean, max, and min values of the tiles match expected values.
for i, tile in enumerate(reference_image):
self.assertAlmostEqual(
tile.mean().item(), expected_tile_means[i], delta=1e-4
)
self.assertAlmostEqual(tile.max().item(), expected_tile_max[i], delta=1e-4)
self.assertAlmostEqual(tile.min().item(), expected_tile_min[i], delta=1e-4)

# Check num tiles matches the product of the aspect ratio.
expected_num_tiles = reference_ar[0] * reference_ar[1]
self.assertEqual(expected_num_tiles, reference_image.shape[0])

# Pre-work for eager and exported models. The reference model performs these
# calculations and passes the result to _CLIPImageTransform, the exportable model.
image_tensor, inscribed_size, best_resolution = self.prepare_inputs(
image=image, config=config
)

# Run eager and exported models.
eager_image, eager_ar = eager_model(
image_tensor, inscribed_size, best_resolution
)
eager_ar = eager_ar.tolist()

exported_image, exported_ar = exported_model.module()(
image_tensor, inscribed_size, best_resolution
)
exported_ar = exported_ar.tolist()

# Check eager and exported models match reference model.
self.assertTrue(torch.allclose(reference_image, eager_image))
self.assertTrue(torch.allclose(reference_image, exported_image))

self.assertTrue(reference_ar, eager_ar)
self.assertTrue(reference_ar, exported_ar)

0 comments on commit 9a32a4a

Please sign in to comment.