-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbin_cuda_kernel.cu
145 lines (119 loc) · 4.76 KB
/
bin_cuda_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
const auto ENCODE_SIZE = 32;
const auto BLOCK_SIZE = 32;
// See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory
template <typename scalar_t>
__global__ void bin_matmul_kernel(
torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> input,
const torch::PackedTensorAccessor<int32_t, 2, torch::RestrictPtrTraits, size_t> encoded_rows,
const torch::PackedTensorAccessor<int32_t, 2, torch::RestrictPtrTraits, size_t> encoded_cols,
const int m,
const int n,
const int k,
int encoded_dim
) {
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
int col = blockIdx.x * BLOCK_SIZE + threadIdx.x;
scalar_t val = 0;
for (int i = 0; i < ceil(double(encoded_dim) / double(BLOCK_SIZE)); i++) {
// we're iterating over (BLOCK_SIZE x BLOCK_SIZE) sub-matrices and use share memory to
// minimize number of global memory accesses
__shared__ int32_t rowsSub[BLOCK_SIZE][BLOCK_SIZE];
__shared__ int32_t colsSub[BLOCK_SIZE][BLOCK_SIZE];
int encodedCol = i * BLOCK_SIZE + threadIdx.x;
int encodedRow = i * BLOCK_SIZE + threadIdx.y;
rowsSub[threadIdx.y][threadIdx.x] = (row < m && encodedCol < encoded_dim) ? encoded_rows[row][encodedCol] : 0;
colsSub[threadIdx.y][threadIdx.x] = (col < k && encodedRow < encoded_dim) ? encoded_cols[encodedRow][col] : 0;
__syncthreads();
for (int j = 0; j < BLOCK_SIZE; j++) {
val += __popc(rowsSub[threadIdx.y][j] ^ colsSub[j][threadIdx.x]);
}
__syncthreads();
}
if (row < m && col < k) {
input[row][col] += n - 2 * val;
}
}
template <typename scalar_t>
__global__ void encode_rows_kernel(
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> matrix,
torch::PackedTensorAccessor<int32_t, 2, torch::RestrictPtrTraits, size_t> result,
int n,
int encode_size
) {
int col = threadIdx.x * encode_size;
int row = blockIdx.x;
int32_t encoded_value = 0;
for (int i = 0; i < encode_size && col < n; i++) {
encoded_value = (encoded_value << 1) | (matrix[row][col] > 0);
col++;
}
result[row][threadIdx.x] = encoded_value;
}
template <typename scalar_t>
__global__ void encode_cols_kernel(
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> matrix,
torch::PackedTensorAccessor<int32_t, 2, torch::RestrictPtrTraits, size_t> result,
int n,
int encode_size
) {
int col = blockIdx.x;
int row = threadIdx.x * encode_size;
int32_t encoded_value = 0;
for (int i = 0; i < encode_size && row < n; i++) {
encoded_value = (encoded_value << 1) | (matrix[row][col] > 0);
row++;
}
result[threadIdx.x][col] = encoded_value;
}
torch::Tensor bin_matmul_cuda(
const torch::Tensor mat1,
const torch::Tensor mat2
) {
const int m = mat1.size(0);
const int n = mat1.size(1);
const int k = mat2.size(1);
auto result = mat1.new_zeros({m, k});
const int encoded_dim = ceil(double(n) / double(ENCODE_SIZE));
auto optionsEncode = torch::TensorOptions()
.device(mat1.device())
.dtype(torch::kInt32);
auto encoded_rows = torch::zeros({m, encoded_dim}, optionsEncode);
AT_DISPATCH_FLOATING_TYPES(mat1.type(), "encode_rows_kernel", ([&] {
encode_rows_kernel<scalar_t><<<m, encoded_dim>>>(
mat1.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
encoded_rows.packed_accessor<int32_t, 2, torch::RestrictPtrTraits, size_t>(),
n,
ENCODE_SIZE
);
}));
auto encoded_cols = torch::zeros({encoded_dim, k}, optionsEncode);
AT_DISPATCH_FLOATING_TYPES(mat2.type(), "encode_cols_kernel", ([&] {
encode_cols_kernel<scalar_t><<<k, encoded_dim>>>(
mat2.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
encoded_cols.packed_accessor<int32_t, 2, torch::RestrictPtrTraits, size_t>(),
n,
ENCODE_SIZE
);
}));
dim3 blockSize(BLOCK_SIZE, BLOCK_SIZE);
dim3 blocksPerGrid(
ceil(double(k) / double(blockSize.y)),
ceil(double(m) / double(blockSize.x))
);
AT_DISPATCH_FLOATING_TYPES(result.type(), "bin_matmul_kernel", ([&] {
bin_matmul_kernel<scalar_t><<<blocksPerGrid, blockSize>>>(
result.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
encoded_rows.packed_accessor<int32_t, 2, torch::RestrictPtrTraits, size_t>(),
encoded_cols.packed_accessor<int32_t, 2, torch::RestrictPtrTraits, size_t>(),
m,
n,
k,
encoded_dim
);
}));
return result;
}