Skip to content

Commit

Permalink
[SYCL] Allow alignment property to be used for group load/store
Browse files Browse the repository at this point in the history
It makes possible to provide alignment<value> property to the
load/store operations indicating the alignment of the pointer.
It will allow to avoid expensive dynamic alignment checks.
  • Loading branch information
againull committed Feb 4, 2025
1 parent 29044ff commit 1e118e5
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,23 @@ so the implementation can rely on `get_max_local_range()` range size:

If partition is uneven the behavior is undefined.

== Alignment

The following property can be used to provide an alignment of the pointer.
It allows to avoid dynamic alignment check.

```c++
namespace sycl::ext::oneapi::experimental {
struct alignment_key {
template <int K>
using value_t = property_value<alignment_key, std::integral_constant<int, K>>;
};

template<int K>
inline constexpr alignment_key::value_t<K> alignment;
} // namespace sycl::ext::oneapi::experimental
```

== Usage Example

Example shows the simplest case without local memory usage of blocked load
Expand Down Expand Up @@ -472,7 +489,7 @@ q.submit([&](sycl::handler& cgh) {
auto offset = g.get_group_id(0) * g.get_local_range(0) *
items_per_thread;
auto props = sycl_exp::properties{sycl_exp::contiguous_memory};
auto props = sycl_exp::properties{sycl_exp::contiguous_memory, sycl_exp::alignment<16>};
sycl_exp::group_load(g, input + offset, sycl::span{ data }, props);
Expand Down
255 changes: 145 additions & 110 deletions sycl/include/sycl/ext/oneapi/experimental/group_load_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,27 @@ struct BlockTypeInfo<BlockInfo<IteratorT, ElementsPerWorkItem, Blocked>> {
detail::ConvertToOpenCLType_t<vec<block_type, BlockInfoTy::num_blocks>>>;
};

template <typename IteratorT, int RequiredAlign, typename Properties,
typename = void>
struct is_statically_known_aligned {
using value_type =
remove_decoration_t<typename std::iterator_traits<IteratorT>::value_type>;
static constexpr bool value = (alignof(value_type) >= RequiredAlign);
};

template <typename IteratorT, int RequiredAlign, typename Properties>
struct is_statically_known_aligned<IteratorT, RequiredAlign, Properties,
typename std::enable_if_t<
Properties::template has_property<alignment_key>()>> {
using value_type =
remove_decoration_t<typename std::iterator_traits<IteratorT>::value_type>;

static constexpr bool value =
(Properties::template get_property<alignment_key>().value >=
RequiredAlign) ||
(alignof(value_type) >= RequiredAlign);
};

// Returns either a pointer decorated with the deduced address space, suitable
// to use in a block read/write builtin, or nullptr if some legality conditions
// aren't satisfied. If deduced address space is generic then returned pointer
Expand Down Expand Up @@ -221,10 +242,6 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
// compiler to optimize the IR further.
__builtin_assume(iter != nullptr);

// No early return as that would mess up return type deduction.
bool is_aligned = alignof(value_type) >= RequiredAlign ||
reinterpret_cast<uintptr_t>(iter) % RequiredAlign == 0;

using block_pointer_type =
typename BlockTypeInfo<BlkInfo>::block_pointer_type;

Expand All @@ -235,7 +252,13 @@ auto get_block_op_ptr(IteratorT iter, [[maybe_unused]] Properties props) {
deduced_address_space ==
access::address_space::global_space ||
deduced_address_space == access::address_space::local_space) {
return is_aligned ? reinterpret_cast<block_pointer_type>(iter) : nullptr;
if constexpr (is_statically_known_aligned<IteratorT, RequiredAlign, Properties>::value) {
return reinterpret_cast<block_pointer_type>(iter);
} else {
return reinterpret_cast<uintptr_t>(iter) % RequiredAlign == 0
? reinterpret_cast<block_pointer_type>(iter)
: nullptr;
}
} else {
return nullptr;
}
Expand Down Expand Up @@ -266,79 +289,85 @@ group_load(Group g, InputIteratorT in_ptr,
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
return group_load(g, in_ptr, out, use_naive{});
} else {
auto ptr =
detail::get_block_op_ptr<4 /* load align */, ElementsPerWorkItem>(
in_ptr, props);
if (!ptr)
return group_load(g, in_ptr, out, use_naive{});
constexpr int RequiredAlign = 4;
auto ptr = detail::get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(
in_ptr, props);

if constexpr (!std::is_same_v<std::nullptr_t, decltype(ptr)>) {
// Do optimized load.
using value_type = remove_decoration_t<
typename std::iterator_traits<InputIteratorT>::value_type>;
using block_info = typename detail::BlockTypeInfo<
detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
static constexpr auto deduced_address_space =
block_info::deduced_address_space;
using block_op_type = typename block_info::block_op_type;

if constexpr (deduced_address_space ==
access::address_space::local_space &&
!props.template has_property<
detail::native_local_block_io_key>())
if constexpr (detail::is_statically_known_aligned<InputIteratorT, RequiredAlign,
Properties>::value) {
if constexpr (std::is_same_v<std::nullptr_t, decltype(ptr)>) {
return group_load(g, in_ptr, out, use_naive{});

block_op_type load;
if constexpr (deduced_address_space ==
access::address_space::generic_space) {
if (auto local_ptr = detail::dynamic_address_cast<
access::address_space::local_space>(ptr)) {
if constexpr (props.template has_property<
detail::native_local_block_io_key>())
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(local_ptr);
else
return group_load(g, in_ptr, out, use_naive{});
} else if (auto global_ptr = detail::dynamic_address_cast<
access::address_space::global_space>(ptr)) {
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(global_ptr);
} else {
return group_load(g, in_ptr, out, use_naive{});
}
} else {
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
}
} else {
if (!ptr)
return group_load(g, in_ptr, out, use_naive{});
}

// Do optimized load.
using value_type = remove_decoration_t<
typename std::iterator_traits<InputIteratorT>::value_type>;
using block_info = typename detail::BlockTypeInfo<
detail::BlockInfo<InputIteratorT, ElementsPerWorkItem, blocked>>;
static constexpr auto deduced_address_space =
block_info::deduced_address_space;
using block_op_type = typename block_info::block_op_type;

if constexpr (deduced_address_space == access::address_space::local_space &&
!props.template has_property<
detail::native_local_block_io_key>())
return group_load(g, in_ptr, out, use_naive{});

// TODO: accessor_iterator's value_type is weird, so we need
// `std::remove_const_t` below:
//
// static_assert(
// std::is_same_v<
// typename std::iterator_traits<
// sycl::detail::accessor_iterator<const int, 1>>::value_type,
// const int>);
//
// yet
//
// static_assert(
// std::is_same_v<
// typename std::iterator_traits<const int *>::value_type, int>);

if constexpr (std::is_same_v<std::remove_const_t<value_type>, OutputT>) {
static_assert(sizeof(load) == out.size_bytes());
sycl::detail::memcpy_no_adl(out.begin(), &load, out.size_bytes());
block_op_type load;
if constexpr (deduced_address_space ==
access::address_space::generic_space) {
if (auto local_ptr =
detail::dynamic_address_cast<access::address_space::local_space>(
ptr)) {
if constexpr (props.template has_property<
detail::native_local_block_io_key>())
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(local_ptr);
else
return group_load(g, in_ptr, out, use_naive{});
} else if (auto global_ptr = detail::dynamic_address_cast<
access::address_space::global_space>(ptr)) {
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(global_ptr);
} else {
std::remove_const_t<value_type> values[ElementsPerWorkItem];
static_assert(sizeof(load) == sizeof(values));
sycl::detail::memcpy_no_adl(values, &load, sizeof(values));

// Note: can't `memcpy` directly into `out` because that might bypass
// an implicit conversion required by the specification.
for (int i = 0; i < ElementsPerWorkItem; ++i)
out[i] = values[i];
return group_load(g, in_ptr, out, use_naive{});
}
} else {
load = __spirv_SubgroupBlockReadINTEL<block_op_type>(ptr);
}

return;
// TODO: accessor_iterator's value_type is weird, so we need
// `std::remove_const_t` below:
//
// static_assert(
// std::is_same_v<
// typename std::iterator_traits<
// sycl::detail::accessor_iterator<const int, 1>>::value_type,
// const int>);
//
// yet
//
// static_assert(
// std::is_same_v<
// typename std::iterator_traits<const int *>::value_type, int>);

if constexpr (std::is_same_v<std::remove_const_t<value_type>, OutputT>) {
static_assert(sizeof(load) == out.size_bytes());
std::memcpy(out.begin(), &load, out.size_bytes());
} else {
std::remove_const_t<value_type> values[ElementsPerWorkItem];
static_assert(sizeof(load) == sizeof(values));
std::memcpy(values, &load, sizeof(values));

// Note: can't `memcpy` directly into `out` because that might bypass
// an implicit conversion required by the specification.
for (int i = 0; i < ElementsPerWorkItem; ++i)
out[i] = values[i];
}

return;
}
}

Expand All @@ -365,55 +394,61 @@ group_store(Group g, const span<InputT, ElementsPerWorkItem> in,
} else if constexpr (!std::is_same_v<Group, sycl::sub_group>) {
return group_store(g, in, out_ptr, use_naive{});
} else {
auto ptr =
detail::get_block_op_ptr<16 /* store align */, ElementsPerWorkItem>(
out_ptr, props);
if (!ptr)
return group_store(g, in, out_ptr, use_naive{});
constexpr int RequiredAlign = 16;
auto ptr = detail::get_block_op_ptr<RequiredAlign, ElementsPerWorkItem>(
out_ptr, props);

if constexpr (!std::is_same_v<std::nullptr_t, decltype(ptr)>) {
using block_info = typename detail::BlockTypeInfo<
detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
static constexpr auto deduced_address_space =
block_info::deduced_address_space;
if constexpr (deduced_address_space ==
access::address_space::local_space &&
!props.template has_property<
detail::native_local_block_io_key>())
if constexpr (detail::is_statically_known_aligned<OutputIteratorT, RequiredAlign,
Properties>::value) {
if constexpr (std::is_same_v<std::nullptr_t, decltype(ptr)>) {
return group_store(g, in, out_ptr, use_naive{});
}
} else {
if (!ptr)
return group_store(g, in, out_ptr, use_naive{});
}

using block_info = typename detail::BlockTypeInfo<
detail::BlockInfo<OutputIteratorT, ElementsPerWorkItem, blocked>>;
static constexpr auto deduced_address_space =
block_info::deduced_address_space;
if constexpr (deduced_address_space == access::address_space::local_space &&
!props.template has_property<
detail::native_local_block_io_key>())
return group_store(g, in, out_ptr, use_naive{});

// Do optimized store.
std::remove_const_t<remove_decoration_t<
typename std::iterator_traits<OutputIteratorT>::value_type>>
values[ElementsPerWorkItem];
// Do optimized store.
std::remove_const_t<remove_decoration_t<
typename std::iterator_traits<OutputIteratorT>::value_type>>
values[ElementsPerWorkItem];

for (int i = 0; i < ElementsPerWorkItem; ++i) {
// Including implicit conversion.
values[i] = in[i];
}
for (int i = 0; i < ElementsPerWorkItem; ++i) {
// Including implicit conversion.
values[i] = in[i];
}

using block_op_type = typename block_info::block_op_type;
if constexpr (deduced_address_space ==
access::address_space::generic_space) {
if (auto local_ptr = detail::dynamic_address_cast<
access::address_space::local_space>(ptr)) {
if constexpr (props.template has_property<
detail::native_local_block_io_key>())
__spirv_SubgroupBlockWriteINTEL(
local_ptr, sycl::bit_cast<block_op_type>(values));
else
return group_store(g, in, out_ptr, use_naive{});
} else if (auto global_ptr = detail::dynamic_address_cast<
access::address_space::global_space>(ptr)) {
using block_op_type = typename block_info::block_op_type;
if constexpr (deduced_address_space ==
access::address_space::generic_space) {
if (auto local_ptr =
detail::dynamic_address_cast<access::address_space::local_space>(
ptr)) {
if constexpr (props.template has_property<
detail::native_local_block_io_key>())
__spirv_SubgroupBlockWriteINTEL(
global_ptr, sycl::bit_cast<block_op_type>(values));
} else {
local_ptr, sycl::bit_cast<block_op_type>(values));
else
return group_store(g, in, out_ptr, use_naive{});
}
} else {
__spirv_SubgroupBlockWriteINTEL(ptr,
} else if (auto global_ptr = detail::dynamic_address_cast<
access::address_space::global_space>(ptr)) {
__spirv_SubgroupBlockWriteINTEL(global_ptr,
sycl::bit_cast<block_op_type>(values));
} else {
return group_store(g, in, out_ptr, use_naive{});
}
} else {
__spirv_SubgroupBlockWriteINTEL(ptr,
sycl::bit_cast<block_op_type>(values));
}
}
}
Expand Down
Loading

0 comments on commit 1e118e5

Please sign in to comment.