Skip to content

Commit

Permalink
polish the implementations.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Jan 27, 2025
1 parent 64d9d78 commit 0b7f1d7
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 23 deletions.
18 changes: 9 additions & 9 deletions benchmarks/cpp/g2s_copy/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
|Shape|Warp Layout|tilefusion(ms)|cutlass(ms)|Ratio|
|:---|:---:|:---:|:---:|:---:|
|RowMajor(64, 64)|(1, 1)|0.05042|0.05069|0.9947|
|RowMajor(64, 64)|(2, 2)|0.05326|0.05095|1.045|
|RowMajor(64, 64)|(2, 4)|0.07204|0.05219|1.38|
|RowMajor(128, 128)|(1, 1)|0.1395|0.1541|0.9057|
|RowMajor(128, 128)|(2, 2)|0.1352|0.134|1.009|
|RowMajor(128, 128)|(2, 4)|0.1437|0.1383|1.039|
|RowMajor(128, 256)|(1, 1)|0.2403|0.3694|0.6505|
|RowMajor(128, 256)|(2, 2)|0.2468|0.2457|1.004|
|RowMajor(128, 256)|(2, 4)|0.2529|0.2509|1.008|
|RowMajor(64, 64)|(1, 1)|0.05039|0.0506|0.996|
|RowMajor(64, 64)|(2, 2)|0.0531|0.05082|1.045|
|RowMajor(64, 64)|(2, 4)|0.07193|0.05196|1.384|
|RowMajor(128, 128)|(1, 1)|0.1396|0.1539|0.9074|
|RowMajor(128, 128)|(2, 2)|0.1353|0.1339|1.01|
|RowMajor(128, 128)|(2, 4)|0.1435|0.138|1.04|
|RowMajor(128, 256)|(1, 1)|0.2402|0.3695|0.6501|
|RowMajor(128, 256)|(2, 2)|0.2468|0.2462|1.002|
|RowMajor(128, 256)|(2, 4)|0.2528|0.251|1.007|
7 changes: 3 additions & 4 deletions benchmarks/cpp/g2s_copy/cutlass_copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@

#pragma once

#include "cell/sync.hpp"
#include "cutlass/copy.cuh"

#include <cute/swizzle.hpp>
#include <cute/tensor.hpp>

using namespace cute;
using namespace tilefusion::cell;
using namespace benchmarks;

namespace {
// NOTE: The current implementation of Loader/Storer supports only
// half-precision (FP16) RowMajor data tiles. It is not implemented for other
// data types or memory layouts. Be cautious when using it for other cases.

template <typename Element, //
const int kRows, const int kCols, //
const int kWarpRows, const int kWarpCols>
Expand Down Expand Up @@ -132,7 +131,7 @@ __global__ void cutlass_g2s_data_transfer(const Element* src, Element* dst) {
for (int k = 0; k < kRepeat; ++k) {
loader(src, buf);

__copy_async();
cutlass_wrapper::__copy_async();
__syncthreads();

storer(buf, dst);
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/g2s_copy/tilefusion_copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

#include "cell/sync.hpp"
#include "cell/copy/sync.hpp"

using namespace tilefusion::cell;

Expand All @@ -20,7 +20,7 @@ __global__ void g2s_data_transfer(const Element* src_ptr, Element* dst_ptr,

for (int i = 0; i < kRepeat; ++i) {
loader(src, inter);
__copy_async();
copy::__copy_async();
__syncthreads();

storer(inter, dst);
Expand Down
1 change: 1 addition & 0 deletions include/cell/copy/mod.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
#include "cell/copy/global_to_shared.hpp"
#include "cell/copy/register.hpp"
#include "cell/copy/shared_to_register.hpp"
#include "cell/copy/sync.hpp"
#include "cell/copy/vectorize.hpp"
#include "cell/copy/warp.hpp"
8 changes: 3 additions & 5 deletions include/cell/sync.hpp → include/cell/copy/sync.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "cuda_utils.hpp"

namespace tilefusion::cell {
namespace tilefusion::cell::copy {

template <int N>
DEVICE void wait_group() {
Expand All @@ -15,15 +15,13 @@ DEVICE void wait_group() {
}

DEVICE void commit_copy_group() {
// FIXME(ying): make the implementation cutlass-independent.
#if defined(CP_ASYNC_SM80_ENABLED)

cute::cp_async_fence();
asm volatile("cp.async.commit_group;\n" ::);
#endif
}

DEVICE void __copy_async() {
commit_copy_group();
wait_group<0>();
}
} // namespace tilefusion::cell
} // namespace tilefusion::cell::copy
1 change: 0 additions & 1 deletion include/cell/mod.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include "cell/compute/mod.hpp"
#include "cell/copy/mod.hpp"
#include "cell/sync.hpp"
#include "cell/warp.hpp"
#include "traits/base.hpp"
#include "types/mod.hpp"
2 changes: 1 addition & 1 deletion tests/cpp/cell/test_g2s_load.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ __global__ void copy_g2s(const Element* src_ptr, Element* dst_ptr,
SrcTile dst(dst_ptr); // global memory tile

loader(src, inter);
__copy_async();
copy::__copy_async();
__syncthreads();

storer(inter, dst);
Expand Down
1 change: 0 additions & 1 deletion tests/cpp/cell/test_swizzled_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

#include "cell/copy/mod.hpp"
#include "cell/sync.hpp"
#include "common/test_utils.hpp"
#include "types/mod.hpp"

Expand Down

0 comments on commit 0b7f1d7

Please sign in to comment.