Skip to content

Commit

Permalink
Re-introduce "XLA_USE_32BIT_LONG" flag (#8589)
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws authored Jan 17, 2025
1 parent 1dd4969 commit fccd395
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 57 deletions.
4 changes: 1 addition & 3 deletions test/neuron/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_test "$CDIR/dynamo/test_dynamo_config.py"
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
#run_test "$CDIR/test_data_type.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/test_data_type.py"
#run_test "$CDIR/test_fp8.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
Expand Down
4 changes: 2 additions & 2 deletions test/neuron/test_neuron_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def test_datatypes(self):
(torch.double, "f32", torch.floor_divide),
(torch.int16, "s32", torch.add),
(torch.int32, "s32", torch.add),
(torch.int64, "s32", torch.add),
(torch.int64, "s64", torch.add),
(torch.uint16, "u32", torch.add),
(torch.uint32, "u32", torch.add),
(torch.uint64, "u32", torch.add)]
(torch.uint64, "u64", torch.add)]

for dtype, op_xla_dtype, op in test_cases:
with self.subTest(dtype=dtype, op_xla_dtype=op_xla_dtype, op=op):
Expand Down
2 changes: 0 additions & 2 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_dynamo_config.py"
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_test "$CDIR/test_data_type.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/test_fp8.py"
run_xla_ir_debug run_test "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug run_test "$CDIR/test_env_var_mapper.py"
Expand Down
95 changes: 52 additions & 43 deletions test/test_data_type.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,82 @@
import os
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import unittest


def check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
class XlaDataTypeTest(unittest.TestCase):

def setUp(cls):
cls.original_env = {
'XLA_USE_BF16': os.environ.get('XLA_USE_BF16'),
'XLA_DOWNCAST_BF16': os.environ.get('XLA_DOWNCAST_BF16'),
'XLA_USE_32BIT_LONG': os.environ.get('XLA_USE_32BIT_LONG')
}

class XlaDataTypeTest(unittest.TestCase):
def tearDown(self):
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

def test_datatype_f32(self):
t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
t3 = torch.div(t1, t2, rounding_mode='floor')
assert t3.dtype == torch.float
def _set_env(self, **kwargs):
for key, value in kwargs.items():
os.environ[key] = value

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
assert 'xla::device_data' in device_data_hlo, device_data_hlo
if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'):
assert 'bf16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'):
assert 'f16' in device_data_hlo, device_data_hlo
else:
assert 'f32' in device_data_hlo, device_data_hlo

def test_datatype_f64(self):
t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
t3 = torch.div(t1, t2, rounding_mode='floor')
assert t3.dtype == torch.double
def _test_datatype(self, dtype, expected_type, op):
t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
t3 = op(t1, t2)
self.assertEqual(t3.dtype, dtype)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
assert 'xla::device_data' in device_data_hlo, device_data_hlo
if check_env_flag('XLA_USE_BF16'):
assert 'bf16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_USE_FP16'):
assert 'f16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag(
'XLA_DOWNCAST_FP16'):
assert 'f32' in device_data_hlo, device_data_hlo
else:
assert 'f64' in device_data_hlo, device_data_hlo
device_data_hlo = hlo_text.split('\n')[2]
self.assertIn('xla::device_data', device_data_hlo)
self.assertIn(expected_type, device_data_hlo)

def test_datatype_use_bf16(self):
self._set_env(XLA_USE_BF16='1')
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
self._test_datatype(torch.float, 'bf16', torch.floor_divide)

def test_datatype_downcast_bf16(self):
self._set_env(XLA_DOWNCAST_BF16='1')
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
self._test_datatype(torch.float, 'bf16', torch.floor_divide)

def test_datatype_use_32bit_long(self):
self._set_env(XLA_USE_32BIT_LONG='1')
self._test_datatype(torch.int64, 's32', torch.add)
self._test_datatype(torch.uint64, 'u32', torch.add)

def test_module_to_dtype(self):
device = torch_xla.device()
linear = torch.nn.Linear(
5, 10, dtype=torch.float32).to(device).to(torch.bfloat16)
input = torch.randn(
10,
5,
).to(device).to(torch.bfloat16)
input = torch.randn(10, 5).to(device).to(torch.bfloat16)
xm.mark_step()
res = linear(input)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([res])
res_hlo = hlo_text.split('\n')[-3]
assert 'bf16' in res_hlo, res_hlo
self.assertIn('bf16', res_hlo)

linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight
]).split('\n')[-3]
assert 'bf16' in linear_weight_hlo, linear_weight_hlo
self.assertIn('bf16', linear_weight_hlo)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
suite = unittest.TestSuite()
suite.addTest(XlaDataTypeTest("test_datatype_use_bf16"))
suite.addTest(XlaDataTypeTest("test_datatype_downcast_bf16"))
suite.addTest(XlaDataTypeTest("test_datatype_use_32bit_long"))
suite.addTest(XlaDataTypeTest("test_module_to_dtype"))
runner = unittest.TextTestRunner(failfast=True)
result = runner.run(suite)
sys.exit(0 if result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xl
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py"
python3 "$TEST_CDIR/quantized_ops/test_dot_general.py"
run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py"
python3 "$TEST_CDIR/test_data_type.py"

# run examples, each test should takes <2 minutes
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py"
Expand Down
4 changes: 1 addition & 3 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def _setup_tpu_vm_library_path() -> bool:


def _check_deprecated_env_var():
deprecated_env_vars = [
'XLA_USE_FP16', 'XLA_DOWNCAST_FP16', 'XLA_USE_32BIT_LONG'
]
deprecated_env_vars = ['XLA_USE_FP16', 'XLA_DOWNCAST_FP16']
for env_var in deprecated_env_vars:
if os.environ.get(env_var):
warnings.warn(f"The environment variable '{env_var}' is deprecated "
Expand Down
22 changes: 18 additions & 4 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ bool ShouldDowncastToBF16() {
return downcast_bf16;
}

bool ShouldUse32BitLong() {
bool use_32bit_long =
runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false);
if (use_32bit_long) {
std::cout
<< "XLA_USE_32BIT_LONG will be deprecated after the 2.6 release\n";
TF_LOG(INFO) << "Using 32bit integers for kLong values";
}
return use_32bit_long;
}

bool UseBF16() {
static bool use_bf16 = ShouldUseBF16();
return use_bf16;
Expand All @@ -40,6 +51,11 @@ bool DowncastBF16() {
return downcast_bf16;
}

bool Use32BitLong() {
static bool use_32bit_long = ShouldUse32BitLong();
return use_32bit_long;
}

} // namespace

at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
Expand Down Expand Up @@ -143,11 +159,9 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S16;
case xla::PrimitiveType::S64:
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S64;
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32
: xla::PrimitiveType::U64;
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return xla::PrimitiveType::C128;
default:
Expand Down

0 comments on commit fccd395

Please sign in to comment.