Skip to content

Commit

Permalink
Re-introduce "XLA_USE_32BIT_LONG" flag
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Jan 14, 2025
1 parent 1c89675 commit 6186c82
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 56 deletions.
124 changes: 71 additions & 53 deletions test/test_data_type.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,91 @@
import os
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import unittest

class XlaDataTypeTest(unittest.TestCase):
def setUp(cls):
cls.original_env = {
'XLA_USE_BF16': os.environ.get('XLA_USE_BF16'),
'XLA_DOWNCAST_BF16': os.environ.get('XLA_DOWNCAST_BF16'),
'XLA_USE_FP16': os.environ.get('XLA_USE_FP16'),
'XLA_DOWNCAST_FP16': os.environ.get('XLA_DOWNCAST_FP16'),
'XLA_USE_32BIT_LONG': os.environ.get('XLA_USE_32BIT_LONG')
}

def check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
def tearDown(self):
for key, value in self.original_env.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = value

def _set_env(self, **kwargs):
for key, value in kwargs.items():
os.environ[key] = value

class XlaDataTypeTest(unittest.TestCase):
def _test_datatype(self, dtype, expected_type, op):
t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device())
t3 = op(t1, t2)
self.assertEqual(t3.dtype, dtype)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[2]
self.assertIn('xla::device_data', device_data_hlo)
self.assertIn(expected_type, device_data_hlo)

def test_datatype_f32(self):
t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device())
t3 = torch.div(t1, t2, rounding_mode='floor')
assert t3.dtype == torch.float
def test_datatype_use_bf16(self):
self._set_env(XLA_USE_BF16='1')
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
self._test_datatype(torch.float, 'bf16', torch.floor_divide)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
assert 'xla::device_data' in device_data_hlo, device_data_hlo
if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'):
assert 'bf16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'):
assert 'f16' in device_data_hlo, device_data_hlo
else:
assert 'f32' in device_data_hlo, device_data_hlo
def test_datatype_use_fp16(self):
self._set_env(XLA_USE_FP16='1')
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
self._test_datatype(torch.float, 'f16', torch.floor_divide)

def test_datatype_f64(self):
t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device())
t3 = torch.div(t1, t2, rounding_mode='floor')
assert t3.dtype == torch.double
def test_datatype_downcast_bf16(self):
self._set_env(XLA_DOWNCAST_BF16='1')
self._test_datatype(torch.double, 'bf16', torch.floor_divide)
self._test_datatype(torch.float, 'bf16', torch.floor_divide)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3])
device_data_hlo = hlo_text.split('\n')[1]
assert 'xla::device_data' in device_data_hlo, device_data_hlo
if check_env_flag('XLA_USE_BF16'):
assert 'bf16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_USE_FP16'):
assert 'f16' in device_data_hlo, device_data_hlo
elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag(
'XLA_DOWNCAST_FP16'):
assert 'f32' in device_data_hlo, device_data_hlo
else:
assert 'f64' in device_data_hlo, device_data_hlo
def test_datatype_downcast_fp16(self):
self._set_env(XLA_DOWNCAST_FP16='1')
self._test_datatype(torch.double, 'f16', torch.floor_divide)
self._test_datatype(torch.float, 'f16', torch.floor_divide)

def test_module_to_dtype(self):
device = torch_xla.device()
linear = torch.nn.Linear(
5, 10, dtype=torch.float32).to(device).to(torch.bfloat16)
input = torch.randn(
10,
5,
).to(device).to(torch.bfloat16)
xm.mark_step()
res = linear(input)
def test_datatype_use_32bit_long(self):
self._set_env(XLA_USE_32BIT_LONG='1')
self._test_datatype(torch.int64, 's32', torch.add)
self._test_datatype(torch.uint64, 'u32', torch.add)

hlo_text = torch_xla._XLAC._get_xla_tensors_text([res])
res_hlo = hlo_text.split('\n')[-3]
assert 'bf16' in res_hlo, res_hlo
def test_module_to_dtype(self):
device = torch_xla.device()
linear = torch.nn.Linear(5, 10, dtype=torch.float32).to(device).to(torch.bfloat16)
input = torch.randn(10, 5).to(device).to(torch.bfloat16)
xm.mark_step()
res = linear(input)

linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight
]).split('\n')[-3]
assert 'bf16' in linear_weight_hlo, linear_weight_hlo
hlo_text = torch_xla._XLAC._get_xla_tensors_text([res])
res_hlo = hlo_text.split('\n')[-3]
self.assertIn('bf16', res_hlo)

linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight]).split('\n')[-3]
self.assertIn('bf16', linear_weight_hlo)

if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
suite = unittest.TestSuite()
suite.addTest(XlaDataTypeTest("test_datatype_use_bf16"))
suite.addTest(XlaDataTypeTest("test_datatype_use_fp16"))
suite.addTest(XlaDataTypeTest("test_datatype_downcast_bf16"))
suite.addTest(XlaDataTypeTest("test_datatype_downcast_fp16"))
suite.addTest(XlaDataTypeTest("test_datatype_use_32bit_long"))
suite.addTest(XlaDataTypeTest("test_module_to_dtype"))
runner = unittest.TextTestRunner(failfast=True)
result = runner.run(suite)
sys.exit(0 if result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _setup_tpu_vm_library_path() -> bool:

def _check_deprecated_env_var():
deprecated_env_vars = [
'XLA_USE_FP16', 'XLA_DOWNCAST_FP16', 'XLA_USE_32BIT_LONG'
'XLA_USE_FP16', 'XLA_DOWNCAST_FP16'
]
for env_var in deprecated_env_vars:
if os.environ.get(env_var):
Expand Down
35 changes: 33 additions & 2 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ bool ShouldDowncastToBF16() {
return downcast_bf16;
}

std::optional<bool> ShouldUse32BitLong() {
const char* env_name = "XLA_USE_32BIT_LONG";
if (std::getenv(env_name) == nullptr) {
return std::nullopt;
}
bool use_32bit_long = runtime::sys_util::GetEnvBool(env_name, false);
if (use_32bit_long) {
std::cout
<< "XLA_USE_32BIT_LONG will be deprecated after the 2.6 release\n";
TF_LOG(INFO) << "Using 32bit integers for kLong values";
}
return use_32bit_long;
}

bool UseBF16() {
static bool use_bf16 = ShouldUseBF16();
return use_bf16;
Expand All @@ -40,6 +54,11 @@ bool DowncastBF16() {
return downcast_bf16;
}

std::optional<bool> Use32BitLong() {
static std::optional<bool> use_32bit_long = ShouldUse32BitLong();
return use_32bit_long;
}

} // namespace

at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
Expand Down Expand Up @@ -142,12 +161,24 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
case xla::PrimitiveType::S16:
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S16;
case xla::PrimitiveType::S64:
case xla::PrimitiveType::S64: {
std::optional<bool> use_32bit_long = Use32BitLong();
if (use_32bit_long.has_value()) {
return *use_32bit_long ? xla::PrimitiveType::S32
: xla::PrimitiveType::S64;
}
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
: xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
}
case xla::PrimitiveType::U64: {
std::optional<bool> use_32bit_long = Use32BitLong();
if (use_32bit_long.has_value()) {
return *use_32bit_long ? xla::PrimitiveType::U32
: xla::PrimitiveType::U64;
}
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32
: xla::PrimitiveType::U64;
}
case xla::PrimitiveType::C128:
return xla::PrimitiveType::C128;
default:
Expand Down

0 comments on commit 6186c82

Please sign in to comment.