Skip to content

Commit

Permalink
[Build] Split spmm.cu and sddmm.cu for building on Windows (#3789)
Browse files Browse the repository at this point in the history
* split files

* fix
  • Loading branch information
BarclayII authored Feb 28, 2022
1 parent 6e1c699 commit 3521fbe
Show file tree
Hide file tree
Showing 6 changed files with 863 additions and 809 deletions.
212 changes: 0 additions & 212 deletions src/array/cuda/sddmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,66 +10,6 @@
namespace dgl {
namespace aten {

#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
typedef cuda::binary::Dot<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
} \
} while (0)

#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \
if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \
{ __VA_ARGS__ } \
} else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \
} while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
} while (0)

/*!
* \brief CUDA implementation of g-SDDMM on Csr format.
*/
Expand All @@ -91,38 +31,6 @@ void SDDMMCsr(const std::string& op,
});
}

/*!
* \brief CUDA implementation of g-SDDMM on heterograph using
Csr format.
*/
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
CSRMatrix csr = vec_csr[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
}
});
});
});
}


/*!
* \brief CUDA implementation of g-SDDMM on Coo format.
Expand All @@ -146,40 +54,6 @@ void SDDMMCoo(const std::string& op,
}


/*!
* \brief CUDA implementation of g-SDDMM on heterograph using
Csr format.
*/
template <int XPU, typename IdType, int bits>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
COOMatrix coo = vec_coo[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}
});
});
});
}


template void SDDMMCsr<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
Expand All @@ -205,49 +79,6 @@ template void SDDMMCsr<kDLGPU, int64_t, 64>(
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);

template void SDDMMCsrHetero<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);

template void SDDMMCoo<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
Expand All @@ -273,48 +104,5 @@ template void SDDMMCoo<kDLGPU, int64_t, 64>(
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);

template void SDDMMCooHetero<kDLGPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);

} // namespace aten
} // namespace dgl
61 changes: 61 additions & 0 deletions src/array/cuda/sddmm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "functor.cuh"
#include "fp16.cuh"
#include "./utils.h"
#include "./functor.cuh"
#include "../selector.h"
#include "../../runtime/cuda/cuda_common.h"

Expand All @@ -22,6 +23,66 @@ using namespace cuda;
namespace aten {
namespace cuda {

#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
typedef cuda::binary::Dot<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
} \
} while (0)

#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \
if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \
{ __VA_ARGS__ } \
} else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \
} while (0)

#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
} while (0)

constexpr unsigned int full_mask = 0xffffffff;

/*!
Expand Down
Loading

0 comments on commit 3521fbe

Please sign in to comment.