Skip to content

Commit

Permalink
Merge branch 'main' into cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
giordano committed Jan 30, 2025
2 parents 264ec4b + 5c67d0a commit f845940
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 47 deletions.
166 changes: 128 additions & 38 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
#include "src/enzyme_ad/jax/Passes/Passes.h"
#include "llvm/Support/TargetSelect.h"
#include "shardy/dialect/sdy/ir/dialect.h"

#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
#include "stablehlo/dialect/ChloOps.h"
Expand All @@ -59,6 +58,7 @@
#include "tsl/profiler/lib/traceme.h"
#include "xla/tsl/profiler/rpc/client/capture_profile.h"
#include "xla/tsl/profiler/rpc/profiler_server.h"
#include "xla/python/profiler_utils.h"

#include "xla/python/ifrt/hlo/hlo_program.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
Expand All @@ -68,6 +68,10 @@

#include "llvm-c/TargetMachine.h"

// shardy
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/integrations/c/attributes.h"

// IFRT
#include "xla/python/ifrt/array.h"
#include "xla/python/ifrt/client.h"
Expand Down Expand Up @@ -274,13 +278,10 @@ extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id,
}

// xla/python/xla.cc 390
extern "C" PjRtClient *MakeGPUClient(int node_id, int num_nodes,
int *allowed_devices,
int num_allowed_devices,
double memory_fraction,
bool preallocate,
const char *platform_name,
const char **error) {
extern "C" PjRtClient *
MakeGPUClient(int node_id, int num_nodes, int *allowed_devices,
int num_allowed_devices, double memory_fraction, bool preallocate,
const char *platform_name, const char **error) {
GpuClientOptions options;
// options.kv_store = "etcd";
// options.allocator_config =
Expand Down Expand Up @@ -359,11 +360,11 @@ extern "C" PjRtClient *MakeTPUClient(const char *tpu_path, const char **error) {
LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error);
if (pluginLoad == nullptr)
return nullptr;

auto tpu_status = InitializePjrtPlugin("tpu", error);
if (tpu_status)
return nullptr;

RegisterProfiler(pluginLoad);
return GetCApiClient("TPU");
}

Expand Down Expand Up @@ -456,9 +457,10 @@ std::vector<int64_t> col_major(int64_t dim) {
return minor_to_major;
}

extern "C" void ReactantLLVMParseCommandLineOptions(int argc, const char *const *argv,
const char *Overview) {
llvm::cl::ParseCommandLineOptions(argc, argv, StringRef(Overview),
extern "C" void ReactantLLVMParseCommandLineOptions(int argc,
const char *const *argv,
const char *Overview) {
llvm::cl::ParseCommandLineOptions(argc, argv, StringRef(Overview),
&llvm::nulls());
}

Expand All @@ -478,9 +480,7 @@ extern "C" int32_t ReactantCudaDriverGetVersion() {
ReactantHandleCuResult(cuDriverGetVersion(&data));
return data;
}
extern "C" int32_t ReactantHermeticCudaGetVersion() {
return CUDA_VERSION;
}
extern "C" int32_t ReactantHermeticCudaGetVersion() { return CUDA_VERSION; }
#else
extern "C" int32_t ReactantCudaDriverGetVersion() { return 0; }
extern "C" int32_t ReactantHermeticCudaGetVersion() { return 0; }
Expand Down Expand Up @@ -533,6 +533,18 @@ extern "C" void BufferToHost(PjRtBuffer *buffer, void *data) {

extern "C" void FreeClient(PjRtClient *client) { delete client; }

extern "C" int64_t PjRtDeviceGetLocalDeviceId(PjRtDevice *device) {
return device->local_device_id().value();
}

extern "C" int64_t PjRtDeviceGetGlobalDeviceId(PjRtDevice *device) {
return device->global_device_id().value();
}

extern "C" int64_t PjRtDeviceGetLocalHardwareId(PjRtDevice *device) {
return device->local_hardware_id().value();
}

#include "xla/service/custom_call_target_registry.h"
extern "C" void RegisterCustomCallTarget(const char *name, void *address,
const char *platform) {
Expand Down Expand Up @@ -584,21 +596,28 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
/* Note that this */
extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
MlirModule cmod,
int device_ordinal,
int num_replicas,
int num_partitions,
bool use_shardy_partitioner) {
int *global_ordinals,
int num_global_ordinals) {
auto program =
std::make_unique<xla::ifrt::HloProgram>(cast<ModuleOp>(*unwrap(cmod)));

CompileOptions options;

if (device_ordinal >= 0) {
options.executable_build_options.set_device_ordinal(device_ordinal);
// https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601
int device_count = client->addressable_device_count();

options.executable_build_options.set_num_replicas(device_count);
options.executable_build_options.set_num_partitions(1);

xla::DeviceAssignment device_assignment(device_count, 1);
for (int64_t device_id = 0; device_id < num_global_ordinals; ++device_id) {
int ordinal = global_ordinals[device_id];
if (ordinal < 0) {
continue;
}
device_assignment(ordinal, 0) = device_id;
}
options.executable_build_options.set_num_replicas(num_replicas);
options.executable_build_options.set_num_partitions(num_partitions);
options.executable_build_options.set_use_shardy_partitioner(use_shardy_partitioner);
options.executable_build_options.set_device_assignment(device_assignment);

auto addressable_devices = client->addressable_devices();
if (!addressable_devices.empty()) {
Expand All @@ -609,8 +628,7 @@ extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
assert(device_ordinal < addressable_devices.size());
auto stats = addressable_devices[device_ordinal]->GetAllocatorStats();
if (stats.ok() && stats->bytes_limit) {
options.executable_build_options.set_device_memory_size(
*stats->bytes_limit);
options.executable_build_options.set_device_memory_size(*stats->bytes_limit);
}
}
auto exec =
Expand All @@ -627,12 +645,72 @@ extern "C" uint8_t FutureIsReady(FutureType *Future) {

extern "C" void FutureAwait(FutureType *Future) { Future->Await(); }

extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args,
PjRtBuffer **op_args, PjRtDevice *device,
uint8_t *is_arg_donatable, int num_results,
PjRtBuffer **op_results, uint8_t *futures,
FutureType **future_results) {
// Create a vector of PjRtBuffer* from the input array.
std::vector<PjRtBuffer *> argument_handles(op_args, op_args + num_args);

// Set up execution options.
ExecuteOptions options;
for (size_t i = 0; i < num_args; i++) {
if (!is_arg_donatable[i]) {
options.non_donatable_input_indices.insert(static_cast<int>(i));
}
}
options.untuple_result = true;

// Optional future to hold asynchronous execution results.
std::optional<PjRtFuture<>> returned_future;

auto results = MyValueOrThrow(
exec->ExecuteSharded(argument_handles,
device, options, returned_future, /*fill_future=*/true));

// Validate the number of results.
if (results.size() != num_results) {
llvm::errs() << "Error: results.size()=" << results.size()
<< " does not match num_results=" << num_results << "\n";
std::abort(); // Terminate if the number of results is incorrect.
}

// Handle futures if they are returned.
if (returned_future.has_value()) {
*futures = true;
for (size_t i = 0; i < num_results; i++) {
future_results[i] = new FutureType(*returned_future);
}
} else {
*futures = false;
}

// Release the results into the output array.
for (size_t i = 0; i < num_results; i++) {
op_results[i] = results[i].release();
}
}

extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
PjRtBuffer **op_args, uint8_t *is_arg_donatable,
int num_results, PjRtBuffer **op_results,
uint8_t *futures, FutureType **future_results) {
std::vector<std::vector<PjRtBuffer *>> argument_handles;
argument_handles.emplace_back(op_args, op_args + num_args);
auto client = exec->client();
int num_devices = client->addressable_device_count();

// Ensure argument_handles is structured as num_devices x num_args
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_devices);

// Distribute arguments across devices
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
argument_handles[device_idx].reserve(num_args);
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
// Assuming op_args is a flat array of size num_devices * num_args
// where arguments for each device are contiguous
argument_handles[device_idx].push_back(op_args[device_idx * num_args + arg_idx]);
}
}

ExecuteOptions options;

Expand All @@ -641,31 +719,43 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
options.non_donatable_input_indices.insert((int)i);
}
options.untuple_result = true;

std::optional<std::vector<FutureType>> returned_futures;
auto results = MyValueOrThrow(
exec->Execute(static_cast<absl::Span<const std::vector<PjRtBuffer *>>>(
argument_handles),
options, returned_futures));

assert(results.size() == 1);
assert(results.size() == num_devices);

if (results[0].size() != num_results) {
llvm::errs() << " results.size()=" << results.size()
<< " num_results=" << num_results << "\n";
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
if (results[device_idx].size() != num_results) {
llvm::errs() << " results[" << device_idx << "].size()=" << results[device_idx].size()
<< " num_results=" << num_results << "\n";
}
assert(results[device_idx].size() == num_results);
}
assert(results[0].size() == num_results);

// Handle returned futures
if (returned_futures) {
*futures = true;
assert(returned_futures->size() == num_results);
for (size_t i = 0; i < num_results; i++) {
future_results[i] = new FutureType((*returned_futures)[i]);
assert(returned_futures->size() == num_devices * num_results);
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
int flat_index = device_idx * num_results + result_idx;
future_results[flat_index] = new FutureType((*returned_futures)[flat_index]);
}
}
} else {
*futures = false;
}

for (size_t i = 0; i < num_results; i++) {
op_results[i] = results[0][i].release();
// Copy results into the output buffers
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
int flat_index = device_idx * num_results + result_idx;
op_results[flat_index] = results[device_idx][result_idx].release();
}
}
}

