forked from openxla/xla
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] Add script to run multi gpu tests #25
Merged
i-chaochen
merged 2 commits into
rocm-jaxlib-v0.4.28-qa
from
rocm-jaxlib-v0.4.28-qa-multi-gpu-tets
Jul 4, 2024
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#!/usr/bin/env bash | ||
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# ============================================================================== | ||
|
||
set -e | ||
set -x | ||
|
||
N_BUILD_JOBS=$(grep -c ^processor /proc/cpuinfo) | ||
# If rocm-smi exists locally (it should) use it to find | ||
# out how many GPUs we have to test with. | ||
rocm-smi -i | ||
STATUS=$? | ||
if [ $STATUS -ne 0 ]; then TF_GPU_COUNT=1; else | ||
TF_GPU_COUNT=$(rocm-smi -i|grep 'Device ID' |grep 'GPU' |wc -l) | ||
fi | ||
if [[ $TF_GPU_COUNT -lt 4 ]]; then | ||
echo "Found only ${TF_GPU_COUNT} gpus, multi-gpu tests need atleast 4 gpus." | ||
exit | ||
fi | ||
|
||
TF_TESTS_PER_GPU=1 | ||
N_TEST_JOBS=$(expr ${TF_GPU_COUNT} \* ${TF_TESTS_PER_GPU}) | ||
hsharsha marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
echo "" | ||
echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} concurrent test job(s)." | ||
echo "" | ||
|
||
# First positional argument (if any) specifies the ROCM_INSTALL_DIR | ||
if [[ -n $1 ]]; then | ||
ROCM_INSTALL_DIR=$1 | ||
else | ||
if [[ -z "${ROCM_PATH}" ]]; then | ||
ROCM_INSTALL_DIR=/opt/rocm-6.0.2 | ||
else | ||
ROCM_INSTALL_DIR=$ROCM_PATH | ||
fi | ||
fi | ||
|
||
export PYTHON_BIN_PATH=`which python3` | ||
export TF_NEED_ROCM=1 | ||
export ROCM_PATH=$ROCM_INSTALL_DIR | ||
TAGS_FILTER="-oss_excluded,-oss_serial" | ||
UNSUPPORTED_GPU_TAGS="$(echo -requires-gpu-sm{60,70,80,86,89,90}{,-only})" | ||
TAGS_FILTER="${TAGS_FILTER},${UNSUPPORTED_GPU_TAGS// /,}" | ||
|
||
bazel \ | ||
test \ | ||
--config=rocm \ | ||
--build_tag_filters=${TAGS_FILTER} \ | ||
--test_tag_filters=${TAGS_FILTER} \ | ||
--test_timeout=920,2400,7200,9600 \ | ||
--test_sharding_strategy=disabled \ | ||
--test_output=errors \ | ||
--flaky_test_attempts=3 \ | ||
--keep_going \ | ||
--local_test_jobs=${N_TEST_JOBS} \ | ||
--test_env=TF_TESTS_PER_GPU=$TF_TESTS_PER_GPU \ | ||
--test_env=TF_GPU_COUNT=$TF_GPU_COUNT \ | ||
--action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \ | ||
--action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \ | ||
-- //xla/tests:collective_ops_test_e2e_gpu \ | ||
//xla/tests:collective_ops_test_gpu \ | ||
//xla/tests:replicated_io_feed_test_gpu \ | ||
//xla/tools/multihost_hlo_runner:functional_hlo_runner_test_gpu \ | ||
//xla/pjrt/distributed:topology_util_test \ | ||
//xla/pjrt/distributed:client_server_test |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a space between
atleast
? :)