Skip to content

Commit

Permalink
[HLO-OPT] Tool : register HWI passes fro hlo/transforms/ directory
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721847940
  • Loading branch information
abhigunj authored and Google-ML-Automation committed Feb 3, 2025
1 parent 6743f54 commit 7ea5677
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 9 deletions.
22 changes: 20 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,21 @@ jobs:
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest",
kokoro_job_name: "xla-linux-x86-cpu",
pretty_name: "XLA Linux x86 CPU",
repo: "openxla/xla",
},
{
pool: "linux-arm64-c4a-16",
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest",
kokoro_job_name: "xla-linux-arm64-cpu",
pretty_name: "XLA Linux ARM64 CPU",
repo: "openxla/xla",
},
{
pool: "linux-x86-n2-16",
container: "gcr.io/tensorflow-sigs/build:latest-python3.11",
kokoro_job_name: "jax-linux-x86-cpu",
pretty_name: "JAX Linux x86 CPU",
repo: "jax-ml/jax",
}
]
name: ${{ matrix.job_info.pretty_name }}
Expand All @@ -53,10 +62,19 @@ jobs:
shell: bash
timeout-minutes: 30
steps:
- name: "Checking out repository"
- name: "Checking out openxla/xla"
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
with:
path: "openxla/xla"
- name: Checking out ${{ matrix.job_info.repo }}
if: ${{ matrix.job_info.repo != 'openxla/xla' }}
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
with:
repository: ${{ matrix.job_info.repo }}
path: ${{ matrix.job_info.repo }}
- name: "Run build.py"
working-directory: ${{ matrix.job_info.repo }}
env:
# TODO(ddunleavy): refactor build.py to not depend on this env var
KOKORO_JOB_NAME: ${{ matrix.job_info.kokoro_job_name }}
run: build_tools/ci/build.py
run: $GITHUB_WORKSPACE/openxla/xla/build_tools/ci/build.py
32 changes: 28 additions & 4 deletions build_tools/ci/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
_KOKORO_ARTIFACTS_DIR = os.environ.get(
"KOKORO_ARTIFACTS_DIR", "$KOKORO_ARTIFACTS_DIR"
)
_GITHUB_WORKSPACE = os.environ.get("GITHUB_WORKSPACE", "$GITHUB_WORKSPACE")


def retry(
Expand Down Expand Up @@ -96,6 +97,7 @@ class BuildType(enum.Enum):
MACOS_CPU_X86 = enum.auto()

JAX_CPU = enum.auto()
JAX_CPU_SELF_HOSTED = enum.auto()
JAX_GPU = enum.auto()

TENSORFLOW_CPU = enum.auto()
Expand Down Expand Up @@ -167,10 +169,14 @@ def commands(self) -> List[List[str]]:
if self.repo != "openxla/xla":
_, repo_name = self.repo.split("/")

# pyformat:disable
cmds.append(["git", "clone", "--depth=1",
f"https://github.com/{self.repo}", f"./github/{repo_name}"])
# pyformat:enable
if "self_hosted" not in self.type_.name.lower():
cmds.append([
"git",
"clone",
"--depth=1",
f"https://github.com/{self.repo}",
f"./github/{repo_name}",
])

cmds.extend(self.extra_setup_commands)

Expand Down Expand Up @@ -392,6 +398,23 @@ def nvidia_gpu_build_with_compute_capability(
),
)

_JAX_CPU_SELF_HOSTED_BUILD = Build(
type_=BuildType.JAX_CPU_SELF_HOSTED,
repo="google/jax",
image_url=None,
configs=("rbe_linux_x86_64",),
target_patterns=("//tests:cpu_tests", "//tests:backend_independent_tests"),
test_env=dict(
JAX_NUM_GENERATED_CASES=25,
JAX_SKIP_SLOW_TESTS=1,
),
options=dict(
**_DEFAULT_BAZEL_OPTIONS,
override_repository=f"xla={_GITHUB_WORKSPACE}/openxla/xla",
repo_env="HERMETIC_PYTHON_VERSION=3.12",
),
)