Expand Down
20 changes: 14 additions & 6 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ cc_library(
],

) + [
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
# "@xla//xla:xla.pb.cc",
"@xla//xla:xla_data.pb.cc",
Expand Down Expand Up @@ -438,7 +438,6 @@ cc_library(
"-Wl,-exported_symbol,_RegisterCustomCallTarget",
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler",
"-Wl,-exported_symbol,_ReactantThrowError",
"-Wl,-exported_symbol,_ReactantHandleCuResult",
"-Wl,-exported_symbol,_CreateProfilerSession",
Expand All @@ -450,8 +449,13 @@ cc_library(
"-Wl,-exported_symbol,_ProfilerActivityEnd",
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",
"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId",
"-Wl,-exported_symbol,_XLAExecuteSharded",
"-Wl,-exported_symbol,_ClientGetPlatformName",
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions"
"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler",
]}),
deps = [
"@enzyme//:EnzymeMLIR",
Expand Down Expand Up @@ -545,12 +549,15 @@ cc_library(
"@xla//xla/backends/profiler/cpu:host_tracer_impl",
"@xla//xla/backends/profiler/cpu:metadata_collector",
"@xla//xla/backends/profiler/cpu:metadata_utils",
"@xla//xla/backends/profiler/tpu:tpu_tracer",
"@xla//xla/python:profiler_utils",

"@tsl//tsl/platform:env_impl",
"@xla//xla/stream_executor:stream_executor_impl",
"@xla//xla/mlir/utils:type_util",
"@stablehlo//:stablehlo_capi_objects",
"@stablehlo//:chlo_capi_objects",
"@shardy//shardy/integrations/c:sdy_capi_objects",
"@com_google_absl//absl/hash:hash",
"@com_google_absl//absl/log:initialize",
"@com_google_absl//absl/log:globals",
Expand Down Expand Up @@ -832,12 +839,13 @@ genrule(
"@llvm-project//mlir:AsyncPassIncGen_filegroup",
"@llvm-project//mlir:GPUPassIncGen_filegroup",
"@stablehlo//:stablehlo/integrations/c/StablehloAttributes.h",
"@shardy//shardy/integrations/c:attributes.h",
"//:Project.toml",
"//:Manifest.toml",
"//:wrap.toml",
"//:missing_defs.jl",
"//:make.jl"
],
outs = ["libMLIR_h.jl"],
cmd = "$$JULIA \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$@\"",
cmd = "$$JULIA \"--project=$(location //:Project.toml)\" \"$(location //:make.jl)\" \"$(location @llvm-project//mlir:include/mlir-c/Bindings/Python/Interop.h)\" \"$(location @llvm-project//llvm:include/llvm-c/Support.h)\" \"$(locations @llvm-project//mlir:ConversionPassIncGen_filegroup)\" \"$(location @stablehlo//:stablehlo/integrations/c/StablehloAttributes.h)\" \"$(location @shardy//shardy/integrations/c:attributes.h)\" \"$@\"",
)
9 changes: 7 additions & 2 deletions deps/ReactantExtra/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ let options = deepcopy(options)

genarg = first(eachsplit(ARGS[3], " "))

gen_include_dir = joinpath(splitpath(genarg)[1:(end - 3)]...)
gen_include_dir = joinpath(splitpath(genarg)[1:(end - 4)]...)

hlo_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)
hlo_include_dir = joinpath(splitpath(ARGS[end - 2])[1:(end - 1)]...)

