Skip to content

Commit

Permalink
Update for new PJRT API with new Dependencies (#86)
Browse files Browse the repository at this point in the history
A recent update to the PJRT API requires an updated implementation. This
includes factoring
out DeviceDescription from the DeviceInstance behavior.

---------

Co-authored-by: OpenXLA Dep Roller <[email protected]>
  • Loading branch information
rsuderman and iree-github-actions-bot authored May 5, 2023
1 parent c6a6070 commit aecbf3b
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 63 deletions.
90 changes: 53 additions & 37 deletions iree/integrations/pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,63 +577,83 @@ iree_status_t BufferInstance::AdvanceDoneFence(iree_hal_semaphore_t* semaphore,
}

//===----------------------------------------------------------------------===//
// DeviceInstance
// DeviceDescription
//===----------------------------------------------------------------------===//

DeviceInstance::~DeviceInstance() = default;
DeviceDescription::~DeviceDescription() = default;

void DeviceInstance::BindApi(PJRT_Api* api) {
api->PJRT_Device_Id = +[](PJRT_Device_Id_Args* args) -> PJRT_Error* {
args->id = DeviceInstance::Unwrap(args->device)->client_id();
return nullptr;
};
api->PJRT_Device_ProcessIndex =
+[](PJRT_Device_ProcessIndex_Args* args) -> PJRT_Error* {
args->process_index = DeviceInstance::Unwrap(args->device)->process_index();
void DeviceDescription::BindApi(PJRT_Api* api) {
api->PJRT_DeviceDescription_Id =
+[](PJRT_DeviceDescription_Id_Args* args) -> PJRT_Error* {
args->id = DeviceDescription::Unwrap(args->device_description)->client_id();
return nullptr;
};
api->PJRT_Device_IsAddressable =
+[](PJRT_Device_IsAddressable_Args* args) -> PJRT_Error* {
args->is_addressable =
DeviceInstance::Unwrap(args->device)->is_addressable();
api->PJRT_DeviceDescription_ProcessIndex =
+[](PJRT_DeviceDescription_ProcessIndex_Args* args) -> PJRT_Error* {
args->process_index =
DeviceDescription::Unwrap(args->device_description)->process_index();
return nullptr;
};

api->PJRT_Device_Attributes =
+[](PJRT_Device_Attributes_Args* args) -> PJRT_Error* {
api->PJRT_DeviceDescription_Attributes =
+[](PJRT_DeviceDescription_Attributes_Args* args) -> PJRT_Error* {
// TODO: Implement something.
args->num_attributes = 0;
args->attributes = nullptr;
return nullptr;
};
api->PJRT_Device_Kind = +[](PJRT_Device_Kind_Args* args) -> PJRT_Error* {
auto sv = DeviceInstance::Unwrap(args->device)->kind_string();
api->PJRT_DeviceDescription_Kind =
+[](PJRT_DeviceDescription_Kind_Args* args) -> PJRT_Error* {
auto sv =
DeviceDescription::Unwrap(args->device_description)->kind_string();
args->device_kind = sv.data();
args->device_kind_size = sv.size();
return nullptr;
};
api->PJRT_Device_LocalHardwareId =
+[](PJRT_Device_LocalHardwareId_Args* args) -> PJRT_Error* {
args->local_hardware_id =
DeviceInstance::Unwrap(args->device)->local_hardware_id();
return nullptr;
};
api->PJRT_Device_DebugString =
+[](PJRT_Device_DebugString_Args* args) -> PJRT_Error* {
auto sv = DeviceInstance::Unwrap(args->device)->debug_string();
api->PJRT_DeviceDescription_DebugString =
+[](PJRT_DeviceDescription_DebugString_Args* args) -> PJRT_Error* {
auto sv =
DeviceDescription::Unwrap(args->device_description)->debug_string();
args->debug_string = sv.data();
args->debug_string_size = sv.size();
return nullptr;
};
api->PJRT_Device_ToString =
+[](PJRT_Device_ToString_Args* args) -> PJRT_Error* {
auto sv = DeviceInstance::Unwrap(args->device)->user_string();
api->PJRT_DeviceDescription_ToString =
+[](PJRT_DeviceDescription_ToString_Args* args) -> PJRT_Error* {
auto sv =
DeviceDescription::Unwrap(args->device_description)->user_string();
args->to_string = sv.data();
args->to_string_size = sv.size();
return nullptr;
};
}

//===----------------------------------------------------------------------===//
// DeviceInstance
//===----------------------------------------------------------------------===//

DeviceInstance::~DeviceInstance() = default;

void DeviceInstance::BindApi(PJRT_Api* api) {
api->PJRT_Device_IsAddressable =
+[](PJRT_Device_IsAddressable_Args* args) -> PJRT_Error* {
args->is_addressable =
DeviceInstance::Unwrap(args->device)->is_addressable();
return nullptr;
};
api->PJRT_Device_LocalHardwareId =
+[](PJRT_Device_LocalHardwareId_Args* args) -> PJRT_Error* {
args->local_hardware_id =
DeviceInstance::Unwrap(args->device)->local_hardware_id();
return nullptr;
};
api->PJRT_Device_GetDescription =
+[](PJRT_Device_GetDescription_Args* args) -> PJRT_Error* {
args->device_description = reinterpret_cast<PJRT_DeviceDescription*>(
DeviceInstance::Unwrap(args->device)->device_description());
return nullptr;
};
}

iree_status_t DeviceInstance::CreateFence(iree_hal_fence_t** out_fence) {
return IreeApi::hal_fence_create(/*capacity=*/2, client_.host_allocator(),
out_fence);
Expand All @@ -642,19 +662,14 @@ iree_status_t DeviceInstance::CreateFence(iree_hal_fence_t** out_fence) {
iree_status_t DeviceInstance::OpenDevice() {
if (device_) return iree_ok_status();
IREE_RETURN_IF_ERROR(iree_hal_driver_create_device_by_id(
driver_, /*device_id=*/info_->device_id,
driver_, /*device_id=*/info_.device_id(),
/*param_count=*/0, /*params=*/nullptr, client_.host_allocator(),
&device_));
IREE_RETURN_IF_ERROR(
iree_hal_semaphore_create(device(), 0ull, &main_timeline_));
IREE_RETURN_IF_ERROR(
iree_hal_semaphore_create(device(), 0ull, &transfer_timeline_));

// Initialize debug strings.
user_string_ = std::string(info_->path.data, info_->path.size);
debug_string_ = std::string(info_->name.data, info_->name.size);
kind_string_ = std::string(info_->name.data, info_->name.size);

return iree_ok_status();
}

Expand Down Expand Up @@ -1613,6 +1628,7 @@ void BindMonomorphicApi(PJRT_Api* api) {
// Bind by object types.
BufferInstance::BindApi(api);
ClientInstance::BindApi(api);
DeviceDescription::BindApi(api);
DeviceInstance::BindApi(api);
ErrorInstance::BindApi(api);
EventInstance::BindApi(api);
Expand Down
77 changes: 56 additions & 21 deletions iree/integrations/pjrt/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,40 +120,79 @@ class BufferInstance {
};

//===----------------------------------------------------------------------===//
// DeviceInstance
// DeviceDescription
//===----------------------------------------------------------------------===//

class DeviceInstance {
class DeviceDescription {
public:
DeviceInstance(int client_id, ClientInstance& client,
iree_hal_driver_t* driver, iree_hal_device_info_t* info)
: client_id_(client_id), client_(client), driver_(driver), info_(info) {}
~DeviceInstance();
operator PJRT_Device*() { return reinterpret_cast<PJRT_Device*>(this); }
DeviceDescription(int32_t client_id, iree_hal_device_info_t* info)
: client_id_(client_id), info_(info) {
// Initialize debug strings.
user_string_ = std::string(info_->path.data, info_->path.size);
debug_string_ = std::string(info_->name.data, info_->name.size);
kind_string_ = std::string(info_->name.data, info_->name.size);
}
~DeviceDescription();
operator PJRT_DeviceDescription*() {
return reinterpret_cast<PJRT_DeviceDescription*>(this);
}
static void BindApi(PJRT_Api* api);
static DeviceInstance* Unwrap(PJRT_Device* device) {
return reinterpret_cast<DeviceInstance*>(device);

static DeviceDescription* Unwrap(PJRT_DeviceDescription* device) {
return reinterpret_cast<DeviceDescription*>(device);
}

int64_t device_id() { return info_->device_id; }

// Since the PJRT device id is a simple int and the IREE device_id is
// a pointer-sized value, we just assign a synthetic id. Currently, this
// is the offset into the devices() array on the client. Will need to be
// revisited if ever supporting re-scanning (but many things would seem to
// need updates then).
int client_id() { return client_id_; }
iree_hal_device_info_t* info() { return info_; }

// Not yet implemented but plumbed through.
bool is_addressable() { return true; }
int process_index() { return 0; }
int local_hardware_id() { return -1; }

// Various debug descriptions of the device. Backing string data owned by
// the device.
// the device description.
std::string_view kind_string() { return kind_string_; }
std::string_view debug_string() { return debug_string_; }
std::string_view user_string() { return user_string_; }

private:
int client_id_;
iree_hal_device_info_t* info_;

// Debug strings (owned by device description).
std::string kind_string_;
std::string debug_string_;
std::string user_string_;
};

//===----------------------------------------------------------------------===//
// DeviceInstance
//===----------------------------------------------------------------------===//

class DeviceInstance {
public:
DeviceInstance(int client_id, ClientInstance& client,
iree_hal_driver_t* driver, iree_hal_device_info_t* info)
: client_(client), driver_(driver), info_(client_id, info) {}
~DeviceInstance();
operator PJRT_Device*() { return reinterpret_cast<PJRT_Device*>(this); }
static void BindApi(PJRT_Api* api);

static DeviceInstance* Unwrap(PJRT_Device* device) {
return reinterpret_cast<DeviceInstance*>(device);
}

static DeviceInstance* Unwrap(PJRT_DeviceDescription* device_description) {
return reinterpret_cast<DeviceInstance*>(device_description);
}

bool is_addressable() { return true; }
int local_hardware_id() { return -1; }

// Copies a host buffer to the device.
// See PJRT_Client_BufferFromHostBuffer
iree_status_t HostBufferToDevice(
Expand All @@ -166,6 +205,8 @@ class DeviceInstance {
// TODO(laurenzo): Eagerly set up device to allow simple access.
iree_status_t GetHalDevice(iree_hal_device_t** out_device);

DeviceDescription* device_description() { return &info_; }

// Only valid once device opened.
iree_hal_semaphore_t* main_timeline() { return main_timeline_.get(); }

Expand All @@ -184,7 +225,6 @@ class DeviceInstance {
bool snapshot_initial_contents_now, bool* initial_contents_snapshotted,
iree_hal_buffer_t** out_buffer);

int client_id_;
ClientInstance& client_;
iree_hal_driver_t* driver_; // Owned by client.
iree::vm::ref<iree_hal_device_t> device_;
Expand All @@ -195,12 +235,7 @@ class DeviceInstance {
iree::vm::ref<iree_hal_fence_t> transfer_now_fence_;
// The timepoint of the last transfer.
uint64_t last_transfer_timepoint_ = 0;
iree_hal_device_info_t* info_;

// Debug strings (owned by device).
std::string kind_string_;
std::string debug_string_;
std::string user_string_;
DeviceDescription info_;
};

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-f https://openxla.github.io/iree/pip-release-links.html
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
iree-compiler==20230427.502
jaxlib==0.4.9.dev20230424
iree-compiler==20230503.508
jaxlib==0.4.9.dev20230503
-e ../jax
6 changes: 3 additions & 3 deletions sync_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
### Update with: openxla-workspace pin

PINNED_VERSIONS = {
"iree": "6c2f27d83df9214b61a9d2ba2d9b77e54df52520",
"xla": "cf0515f724bdbd693b86c5c5b3e01e91eb6ef6be",
"jax": "0814b874d50784da22d4af47a54509495f94b8b6"
"iree": "0884b3ee6df3946c46d97855d9f78fa9b12b8f90",
"xla": "6c6744c8c5ded1c8e4af83296b508e55fe432e19",
"jax": "95e1e6d3efff585db5404b2fe5f116027af4acb2"
}

ORIGINS = {
Expand Down

0 comments on commit aecbf3b

Please sign in to comment.