Skip to content

Commit

Permalink
Extract matmul operation
Browse files Browse the repository at this point in the history
  • Loading branch information
mszulc913 committed Mar 17, 2022
1 parent aa1f78b commit 91b372b
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 58 deletions.
16 changes: 14 additions & 2 deletions binnet/functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional, Tuple, cast

import bin_linear_cuda
import bin_cuda
import torch
from torch.autograd import Function
from torch.autograd.function import FunctionCtx
Expand All @@ -20,7 +20,7 @@ def forward( # type: ignore
ctx.save_for_backward(x, weight, bias)
else:
ctx.save_for_backward(x, weight)
return bin_linear_cuda.forward(x, weight, bias)
return bin_cuda.bin_linear(x, weight, bias)

@staticmethod
def backward( # type: ignore
Expand Down Expand Up @@ -61,3 +61,15 @@ def backward( # type: ignore
out = torch.ones_like(x)
out[torch.abs(x) > 1] = 0
return grad_output * out


def bin_matmul(mat1: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""Binarized matrix product of two tensors.
.
Currently, only two-dimensional inputs are supported.
:param mat1: The first tensor to be multiplied.
:param mat2: The second tensor to be multiplied.
:return: Tensor with the result of the operation.
"""
return bin_cuda.bin_matmul(mat1, mat2)
33 changes: 33 additions & 0 deletions cuda/bin_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <torch/extension.h>
#include <vector>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")


torch::Tensor bin_matmul_cuda(const torch::Tensor mat1, const torch::Tensor mat2);


torch::Tensor bin_linear(const torch::Tensor input, const torch::Tensor weight, const c10::optional<torch::Tensor> bias_opt) {
CHECK_CUDA(input);
CHECK_CUDA(weight);
if (bias_opt.has_value()) {
CHECK_CUDA(bias_opt.value());
}
auto result = bin_matmul_cuda(input, weight.t());
if (bias_opt.has_value()) {
result.add_(bias_opt.value());
}
return result;
}

torch::Tensor bin_matmul(const torch::Tensor mat1, const torch::Tensor mat2) {
CHECK_CUDA(mat1);
CHECK_CUDA(mat2);

return bin_matmul_cuda(mat1, mat2);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bin_linear", &bin_linear, "Binary linear projection (CUDA)");
m.def("bin_matmul", &bin_matmul, "Binary matrix multiplication (CUDA)");
}
7 changes: 6 additions & 1 deletion cuda/bin_linear_cuda.pyi → cuda/bin_cuda.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ from typing import Optional

import torch

def forward(
def bin_linear(
x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor: ...


def bin_matmul(
mat1: torch.Tensor, mat2: torch.Tensor
) -> torch.Tensor: ...
39 changes: 16 additions & 23 deletions cuda/bin_linear_cuda_kernel.cu → cuda/bin_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const auto ENCODE_SIZE = 32;


template <typename scalar_t>
__global__ void bin_linear_kernel(
__global__ void bin_matmul_kernel(
torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> input,
const torch::PackedTensorAccessor<int32_t, 2, torch::RestrictPtrTraits, size_t> encoded_rows,
const torch::PackedTensorAccessor<int32_t, 2, torch::RestrictPtrTraits, size_t> encoded_cols,
Expand Down Expand Up @@ -74,42 +74,35 @@ __global__ void encode_cols_kernel(
}


torch::Tensor bin_linear_cuda(
const torch::Tensor input,
const torch::Tensor weight,
const c10::optional<torch::Tensor> bias_opt
torch::Tensor bin_matmul_cuda(
const torch::Tensor mat1,
const torch::Tensor mat2
) {
torch::Tensor input_expanded;
const int m = mat1.size(0);
const int n = mat1.size(1);
const int k = mat2.size(1);

const int m = input.size(0);
const int n = input.size(1);
const int k = weight.size(0);

auto result = input.new_zeros({m, k});
if (bias_opt.has_value()) {
result.add_(bias_opt.value());
}
auto result = mat1.new_zeros({m, k});

const int encoded_dim = ceil(double(n) / double(ENCODE_SIZE));

auto optionsEncode = torch::TensorOptions()
.device(weight.device())
.device(mat1.device())
.dtype(torch::kInt32);
auto encoded_rows = torch::zeros({m, encoded_dim}, optionsEncode);
AT_DISPATCH_FLOATING_TYPES(input.type(), "encode_rows_kernel", ([&] {
encode_rows_kernel<scalar_t><<<m, encoded_dim>>>(if (input.dim() == 2) {
input.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
AT_DISPATCH_FLOATING_TYPES(mat1.type(), "encode_rows_kernel", ([&] {
encode_rows_kernel<scalar_t><<<m, encoded_dim>>>(
mat1.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
encoded_rows.packed_accessor<int32_t, 2, torch::RestrictPtrTraits, size_t>(),
n,
ENCODE_SIZE
);
}));

auto encoded_cols = torch::zeros({encoded_dim, k}, optionsEncode);
auto weight_transposed = weight.t();
AT_DISPATCH_FLOATING_TYPES(weight_transposed.type(), "encode_cols_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES(mat2.type(), "encode_cols_kernel", ([&] {
encode_cols_kernel<scalar_t><<<k, encoded_dim>>>(
weight_transposed.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
mat2.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
encoded_cols.packed_accessor<int32_t, 2, torch::RestrictPtrTraits, size_t>(),
n,
ENCODE_SIZE
Expand All @@ -121,8 +114,8 @@ torch::Tensor bin_linear_cuda(
ceil(double(m) / double(blockSize.x)),
ceil(double(k) / double(blockSize.y))
);
AT_DISPATCH_FLOATING_TYPES(result.type(), "bin_linear_forward_kernel", ([&] {
bin_linear_kernel<scalar_t><<<blocksPerGrid, blockSize>>>(
AT_DISPATCH_FLOATING_TYPES(result.type(), "bin_matmul_kernel", ([&] {
bin_matmul_kernel<scalar_t><<<blocksPerGrid, blockSize>>>(
result.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
encoded_rows.packed_accessor<int32_t, 2, torch::RestrictPtrTraits, size_t>(),
encoded_cols.packed_accessor<int32_t, 2, torch::RestrictPtrTraits, size_t>(),
Expand Down
21 changes: 0 additions & 21 deletions cuda/bin_linear_cuda.cpp

This file was deleted.

8 changes: 4 additions & 4 deletions cuda/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
name="bin_linear_cuda",
name="bin_cuda",
ext_modules=[
CUDAExtension(
"bin_linear_cuda",
"bin_cuda",
[
"bin_linear_cuda.cpp",
"bin_linear_cuda_kernel.cu",
"bin_cuda.cpp",
"bin_cuda_kernel.cu",
],
),
],
Expand Down
39 changes: 32 additions & 7 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn.functional as F

from binnet.functional import QuantizeSignSTE, BinLinearXOR
from binnet.functional import QuantizeSignSTE, BinLinearXOR, bin_matmul


def test_quantize_ste_forward():
Expand Down Expand Up @@ -38,11 +38,10 @@ def test_quantize_ste_backward(x: torch.Tensor, expected_grad: torch.Tensor):
"x, weight, bias",
[
(torch.ones(2, 3), torch.ones(4, 3), None),
(torch.ones(5, 2, 3), torch.ones(4, 3), None),
(torch.ones(5, 2, 3), torch.ones(4, 3), torch.ones(4)),
(torch.ones(10, 100, 200), torch.ones(300, 200), torch.ones(300)),
(torch.ones(2, 3), torch.ones(4, 3), torch.ones(4)),
(torch.ones(100, 200), torch.ones(300, 200), torch.ones(300)),
(
torch.sign(torch.randn(10, 100, 200)),
torch.sign(torch.randn(100, 200)),
torch.sign(torch.randn(300, 200)),
torch.sign(torch.randn(300)),
),
Expand Down Expand Up @@ -70,7 +69,7 @@ def test_bin_linear_xor_forward(
"x, weight, bias",
[
(
torch.sign(torch.randn(10, 100, 200, requires_grad=True)),
torch.sign(torch.randn(100, 200, requires_grad=True)),
torch.sign(torch.randn(300, 200, requires_grad=True)),
None,
),
Expand All @@ -80,7 +79,7 @@ def test_bin_linear_xor_forward(
torch.sign(torch.randn(300, requires_grad=True)),
),
(
torch.sign(torch.randn(10, 100, 200, requires_grad=True)),
torch.sign(torch.randn(100, 200, requires_grad=True)),
torch.sign(torch.randn(300, 200, requires_grad=True)),
torch.sign(torch.randn(300, requires_grad=True)),
),
Expand Down Expand Up @@ -117,3 +116,29 @@ def test_bin_linear_xor_backward(
if bias is not None:
assert bias.grad.shape == expected_d_bias.shape
assert (bias.grad == expected_d_bias).all()


@pytest.mark.cuda
@pytest.mark.parametrize(
"mat1, mat2",
[
(torch.ones(2, 3), torch.ones(3, 4)),
(torch.ones(100, 200), torch.ones(200, 300)),
(
torch.sign(torch.randn(100, 200)),
torch.sign(torch.randn(200, 300)),
),
],
)
def test_bin_matmul(
mat1: torch.Tensor,
mat2: torch.Tensor,
):
mat1 = mat1.cuda()
mat2 = mat2.cuda()
expected_result = torch.matmul(mat1, mat2)

result = bin_matmul(mat1, mat2)

assert result.shape == expected_result.shape
assert (result == expected_result).all()

0 comments on commit 91b372b

Please sign in to comment.