diff --git a/test/test_data_type.py b/test/test_data_type.py index 8e06e15b40a5..b497983135ae 100644 --- a/test/test_data_type.py +++ b/test/test_data_type.py @@ -1,73 +1,91 @@ import os +import sys +import unittest import torch import torch_xla import torch_xla.core.xla_model as xm import torch_xla.utils.utils as xu -import unittest +class XlaDataTypeTest(unittest.TestCase): + def setUp(cls): + cls.original_env = { + 'XLA_USE_BF16': os.environ.get('XLA_USE_BF16'), + 'XLA_DOWNCAST_BF16': os.environ.get('XLA_DOWNCAST_BF16'), + 'XLA_USE_FP16': os.environ.get('XLA_USE_FP16'), + 'XLA_DOWNCAST_FP16': os.environ.get('XLA_DOWNCAST_FP16'), + 'XLA_USE_32BIT_LONG': os.environ.get('XLA_USE_32BIT_LONG') + } -def check_env_flag(name, default=''): - return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + def tearDown(self): + for key, value in self.original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + def _set_env(self, **kwargs): + for key, value in kwargs.items(): + os.environ[key] = value -class XlaDataTypeTest(unittest.TestCase): + def _test_datatype(self, dtype, expected_type, op): + t1 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) + t2 = torch.tensor([2, 3], dtype=dtype, device=xm.xla_device()) + t3 = op(t1, t2) + self.assertEqual(t3.dtype, dtype) + + hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) + device_data_hlo = hlo_text.split('\n')[2] + self.assertIn('xla::device_data', device_data_hlo) + self.assertIn(expected_type, device_data_hlo) - def test_datatype_f32(self): - t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) - t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) - t3 = torch.div(t1, t2, rounding_mode='floor') - assert t3.dtype == torch.float + def test_datatype_use_bf16(self): + self._set_env(XLA_USE_BF16='1') + self._test_datatype(torch.double, 'bf16', torch.floor_divide) + self._test_datatype(torch.float, 'bf16', torch.floor_divide) - hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) - device_data_hlo = hlo_text.split('\n')[1] - assert 'xla::device_data' in device_data_hlo, device_data_hlo - if check_env_flag('XLA_USE_BF16') or check_env_flag('XLA_DOWNCAST_BF16'): - assert 'bf16' in device_data_hlo, device_data_hlo - elif check_env_flag('XLA_USE_FP16') or check_env_flag('XLA_DOWNCAST_FP16'): - assert 'f16' in device_data_hlo, device_data_hlo - else: - assert 'f32' in device_data_hlo, device_data_hlo + def test_datatype_use_fp16(self): + self._set_env(XLA_USE_FP16='1') + self._test_datatype(torch.double, 'bf16', torch.floor_divide) + self._test_datatype(torch.float, 'f16', torch.floor_divide) - def test_datatype_f64(self): - t1 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device()) - t2 = torch.tensor([2.0, 3.0], dtype=torch.double, device=xm.xla_device()) - t3 = torch.div(t1, t2, rounding_mode='floor') - assert t3.dtype == torch.double + def test_datatype_downcast_bf16(self): + self._set_env(XLA_DOWNCAST_BF16='1') + self._test_datatype(torch.double, 'bf16', torch.floor_divide) + self._test_datatype(torch.float, 'bf16', torch.floor_divide) - hlo_text = torch_xla._XLAC._get_xla_tensors_text([t3]) - device_data_hlo = hlo_text.split('\n')[1] - assert 'xla::device_data' in device_data_hlo, device_data_hlo - if check_env_flag('XLA_USE_BF16'): - assert 'bf16' in device_data_hlo, device_data_hlo - elif check_env_flag('XLA_USE_FP16'): - assert 'f16' in device_data_hlo, device_data_hlo - elif check_env_flag('XLA_DOWNCAST_BF16') or check_env_flag( - 'XLA_DOWNCAST_FP16'): - assert 'f32' in device_data_hlo, device_data_hlo - else: - assert 'f64' in device_data_hlo, device_data_hlo + def test_datatype_downcast_fp16(self): + self._set_env(XLA_DOWNCAST_FP16='1') + self._test_datatype(torch.double, 'f16', torch.floor_divide) + self._test_datatype(torch.float, 'f16', torch.floor_divide) - def test_module_to_dtype(self): - device = torch_xla.device() - linear = torch.nn.Linear( - 5, 10, dtype=torch.float32).to(device).to(torch.bfloat16) - input = torch.randn( - 10, - 5, - ).to(device).to(torch.bfloat16) - xm.mark_step() - res = linear(input) + def test_datatype_use_32bit_long(self): + self._set_env(XLA_USE_32BIT_LONG='1') + self._test_datatype(torch.int64, 's32', torch.add) + self._test_datatype(torch.uint64, 'u32', torch.add) - hlo_text = torch_xla._XLAC._get_xla_tensors_text([res]) - res_hlo = hlo_text.split('\n')[-3] - assert 'bf16' in res_hlo, res_hlo + def test_module_to_dtype(self): + device = torch_xla.device() + linear = torch.nn.Linear(5, 10, dtype=torch.float32).to(device).to(torch.bfloat16) + input = torch.randn(10, 5).to(device).to(torch.bfloat16) + xm.mark_step() + res = linear(input) - linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight - ]).split('\n')[-3] - assert 'bf16' in linear_weight_hlo, linear_weight_hlo + hlo_text = torch_xla._XLAC._get_xla_tensors_text([res]) + res_hlo = hlo_text.split('\n')[-3] + self.assertIn('bf16', res_hlo) + linear_weight_hlo = torch_xla._XLAC._get_xla_tensors_text([linear.weight]).split('\n')[-3] + self.assertIn('bf16', linear_weight_hlo) if __name__ == '__main__': - test = unittest.main() - sys.exit(0 if test.result.wasSuccessful() else 1) + suite = unittest.TestSuite() + suite.addTest(XlaDataTypeTest("test_datatype_use_bf16")) + suite.addTest(XlaDataTypeTest("test_datatype_use_fp16")) + suite.addTest(XlaDataTypeTest("test_datatype_downcast_bf16")) + suite.addTest(XlaDataTypeTest("test_datatype_downcast_fp16")) + suite.addTest(XlaDataTypeTest("test_datatype_use_32bit_long")) + suite.addTest(XlaDataTypeTest("test_module_to_dtype")) + runner = unittest.TextTestRunner(failfast=True) + result = runner.run(suite) + sys.exit(0 if result.wasSuccessful() else 1) \ No newline at end of file diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index e214e7a47a77..317e067a46ad 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -179,7 +179,7 @@ def _setup_tpu_vm_library_path() -> bool: def _check_deprecated_env_var(): deprecated_env_vars = [ - 'XLA_USE_FP16', 'XLA_DOWNCAST_FP16', 'XLA_USE_32BIT_LONG' + 'XLA_USE_FP16', 'XLA_DOWNCAST_FP16' ] for env_var in deprecated_env_vars: if os.environ.get(env_var): diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 759c045f8f20..6e1a92a403ea 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -30,6 +30,20 @@ bool ShouldDowncastToBF16() { return downcast_bf16; } +std::optional ShouldUse32BitLong() { + const char* env_name = "XLA_USE_32BIT_LONG"; + if (std::getenv(env_name) == nullptr) { + return std::nullopt; + } + bool use_32bit_long = runtime::sys_util::GetEnvBool(env_name, false); + if (use_32bit_long) { + std::cout + << "XLA_USE_32BIT_LONG will be deprecated after the 2.6 release\n"; + TF_LOG(INFO) << "Using 32bit integers for kLong values"; + } + return use_32bit_long; +} + bool UseBF16() { static bool use_bf16 = ShouldUseBF16(); return use_bf16; @@ -40,6 +54,11 @@ bool DowncastBF16() { return downcast_bf16; } +std::optional Use32BitLong() { + static std::optional use_32bit_long = ShouldUse32BitLong(); + return use_32bit_long; +} + } // namespace at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { @@ -142,12 +161,24 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( case xla::PrimitiveType::S16: return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 : xla::PrimitiveType::S16; - case xla::PrimitiveType::S64: + case xla::PrimitiveType::S64: { + std::optional use_32bit_long = Use32BitLong(); + if (use_32bit_long.has_value()) { + return *use_32bit_long ? xla::PrimitiveType::S32 + : xla::PrimitiveType::S64; + } return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64; - case xla::PrimitiveType::U64: + } + case xla::PrimitiveType::U64: { + std::optional use_32bit_long = Use32BitLong(); + if (use_32bit_long.has_value()) { + return *use_32bit_long ? xla::PrimitiveType::U32 + : xla::PrimitiveType::U64; + } return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64; + } case xla::PrimitiveType::C128: return xla::PrimitiveType::C128; default: