Skip to content

Commit

Permalink
PR #22815: [PJRT C API] Ensure C Compliance for all C headers
Browse files Browse the repository at this point in the history
Imported from GitHub PR #22815

`<cstd*>` headers are C++ headers that wrap their `<std*.h>` counteparts in the std namespace and re-exports them as well.

It is meant to be consumed by C++ compilers, not C compilers.

Since this is a C API, this PR replaces usages of `<cstd*>` include statements by their C counterparts only for exported C api headers.

This PR supersedes #22082 and fixes it across the whole C API.
Copybara import of the project:

--
d2a1096 by Corentin Kerisit <[email protected]>:

[PJRT C API] Ensure C Compliance for all C headers

<cstd*> headers are C++ headers that wrap their <std*.h> counteparts
in the std namespace and re-exports them as well..

It is meant to be consumed by C++ compilers, not C compilers.

Since this is a C API, this PR replaces usages of <cstd*> include
statements by their C counterparts only for exported C api headers.

This PR supersedes #22082 and
fixes it across the whole C API.

--
f1f6eb6 by Corentin Kerisit <[email protected]>:

Add missing typedef when refering to structs

Merging this change closes #22815

COPYBARA_INTEGRATE_REVIEW=#22815 from cerisier:cerisir/fix-c-compatibility f1f6eb6
PiperOrigin-RevId: 728682193
  • Loading branch information
cerisier authored and Google-ML-Automation committed Feb 19, 2025
1 parent 6c41035 commit a41222c
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 33 deletions.
40 changes: 20 additions & 20 deletions xla/pjrt/c/pjrt_c_api_custom_partitioner_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#ifndef XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_
#define XLA_PJRT_C_PJRT_C_API_CUSTOM_PARTITIONER_EXTENSION_H_

#include <cstddef>
#include <cstdint>
#include <stddef.h>
#include <stdint.h>

#include "xla/pjrt/c/pjrt_c_api.h"

