Skip to content

Commit

Permalink
Fix malformed tests in test_usm_ndarray_dlpack
Browse files Browse the repository at this point in the history
These tests would fail on machines with more than 2 devices for a given platform due to an incorrect asusmption that the DLPack device ID would match that of the cached root devices, of which only 2 are kept per platform
  • Loading branch information
ndgrigorian committed Jan 31, 2025
1 parent bfd4d57 commit 7ef0e44
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_dlpack_device(usm_type, all_root_devices):
assert type(dev) is tuple
assert len(dev) == 2
assert dev[0] == device_oneAPI
assert sycl_dev == all_root_devices[dev[1]]
assert dev[1] == sycl_dev.get_device_id()


def test_dlpack_exporter(typestr, usm_type, all_root_devices):
Expand Down Expand Up @@ -834,15 +834,15 @@ def test_sycl_device_to_dldevice(all_root_devices):
assert type(dev) is tuple
assert len(dev) == 2
assert dev[0] == device_oneAPI
assert dev[1] == all_root_devices.index(sycl_dev)
assert dev[1] == sycl_dev.get_device_id()


def test_dldevice_to_sycl_device(all_root_devices):
for sycl_dev in all_root_devices:
dldev = dpt.empty(0, device=sycl_dev).__dlpack_device__()
dev = dpt.dldevice_to_sycl_device(dldev)
assert type(dev) is dpctl.SyclDevice
assert dev == all_root_devices[dldev[1]]
assert dev.get_device_id() == sycl_dev.get_device_id()


def test_dldevice_conversion_arg_validation():
Expand Down

0 comments on commit 7ef0e44

Please sign in to comment.