From aecbf3b1d2f66591825d1c34f4b42ed9ed49f93f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 5 May 2023 14:15:49 -0700 Subject: [PATCH] Update for new PJRT API with new Dependencies (#86) A recent update to the PJRT API requires an updated implementation. This includes factoring out DeviceDescription from the DeviceInstance behavior. --------- Co-authored-by: OpenXLA Dep Roller --- iree/integrations/pjrt/common/api_impl.cc | 90 +++++++++++++---------- iree/integrations/pjrt/common/api_impl.h | 77 +++++++++++++------ requirements.txt | 4 +- sync_deps.py | 6 +- 4 files changed, 114 insertions(+), 63 deletions(-) diff --git a/iree/integrations/pjrt/common/api_impl.cc b/iree/integrations/pjrt/common/api_impl.cc index 0a8b47f2..99f26c11 100644 --- a/iree/integrations/pjrt/common/api_impl.cc +++ b/iree/integrations/pjrt/common/api_impl.cc @@ -577,63 +577,83 @@ iree_status_t BufferInstance::AdvanceDoneFence(iree_hal_semaphore_t* semaphore, } //===----------------------------------------------------------------------===// -// DeviceInstance +// DeviceDescription //===----------------------------------------------------------------------===// -DeviceInstance::~DeviceInstance() = default; +DeviceDescription::~DeviceDescription() = default; -void DeviceInstance::BindApi(PJRT_Api* api) { - api->PJRT_Device_Id = +[](PJRT_Device_Id_Args* args) -> PJRT_Error* { - args->id = DeviceInstance::Unwrap(args->device)->client_id(); - return nullptr; - }; - api->PJRT_Device_ProcessIndex = - +[](PJRT_Device_ProcessIndex_Args* args) -> PJRT_Error* { - args->process_index = DeviceInstance::Unwrap(args->device)->process_index(); +void DeviceDescription::BindApi(PJRT_Api* api) { + api->PJRT_DeviceDescription_Id = + +[](PJRT_DeviceDescription_Id_Args* args) -> PJRT_Error* { + args->id = DeviceDescription::Unwrap(args->device_description)->client_id(); return nullptr; }; - api->PJRT_Device_IsAddressable = - +[](PJRT_Device_IsAddressable_Args* args) -> PJRT_Error* { - args->is_addressable = - DeviceInstance::Unwrap(args->device)->is_addressable(); + api->PJRT_DeviceDescription_ProcessIndex = + +[](PJRT_DeviceDescription_ProcessIndex_Args* args) -> PJRT_Error* { + args->process_index = + DeviceDescription::Unwrap(args->device_description)->process_index(); return nullptr; }; - - api->PJRT_Device_Attributes = - +[](PJRT_Device_Attributes_Args* args) -> PJRT_Error* { + api->PJRT_DeviceDescription_Attributes = + +[](PJRT_DeviceDescription_Attributes_Args* args) -> PJRT_Error* { // TODO: Implement something. args->num_attributes = 0; args->attributes = nullptr; return nullptr; }; - api->PJRT_Device_Kind = +[](PJRT_Device_Kind_Args* args) -> PJRT_Error* { - auto sv = DeviceInstance::Unwrap(args->device)->kind_string(); + api->PJRT_DeviceDescription_Kind = + +[](PJRT_DeviceDescription_Kind_Args* args) -> PJRT_Error* { + auto sv = + DeviceDescription::Unwrap(args->device_description)->kind_string(); args->device_kind = sv.data(); args->device_kind_size = sv.size(); return nullptr; }; - api->PJRT_Device_LocalHardwareId = - +[](PJRT_Device_LocalHardwareId_Args* args) -> PJRT_Error* { - args->local_hardware_id = - DeviceInstance::Unwrap(args->device)->local_hardware_id(); - return nullptr; - }; - api->PJRT_Device_DebugString = - +[](PJRT_Device_DebugString_Args* args) -> PJRT_Error* { - auto sv = DeviceInstance::Unwrap(args->device)->debug_string(); + api->PJRT_DeviceDescription_DebugString = + +[](PJRT_DeviceDescription_DebugString_Args* args) -> PJRT_Error* { + auto sv = + DeviceDescription::Unwrap(args->device_description)->debug_string(); args->debug_string = sv.data(); args->debug_string_size = sv.size(); return nullptr; }; - api->PJRT_Device_ToString = - +[](PJRT_Device_ToString_Args* args) -> PJRT_Error* { - auto sv = DeviceInstance::Unwrap(args->device)->user_string(); + api->PJRT_DeviceDescription_ToString = + +[](PJRT_DeviceDescription_ToString_Args* args) -> PJRT_Error* { + auto sv = + DeviceDescription::Unwrap(args->device_description)->user_string(); args->to_string = sv.data(); args->to_string_size = sv.size(); return nullptr; }; } +//===----------------------------------------------------------------------===// +// DeviceInstance +//===----------------------------------------------------------------------===// + +DeviceInstance::~DeviceInstance() = default; + +void DeviceInstance::BindApi(PJRT_Api* api) { + api->PJRT_Device_IsAddressable = + +[](PJRT_Device_IsAddressable_Args* args) -> PJRT_Error* { + args->is_addressable = + DeviceInstance::Unwrap(args->device)->is_addressable(); + return nullptr; + }; + api->PJRT_Device_LocalHardwareId = + +[](PJRT_Device_LocalHardwareId_Args* args) -> PJRT_Error* { + args->local_hardware_id = + DeviceInstance::Unwrap(args->device)->local_hardware_id(); + return nullptr; + }; + api->PJRT_Device_GetDescription = + +[](PJRT_Device_GetDescription_Args* args) -> PJRT_Error* { + args->device_description = reinterpret_cast( + DeviceInstance::Unwrap(args->device)->device_description()); + return nullptr; + }; +} + iree_status_t DeviceInstance::CreateFence(iree_hal_fence_t** out_fence) { return IreeApi::hal_fence_create(/*capacity=*/2, client_.host_allocator(), out_fence); @@ -642,7 +662,7 @@ iree_status_t DeviceInstance::CreateFence(iree_hal_fence_t** out_fence) { iree_status_t DeviceInstance::OpenDevice() { if (device_) return iree_ok_status(); IREE_RETURN_IF_ERROR(iree_hal_driver_create_device_by_id( - driver_, /*device_id=*/info_->device_id, + driver_, /*device_id=*/info_.device_id(), /*param_count=*/0, /*params=*/nullptr, client_.host_allocator(), &device_)); IREE_RETURN_IF_ERROR( @@ -650,11 +670,6 @@ iree_status_t DeviceInstance::OpenDevice() { IREE_RETURN_IF_ERROR( iree_hal_semaphore_create(device(), 0ull, &transfer_timeline_)); - // Initialize debug strings. - user_string_ = std::string(info_->path.data, info_->path.size); - debug_string_ = std::string(info_->name.data, info_->name.size); - kind_string_ = std::string(info_->name.data, info_->name.size); - return iree_ok_status(); } @@ -1613,6 +1628,7 @@ void BindMonomorphicApi(PJRT_Api* api) { // Bind by object types. BufferInstance::BindApi(api); ClientInstance::BindApi(api); + DeviceDescription::BindApi(api); DeviceInstance::BindApi(api); ErrorInstance::BindApi(api); EventInstance::BindApi(api); diff --git a/iree/integrations/pjrt/common/api_impl.h b/iree/integrations/pjrt/common/api_impl.h index 65ca5f3f..6d287401 100644 --- a/iree/integrations/pjrt/common/api_impl.h +++ b/iree/integrations/pjrt/common/api_impl.h @@ -120,40 +120,79 @@ class BufferInstance { }; //===----------------------------------------------------------------------===// -// DeviceInstance +// DeviceDescription //===----------------------------------------------------------------------===// -class DeviceInstance { +class DeviceDescription { public: - DeviceInstance(int client_id, ClientInstance& client, - iree_hal_driver_t* driver, iree_hal_device_info_t* info) - : client_id_(client_id), client_(client), driver_(driver), info_(info) {} - ~DeviceInstance(); - operator PJRT_Device*() { return reinterpret_cast(this); } + DeviceDescription(int32_t client_id, iree_hal_device_info_t* info) + : client_id_(client_id), info_(info) { + // Initialize debug strings. + user_string_ = std::string(info_->path.data, info_->path.size); + debug_string_ = std::string(info_->name.data, info_->name.size); + kind_string_ = std::string(info_->name.data, info_->name.size); + } + ~DeviceDescription(); + operator PJRT_DeviceDescription*() { + return reinterpret_cast(this); + } static void BindApi(PJRT_Api* api); - static DeviceInstance* Unwrap(PJRT_Device* device) { - return reinterpret_cast(device); + + static DeviceDescription* Unwrap(PJRT_DeviceDescription* device) { + return reinterpret_cast(device); } + int64_t device_id() { return info_->device_id; } + // Since the PJRT device id is a simple int and the IREE device_id is // a pointer-sized value, we just assign a synthetic id. Currently, this // is the offset into the devices() array on the client. Will need to be // revisited if ever supporting re-scanning (but many things would seem to // need updates then). int client_id() { return client_id_; } - iree_hal_device_info_t* info() { return info_; } - // Not yet implemented but plumbed through. - bool is_addressable() { return true; } int process_index() { return 0; } - int local_hardware_id() { return -1; } // Various debug descriptions of the device. Backing string data owned by - // the device. + // the device description. std::string_view kind_string() { return kind_string_; } std::string_view debug_string() { return debug_string_; } std::string_view user_string() { return user_string_; } + private: + int client_id_; + iree_hal_device_info_t* info_; + + // Debug strings (owned by device description). + std::string kind_string_; + std::string debug_string_; + std::string user_string_; +}; + +//===----------------------------------------------------------------------===// +// DeviceInstance +//===----------------------------------------------------------------------===// + +class DeviceInstance { + public: + DeviceInstance(int client_id, ClientInstance& client, + iree_hal_driver_t* driver, iree_hal_device_info_t* info) + : client_(client), driver_(driver), info_(client_id, info) {} + ~DeviceInstance(); + operator PJRT_Device*() { return reinterpret_cast(this); } + static void BindApi(PJRT_Api* api); + + static DeviceInstance* Unwrap(PJRT_Device* device) { + return reinterpret_cast(device); + } + + static DeviceInstance* Unwrap(PJRT_DeviceDescription* device_description) { + return reinterpret_cast(device_description); + } + + bool is_addressable() { return true; } + int local_hardware_id() { return -1; } + // Copies a host buffer to the device. // See PJRT_Client_BufferFromHostBuffer iree_status_t HostBufferToDevice( @@ -166,6 +205,8 @@ class DeviceInstance { // TODO(laurenzo): Eagerly set up device to allow simple access. iree_status_t GetHalDevice(iree_hal_device_t** out_device); + DeviceDescription* device_description() { return &info_; } + // Only valid once device opened. iree_hal_semaphore_t* main_timeline() { return main_timeline_.get(); } @@ -184,7 +225,6 @@ class DeviceInstance { bool snapshot_initial_contents_now, bool* initial_contents_snapshotted, iree_hal_buffer_t** out_buffer); - int client_id_; ClientInstance& client_; iree_hal_driver_t* driver_; // Owned by client. iree::vm::ref device_; @@ -195,12 +235,7 @@ class DeviceInstance { iree::vm::ref transfer_now_fence_; // The timepoint of the last transfer. uint64_t last_transfer_timepoint_ = 0; - iree_hal_device_info_t* info_; - - // Debug strings (owned by device). - std::string kind_string_; - std::string debug_string_; - std::string user_string_; + DeviceDescription info_; }; //===----------------------------------------------------------------------===// diff --git a/requirements.txt b/requirements.txt index 085ae5e0..6371a028 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -f https://openxla.github.io/iree/pip-release-links.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -iree-compiler==20230427.502 -jaxlib==0.4.9.dev20230424 +iree-compiler==20230503.508 +jaxlib==0.4.9.dev20230503 -e ../jax diff --git a/sync_deps.py b/sync_deps.py index 87d4ee9a..2eb37664 100755 --- a/sync_deps.py +++ b/sync_deps.py @@ -7,9 +7,9 @@ ### Update with: openxla-workspace pin PINNED_VERSIONS = { - "iree": "6c2f27d83df9214b61a9d2ba2d9b77e54df52520", - "xla": "cf0515f724bdbd693b86c5c5b3e01e91eb6ef6be", - "jax": "0814b874d50784da22d4af47a54509495f94b8b6" + "iree": "0884b3ee6df3946c46d97855d9f78fa9b12b8f90", + "xla": "6c6744c8c5ded1c8e4af83296b508e55fe432e19", + "jax": "95e1e6d3efff585db5404b2fe5f116027af4acb2" } ORIGINS = {