Skip to content

Commit

Permalink
Simplify gen_oplist_copy_from_core file (pytorch#3549)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3549

Looks like we just need the yaml file. That means, we can get rid off gen_supported_mobile_models and write_selected_mobile_ops and all its dependencies.

Reviewed By: lucylq

Differential Revision: D57122006

fbshipit-source-id: 07a0aafb686237cae29d774e552eddccce797136
  • Loading branch information
mergennachin authored and facebook-github-bot committed May 13, 2024
1 parent 629b112 commit 9db0a69
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 176 deletions.
5 changes: 1 addition & 4 deletions codegen/tools/gen_all_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ def main(argv: List[Any]) -> None:
parser = argparse.ArgumentParser(description="Generate operator lists")
parser.add_argument(
"--output_dir",
help=(
"The directory to store the output yaml files (selected_mobile_ops.h, "
+ "selected_kernel_dtypes.h, selected_operators.yaml)"
),
help=("The directory to store the output yaml file (selected_operators.yaml)"),
required=True,
)
parser.add_argument(
Expand Down
174 changes: 2 additions & 172 deletions codegen/tools/gen_oplist_copy_from_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,138 +5,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# This is a copy from //xplat/caffe2/tools/code_analyzer/gen_oplist.py
# TODO(mnachin): We will need to either simplify or remove this code altogether.
# This is necessary to remove dependency from pytorch core from ExecuTorch.
# This is a simplified copy from //xplat/caffe2/tools/code_analyzer/gen_oplist.py
import argparse
import json
import os
import sys
from functools import reduce
from typing import Any, List, Set
from typing import Any, List

import yaml
from torchgen.code_template import CodeTemplate
from torchgen.selective_build.selector import (
combine_selective_builders,
SelectiveBuilder,
)

if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
return $dtype_checks;
}"""
if_condition_template = CodeTemplate(if_condition_template_str)

selected_kernel_dtypes_h_template_str = """
#include <c10/core/ScalarType.h>
#include <c10/util/string_view.h>
#include <c10/macros/Macros.h>
namespace at {
inline constexpr bool should_include_kernel_dtype(
const char *kernel_tag_str,
at::ScalarType scalar_type
) {
c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str);
$body
return false;
}
}
"""
selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)

selected_mobile_ops_preamble = """#pragma once
/**
* Generated by gen_selected_mobile_ops_header.py
*/
"""


def get_selected_kernel_dtypes_code(
selective_builder: SelectiveBuilder,
) -> str:
# See https://www.internalfb.com/intern/paste/P153411698/ for an example of the
# generated code in case all kernel dtypes are selected and in case some kernel
# dtypes are selected (i.e. both cases).
#
body = "return true;"
if (
selective_builder.include_all_operators is False
and selective_builder.include_all_non_op_selectives is False
):
body_parts = []
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
body_parts.append(
if_condition_template.substitute(
kernel_tag_name=kernel_tag,
dtype_checks=" || ".join(conditions),
),
)
body = " else ".join(body_parts)

header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
return header_contents


def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = []
for op_name, op in selective_builder.operators.items():
if op.is_root_operator:
ops.append(op_name)
return set(ops)


# Write the file selected_mobile_ops.h with optionally:
# 1. The selected root operators
# 2. The selected kernel dtypes
def write_selected_mobile_ops(
output_file_path: str,
selective_builder: SelectiveBuilder,
) -> None:
root_ops = extract_root_operators(selective_builder)
custom_classes = selective_builder.custom_classes
build_features = selective_builder.build_features
with open(output_file_path, "wb") as out_file:
body_parts = [selected_mobile_ops_preamble]
# This condition checks if we are in selective build.
# if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
if not selective_builder.include_all_operators:
body_parts.append(
"#define TORCH_OPERATOR_WHITELIST "
+ (";".join(sorted(root_ops)))
+ ";\n\n"
)
# This condition checks if we are in tracing based selective build
if selective_builder.include_all_non_op_selectives is False:
body_parts.append(
"#define TORCH_CUSTOM_CLASS_ALLOWLIST "
+ (";".join(sorted(custom_classes)))
+ ";\n\n"
)
body_parts.append(
"#define TORCH_BUILD_FEATURE_ALLOWLIST "
+ (";".join(sorted(build_features)))
+ ";\n\n"
)

body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
header_contents = "".join(body_parts)
out_file.write(header_contents.encode("utf-8"))


def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]:
return set(selective_builder.operators.keys())


def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]:
ops = []
for op_name, op in selective_builder.operators.items():
if op.is_used_for_training:
ops.append(op_name)
return set(ops)


def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None:
ops = []
Expand All @@ -153,49 +34,6 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N
)


def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None:
supported_mobile_models_source = """/*
* Generated by gen_oplist.py
*/
#include "fb/supported_mobile_models/SupportedMobileModels.h"
struct SupportedMobileModelCheckerRegistry {{
SupportedMobileModelCheckerRegistry() {{
auto& ref = facebook::pytorch::supported_model::SupportedMobileModelChecker::singleton();
ref.set_supported_md5_hashes(std::unordered_set<std::string>{{
{supported_hashes_template}
}});
}}
}};
// This is a global object, initializing which causes the registration to happen.
SupportedMobileModelCheckerRegistry register_model_versions;
"""

# Generate SupportedMobileModelsRegistration.cpp
md5_hashes = set()
for model_dict in model_dicts:
if "debug_info" in model_dict:
debug_info = json.loads(model_dict["debug_info"][0])
if debug_info["is_new_style_rule"]:
for asset_info in debug_info["asset_info"].values():
md5_hashes.update(asset_info["md5_hash"])

supported_hashes = ""
for md5 in md5_hashes:
supported_hashes += f'"{md5}",\n'
with open(
os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb"
) as out_file:
source = supported_mobile_models_source.format(
supported_hashes_template=supported_hashes
)
out_file.write(source.encode("utf-8"))


def main(argv: List[Any]) -> None:
"""This binary generates 3 files:
Expand Down Expand Up @@ -258,9 +96,6 @@ def main(argv: List[Any]) -> None:

selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts]

# While we have the model_dicts generate the supported mobile models api
gen_supported_mobile_models(model_dicts, options.output_dir)

# We may have 0 selective builders since there may not be any viable
# pt_operator_library rule marked as a dep for the pt_operator_registry rule.
# This is potentially an error, and we should probably raise an assertion
Expand All @@ -283,11 +118,6 @@ def main(argv: List[Any]) -> None:
).encode("utf-8"),
)

write_selected_mobile_ops(
os.path.join(options.output_dir, "selected_mobile_ops.h"),
selective_builder,
)


if __name__ == "__main__":
main(sys.argv[1:])

0 comments on commit 9db0a69

Please sign in to comment.