Skip to content

Commit

Permalink
Use shared memory in CUDA kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
mszulc913 committed Mar 20, 2022
1 parent 2d269dc commit 0a4ef59
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ It contains the main ideas introduced in
The only layer available right now is `BinaryLinear` which performs
binarized version of `torch.nn.Linear`. The optimized forward pass kernel
is available via `use_xnor_kernel` argument.
The kernel implementation is quite naive and will be optimized in the future.


## Install
Expand Down
39 changes: 27 additions & 12 deletions cuda/bin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
#include <vector>

const auto ENCODE_SIZE = 32;
const auto BLOCK_SIZE = 32;


// See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory
template <typename scalar_t>
__global__ void bin_matmul_kernel(
torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> input,
Expand All @@ -18,16 +20,31 @@ __global__ void bin_matmul_kernel(
const int k,
int encoded_dim
) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
int col = blockIdx.y * blockDim.y + threadIdx.y;
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
int col = blockIdx.x * BLOCK_SIZE + threadIdx.x;

if (row < m && col < k) {
scalar_t val = 0;
scalar_t val = 0;

for (int i = 0; i < ceil(double(encoded_dim) / double(BLOCK_SIZE)); i++) {
// we're iterating over (BLOCK_SIZE x BLOCK_SIZE) sub-matrices and use share memory to
// minimize number of global memory accesses
__shared__ int32_t rowsSub[BLOCK_SIZE][BLOCK_SIZE];
__shared__ int32_t colsSub[BLOCK_SIZE][BLOCK_SIZE];

int encodedCol = i * BLOCK_SIZE + threadIdx.x;
int encodedRow = i * BLOCK_SIZE + threadIdx.y;

rowsSub[threadIdx.y][threadIdx.x] = (row < m && encodedCol < encoded_dim) ? encoded_rows[row][encodedCol] : 0;
colsSub[threadIdx.y][threadIdx.x] = (col < k && encodedRow < encoded_dim) ? encoded_cols[encodedRow][col] : 0;

#pragma unroll
for (int i = 0; i < encoded_dim; i++) {
val += __popc(encoded_rows[row][i] ^ encoded_cols[i][col]);
__syncthreads();
for (int j = 0; j < BLOCK_SIZE; j++) {
val += __popc(rowsSub[threadIdx.y][j] ^ colsSub[j][threadIdx.x]);
}
__syncthreads();
}

if (row < m && col < k) {
input[row][col] += n - 2 * val;
}
}
Expand All @@ -45,7 +62,6 @@ __global__ void encode_rows_kernel(

int32_t encoded_value = 0;

#pragma unroll
for (int i = 0; i < encode_size && col < n; i++) {
encoded_value = (encoded_value << 1) | (matrix[row][col] > 0);
col++;
Expand All @@ -65,7 +81,6 @@ __global__ void encode_cols_kernel(

int32_t encoded_value = 0;

#pragma unroll
for (int i = 0; i < encode_size && row < n; i++) {
encoded_value = (encoded_value << 1) | (matrix[row][col] > 0);
row++;
Expand Down Expand Up @@ -109,10 +124,10 @@ torch::Tensor bin_matmul_cuda(
);
}));

dim3 blockSize(32, 32);
dim3 blockSize(BLOCK_SIZE, BLOCK_SIZE);
dim3 blocksPerGrid(
ceil(double(m) / double(blockSize.x)),
ceil(double(k) / double(blockSize.y))
ceil(double(k) / double(blockSize.y)),
ceil(double(m) / double(blockSize.x))
);
AT_DISPATCH_FLOATING_TYPES(result.type(), "bin_matmul_kernel", ([&] {
bin_matmul_kernel<scalar_t><<<blocksPerGrid, blockSize>>>(
Expand Down

0 comments on commit 0a4ef59

Please sign in to comment.