sdy_include_dir = joinpath(splitpath(ARGS[end - 1])[1:(end - 1)]...)

append!(
args,
Expand All @@ -33,6 +35,8 @@ let options = deepcopy(options)
gen_include_dir,
"-I",
hlo_include_dir,
"-I",
sdy_include_dir,
"-x",
"c++",
],
Expand All @@ -41,6 +45,7 @@ let options = deepcopy(options)
headers = [
detect_headers(include_dir, args, Dict(), endswith("Python/Interop.h"))...,
detect_headers(hlo_include_dir, args, Dict())...,
detect_headers(sdy_include_dir, args, Dict())...,
]

ctx = create_context(headers, args, options)
Expand Down
2 changes: 2 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ export default defineConfig({
{ text: "NVVM", link: "/api/nvvm" },
{ text: "TPU", link: "/api/tpu" },
{ text: "Triton", link: "/api/triton" },
{ text: "Shardy", link: "/api/shardy" },
],
},
{
Expand Down Expand Up @@ -140,6 +141,7 @@ export default defineConfig({
{ text: "NVVM", link: "/api/nvvm" },
{ text: "TPU", link: "/api/tpu" },
{ text: "Triton", link: "/api/triton" },
{ text: "Shardy", link: "/api/shardy" },
],
},
{
Expand Down
Loading

0 comments on commit f845940

Please sign in to comment.