_JAX_GPU_BUILD = Build(
type_=BuildType.JAX_GPU,
repo="google/jax",
Expand Down Expand Up @@ -476,6 +499,7 @@ def nvidia_gpu_build_with_compute_capability(
"tensorflow/xla/tensorflow/gpu/build_gpu": _TENSORFLOW_GPU_BUILD,
"xla-linux-x86-cpu": _CPU_X86_SELF_HOSTED_BUILD,
"xla-linux-arm64-cpu": _CPU_ARM64_SELF_HOSTED_BUILD,
"jax-linux-x86-cpu": _JAX_CPU_SELF_HOSTED_BUILD,
}


Expand Down
5 changes: 5 additions & 0 deletions build_tools/ci/golden_commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ docker exec xla_ci bazel test --build_tag_filters= --test_tag_filters= --config=
docker exec xla_ci bazel analyze-profile profile.json.gz
docker stop xla_ci
# END BuildType.JAX_CPU
# BEGIN BuildType.JAX_CPU_SELF_HOSTED
parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters= --test_tag_filters= --config=rbe_linux_x86_64 --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --repo_env=HERMETIC_PYTHON_VERSION=3.12 --nobuild -- //tests:cpu_tests //tests:backend_independent_tests
bazel test --build_tag_filters= --test_tag_filters= --config=rbe_linux_x86_64 --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=$GITHUB_WORKSPACE/openxla/xla --repo_env=HERMETIC_PYTHON_VERSION=3.12 -- //tests:cpu_tests //tests:backend_independent_tests
bazel analyze-profile profile.json.gz
# END BuildType.JAX_CPU_SELF_HOSTED
# BEGIN BuildType.JAX_GPU
$KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html
git clone --depth=1 https://github.com/google/jax ./github/jax
Expand Down
38 changes: 37 additions & 1 deletion xla/hlo/tools/hlo_opt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,40 @@ cc_library(
# Includes a macro to register a provider.
cc_library(
name = "opt_lib",
testonly = True,
srcs = ["opt_lib.cc"],
hdrs = ["opt_lib.h"],
deps = [
"//xla:literal_pool",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/analysis:indexed_array_analysis",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass_pipeline",
"//xla/hlo/transforms:add_original_value",
"//xla/hlo/transforms:bfloat16_propagation",
"//xla/hlo/transforms:convert_memory_placement_to_internal_annotations",
"//xla/hlo/transforms:defuser",
"//xla/hlo/transforms:despecializer",
"//xla/hlo/transforms:host_offload_legalize",
"//xla/hlo/transforms:host_offloader",
"//xla/hlo/transforms:host_offloading_prepare",
"//xla/hlo/transforms:literal_canonicalizer",
"//xla/hlo/transforms:memory_space_propagation",
"//xla/hlo/transforms:operand_upcaster",
"//xla/hlo/transforms:sharding_format_picker",
"//xla/hlo/transforms:while_loop_trip_count_annotator",
"//xla/hlo/transforms/collectives:all_gather_broadcast_reorder",
"//xla/hlo/transforms/collectives:all_gather_combiner",
"//xla/hlo/transforms/collectives:all_gather_cse",
"//xla/hlo/transforms/collectives:all_reduce_combiner",
"//xla/hlo/transforms/collectives:all_reduce_contiguous",
"//xla/hlo/transforms/collectives:async_collective_creator",
"//xla/hlo/transforms/collectives:collective_quantizer",
"//xla/hlo/transforms/collectives:collective_transformation_reorderer",
"//xla/hlo/transforms/collectives:collectives_schedule_linearizer",
"//xla/hlo/transforms/collectives:convert_async_collectives_to_sync",
"//xla/hlo/transforms/collectives:infeed_token_propagation",
"//xla/hlo/transforms/expanders:cholesky_expander",
"//xla/hlo/transforms/expanders:comparison_expander",
"//xla/hlo/transforms/expanders:convolution_4d_expander",
Expand All @@ -56,21 +78,34 @@ cc_library(
"//xla/hlo/transforms/expanders:rng_expander",
"//xla/hlo/transforms/expanders:stable_sort_expander",
"//xla/hlo/transforms/expanders:stochastic_convert_decomposer",
"//xla/hlo/transforms/simplifiers:algebraic_simplifier",
"//xla/hlo/transforms/simplifiers:all_reduce_folder",
"//xla/hlo/transforms/simplifiers:ar_crs_combiner",
"//xla/hlo/transforms/simplifiers:batch_dot_simplification",
"//xla/hlo/transforms/simplifiers:bfloat16_conversion_folding",
"//xla/hlo/transforms/simplifiers:broadcast_canonicalizer",
"//xla/hlo/transforms/simplifiers:conditional_canonicalizer",
"//xla/hlo/transforms/simplifiers:convert_mover",
"//xla/hlo/transforms/simplifiers:convolution_group_converter",
"//xla/hlo/transforms/simplifiers:dot_dimension_merger",
"//xla/hlo/transforms/simplifiers:dot_merger",
"//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier",
"//xla/hlo/transforms/simplifiers:flatten_call_graph",
"//xla/hlo/transforms/simplifiers:float_normalization",
"//xla/hlo/transforms/simplifiers:fusion_constant_sinking",
"//xla/hlo/transforms/simplifiers:gather_simplifier",
"//xla/hlo/transforms/simplifiers:hlo_computation_deduplicator",
"//xla/hlo/transforms/simplifiers:hlo_constant_folding",
"//xla/hlo/transforms/simplifiers:hlo_constant_splitter",
"//xla/hlo/transforms/simplifiers:hlo_dce",
"//xla/hlo/transforms/simplifiers:hlo_element_type_converter",
"//xla/hlo/transforms/simplifiers:hlo_memory_scheduler",
"//xla/hlo/transforms/simplifiers:host_memory_transfer_asyncifier",
"//xla/hlo/transforms/simplifiers:instruction_hoister",
"//xla/hlo/transforms/simplifiers:optimize_input_output_buffer_alias",
"//xla/hlo/transforms/simplifiers:reshape_mover",
"//xla/hlo/transforms/simplifiers:result_caster",
"//xla/hlo/transforms/simplifiers:root_instruction_sinker",
"//xla/hlo/transforms/simplifiers:simplify_fp_conversions",
"//xla/hlo/transforms/simplifiers:slice_sinker",
"//xla/hlo/transforms/simplifiers:sort_simplifier",
Expand All @@ -79,9 +114,11 @@ cc_library(
"//xla/hlo/transforms/simplifiers:tuple_simplifier",
"//xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination",
"//xla/hlo/transforms/tests:dummy_passes",
"//xla/service:buffer_value",
"//xla/service:float_support",
"//xla/service:platform_util",
"//xla/stream_executor/platform:initialize",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
Expand All @@ -90,6 +127,5 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@tsl//tsl/platform:statusor",
],
)
Loading

0 comments on commit 7ea5677

Please sign in to comment.