Expand All @@ -27,24 +27,24 @@ extern "C" {

#define PJRT_API_CUSTOM_PARTITIONER_EXTENSION_VERSION 1

struct JAX_CustomCallPartitioner_string {
typedef struct JAX_CustomCallPartitioner_string {
const char* data;
size_t size;
};
} JAX_CustomCallPartitioner_string;

struct JAX_CustomCallPartitioner_aval {
typedef struct JAX_CustomCallPartitioner_aval {
JAX_CustomCallPartitioner_string shape;
bool has_sharding;
JAX_CustomCallPartitioner_string sharding;
};
} JAX_CustomCallPartitioner_aval;

// General callback information containing api versions, the result error
// message and the cleanup function to free any temporary memory that is backing
// the results. Arguments are always owned by the caller, and results are owned
// by the cleanup_fn. These should never be used directly. Args and results
// should be serialized via the PopulateArgs, ReadArgs, PopulateResults,
// ConsumeResults functions defined below.
struct JAX_CustomCallPartitioner_version_and_error {
typedef struct JAX_CustomCallPartitioner_version_and_error {
int64_t api_version;
void* data; // out
// cleanup_fn cleans up any returned results. The caller must finish with all
Expand All @@ -53,9 +53,9 @@ struct JAX_CustomCallPartitioner_version_and_error {
bool has_error;
PJRT_Error_Code code; // out
JAX_CustomCallPartitioner_string error_msg; // out
};
} JAX_CustomCallPartitioner_version_and_error;

struct JAX_CustomCallPartitioner_Partition_Args {
typedef struct JAX_CustomCallPartitioner_Partition_Args {
JAX_CustomCallPartitioner_version_and_error header;

size_t num_args;
Expand All @@ -67,9 +67,9 @@ struct JAX_CustomCallPartitioner_Partition_Args {
JAX_CustomCallPartitioner_string mlir_module;
JAX_CustomCallPartitioner_string* args_sharding;
JAX_CustomCallPartitioner_string result_sharding;
};
} JAX_CustomCallPartitioner_Partition_Args;

struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args {
typedef struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args {
JAX_CustomCallPartitioner_version_and_error header;

size_t num_args;
Expand All @@ -79,32 +79,32 @@ struct JAX_CustomCallPartitioner_InferShardingFromOperands_Args {

bool has_result_sharding;
JAX_CustomCallPartitioner_string result_sharding;
};
} JAX_CustomCallPartitioner_InferShardingFromOperands_Args;

struct JAX_CustomCallPartitioner_PropagateUserSharding_Args {
typedef struct JAX_CustomCallPartitioner_PropagateUserSharding_Args {
JAX_CustomCallPartitioner_version_and_error header;

JAX_CustomCallPartitioner_string backend_config;

JAX_CustomCallPartitioner_string result_shape;

JAX_CustomCallPartitioner_string result_sharding; // inout
};
} JAX_CustomCallPartitioner_PropagateUserSharding_Args;

struct JAX_CustomCallPartitioner_Callbacks {
typedef struct JAX_CustomCallPartitioner_Callbacks {
int64_t version;
void* private_data;
void (*dtor)(JAX_CustomCallPartitioner_Callbacks* data);
void (*partition)(JAX_CustomCallPartitioner_Callbacks* data,
void (*dtor)(struct JAX_CustomCallPartitioner_Callbacks* data);
void (*partition)(struct JAX_CustomCallPartitioner_Callbacks* data,
JAX_CustomCallPartitioner_Partition_Args* args);
void (*infer_sharding)(
JAX_CustomCallPartitioner_Callbacks* data,
struct JAX_CustomCallPartitioner_Callbacks* data,
JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args);
void (*propagate_user_sharding)(
JAX_CustomCallPartitioner_Callbacks* data,
struct JAX_CustomCallPartitioner_Callbacks* data,
JAX_CustomCallPartitioner_PropagateUserSharding_Args* args);
bool can_side_effecting_have_replicated_sharding;
};
} JAX_CustomCallPartitioner_Callbacks;

struct PJRT_Register_Custom_Partitioner_Args {
size_t struct_size;
Expand Down
7 changes: 3 additions & 4 deletions xla/pjrt/c/pjrt_c_api_ffi_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ limitations under the License.
#define XLA_PJRT_C_PJRT_C_API_FFI_EXTENSION_H_

#include <stddef.h>

#include <cstdint>
#include <stdint.h>

#include "xla/pjrt/c/pjrt_c_api.h"

Expand Down Expand Up @@ -49,11 +48,11 @@ typedef PJRT_Error* PJRT_FFI_TypeID_Register(

// User-data that will be forwarded to the FFI handlers. Deleter is optional,
// and can be nullptr. Deleter will be called when the context is destroyed.
struct PJRT_FFI_UserData {
typedef struct PJRT_FFI_UserData {
int64_t type_id;
void* data;
void (*deleter)(void* data);
};
} PJRT_FFI_UserData;

struct PJRT_FFI_UserData_Add_Args {
size_t struct_size;
Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/c/pjrt_c_api_layouts_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#ifndef XLA_PJRT_C_PJRT_C_API_LAYOUTS_EXTENSION_H_
#define XLA_PJRT_C_PJRT_C_API_LAYOUTS_EXTENSION_H_

#include <cstddef>
#include <cstdint>
#include <stddef.h>
#include <stdint.h>

#include "xla/pjrt/c/pjrt_c_api.h"

Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_memory_descriptions_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#ifndef XLA_PJRT_C_PJRT_C_API_MEMORY_DESCRIPTIONS_EXTENSION_H_
#define XLA_PJRT_C_PJRT_C_API_MEMORY_DESCRIPTIONS_EXTENSION_H_

#include <cstddef>
#include <stddef.h>

#include "xla/pjrt/c/pjrt_c_api.h"

Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/c/pjrt_c_api_profiler_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#ifndef XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_
#define XLA_PJRT_C_PJRT_C_API_PROFILER_EXTENSION_H_

#include <cstddef>
#include <cstdint>
#include <stddef.h>
#include <stdint.h>

#include "xla/backends/profiler/plugin/profiler_c_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
Expand Down
3 changes: 1 addition & 2 deletions xla/pjrt/c/pjrt_c_api_stream_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ limitations under the License.
#define XLA_PJRT_C_PJRT_C_API_STREAM_EXTENSION_H_

#include <stddef.h>

#include <cstdint>
#include <stdint.h>

#include "xla/pjrt/c/pjrt_c_api.h"

Expand Down
4 changes: 2 additions & 2 deletions xla/pjrt/c/pjrt_c_api_triton_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License.
#ifndef XLA_PJRT_C_PJRT_C_API_TRITON_EXTENSION_H_
#define XLA_PJRT_C_PJRT_C_API_TRITON_EXTENSION_H_

#include <cstddef>
#include <cstdint>
#include <stddef.h>
#include <stdint.h>

#include "xla/pjrt/c/pjrt_c_api.h"

Expand Down

0 comments on commit a41222c

Please sign in to comment.