Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjbrowning committed Jan 12, 2024
1 parent 59e3eb3 commit 355b776
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 30 deletions.
2 changes: 1 addition & 1 deletion sphericart-torch/src/autograd.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "sphericart/autograd.hpp"

#include "cuda.hpp"
#include "cuda_base.hpp"
#include "sphericart.hpp"
#include "sphericart/torch.hpp"
#include "sphericart/torch_cuda_wrapper.hpp"
Expand Down
10 changes: 5 additions & 5 deletions sphericart-torch/src/torch_cuda_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#define _SPHERICART_INTERNAL_IMPLEMENTATION // gives us access to
// templates/macros
#include "cuda.hpp"
#include "cuda_base.hpp"
#include "sphericart.hpp"

#define CHECK_CUDA(x) \
Expand Down Expand Up @@ -39,7 +39,7 @@ std::vector<torch::Tensor> sphericart_torch::spherical_harmonics_cuda(
torch::TensorOptions().dtype(xyz.dtype()).device(xyz.device()));

torch::Tensor d_sph;
if (xyz.requires_grad() || gradients) {
if (gradients) {
d_sph = torch::empty(
{xyz.size(0), 3, n_total},
torch::TensorOptions().dtype(xyz.dtype()).device(xyz.device()));
Expand All @@ -52,7 +52,7 @@ std::vector<torch::Tensor> sphericart_torch::spherical_harmonics_cuda(

torch::Tensor hess_sph;

if (xyz.requires_grad() && hessian) {
if (hessian) {
hess_sph = torch::empty(
{xyz.size(0), 3, 3, n_total},
torch::TensorOptions().dtype(xyz.dtype()).device(xyz.device()));
Expand All @@ -76,14 +76,14 @@ std::vector<torch::Tensor> sphericart_torch::spherical_harmonics_cuda(
sphericart::cuda::spherical_harmonics_cuda_base<double>(
xyz.data_ptr<double>(), xyz.size(0), prefactors.data_ptr<double>(),
prefactors.size(0), l_max, normalize, GRID_DIM_X, GRID_DIM_Y,
xyz.requires_grad(), gradients, hessian, sph.data_ptr<double>(),
gradients, hessian, sph.data_ptr<double>(),
d_sph.data_ptr<double>(), hess_sph.data_ptr<double>());
break;
case torch::ScalarType::Float:
sphericart::cuda::spherical_harmonics_cuda_base<float>(
xyz.data_ptr<float>(), xyz.size(0), prefactors.data_ptr<float>(),
prefactors.size(0), l_max, normalize, GRID_DIM_X, GRID_DIM_Y,
xyz.requires_grad(), gradients, hessian, sph.data_ptr<float>(),
gradients, hessian, sph.data_ptr<float>(),
d_sph.data_ptr<float>(), hess_sph.data_ptr<float>());
break;
}
Expand Down
14 changes: 5 additions & 9 deletions sphericart/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ set(COMMON_SOURCES
"include/sphericart.h"
)

#add_library(sphericart
# "src/sphericart.cpp"
# "src/sphericart-capi.cpp"
# "include/sphericart.hpp"
# "include/sphericart.h"
#)

# Find CUDA
include(CheckLanguage)
check_language(CUDA)
Expand All @@ -60,13 +53,16 @@ if (CMAKE_CUDA_COMPILER AND NOT SPHERICART_ENABLE_CUDA)
endif()

# Append the relevant CUDA files to sources
list(APPEND COMMON_SOURCES "include/cuda.hpp")
list(APPEND COMMON_SOURCES "include/cuda_base.hpp")
list(APPEND COMMON_SOURCES "include/sphericart_cuda.hpp")

if (CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA)
list(APPEND COMMON_SOURCES "src/cuda_base.cu")
list(APPEND COMMON_SOURCES "src/sphericart_cuda.cu")
else()
list(APPEND COMMON_SOURCES "src/cuda_stub.cpp")
list(APPEND COMMON_SOURCES "src/sphericart_cuda_stub.cpp")
endif()
message("Value of SPHERICART_ENABLE_CUDA: ${SPHERICART_ENABLE_CUDA}")

add_library(sphericart ${COMMON_SOURCES})

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef SPHERICART_CUDA_HPP
#define SPHERICART_CUDA_HPP
#ifndef SPHERICART_CUDA_BASE_HPP
#define SPHERICART_CUDA_BASE_HPP

#include "sphericart.hpp"

Expand Down Expand Up @@ -55,7 +55,7 @@ void spherical_harmonics_cuda_base(
const scalar_t *__restrict__ xyz, const int nedges,
const scalar_t *__restrict__ prefactors, const int nprefactors,
const int64_t l_max, const bool normalize, const int64_t GRID_DIM_X,
const int64_t GRID_DIM_Y, const bool xyz_requires_grad,
const int64_t GRID_DIM_Y,
const bool gradients, const bool hessian, scalar_t *__restrict__ sph,
scalar_t *__restrict__ dsph, scalar_t *__restrict__ ddsph);

Expand Down
61 changes: 61 additions & 0 deletions sphericart/include/sphericart_cuda.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/** \file sphericart_cuda.hpp
* Defines the CUDA API for `sphericart`.
*/

#ifndef SPHERICART_CUDA_HPP
#define SPHERICART_CUDA_HPP

#include "sphericart.hpp"

namespace sphericart {

namespace cuda {
/**
* A spherical harmonics calculator.
*
* It handles initialization of the prefactors upon initialization and it
* stores the buffers that are necessary to compute the spherical harmonics
* efficiently.
*/
template <typename T> class SphericalHarmonics {
public:
/** Initialize the SphericalHarmonics class setting maximum degree and
* normalization
*
* @param l_max
* The maximum degree of the spherical harmonics to be calculated.
* @param normalized
* If `false` (default) computes the scaled spherical harmonics, which
* are homogeneous polynomials in the Cartesian coordinates of the input
* points. If `true`, computes the normalized spherical harmonics that are
* evaluated on the unit sphere. In practice, this simply computes the
* scaled harmonics at the normalized coordinates \f$(x/r, y/r, z/r)\f$, and
* adapts the derivatives accordingly.
*/
SphericalHarmonics(size_t l_max, bool normalized = false);

/* @cond */
~SphericalHarmonics();
/* @endcond */

/** Computes the spherical harmonics for one or more 3D points, using
* pre-allocated device-side pointers
*
* @param xyz todo docs
* @param sph todo docs
*/
void compute(const T *xyz, size_t nsamples, bool compute_with_gradients,
bool compute_with_hessian, size_t GRID_DIM_X,
size_t GRID_DIM_Y, T *sph, T *dsph = nullptr,
T *ddsph = nullptr);

private:
size_t l_max; // maximum l value computed by this class
bool normalized; // should we normalize the input vectors?
T *prefactors_cpu; // host prefactors buffer
T *prefactors_cuda; // storage space for prefactors
};

} // namespace cuda
} // namespace sphericart
#endif
20 changes: 9 additions & 11 deletions sphericart/src/cuda_base.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#define _SPHERICART_INTERNAL_IMPLEMENTATION
#define CUDA_DEVICE_PREFIX __device__

#include "cuda.hpp"
#include "cuda_base.hpp"

#define HARDCODED_LMAX 1

Expand Down Expand Up @@ -575,9 +575,9 @@ void sphericart::cuda::spherical_harmonics_cuda_base(
const scalar_t *__restrict__ xyz, const int nedges,
const scalar_t *__restrict__ prefactors, const int nprefactors,
const int64_t l_max, const bool normalize, const int64_t GRID_DIM_X,
const int64_t GRID_DIM_Y, const bool xyz_requires_grad,
const bool gradients, const bool hessian, scalar_t *__restrict__ sph,
scalar_t *__restrict__ dsph, scalar_t *__restrict__ ddsph) {
const int64_t GRID_DIM_Y, const bool gradients, const bool hessian,
scalar_t *__restrict__ sph, scalar_t *__restrict__ dsph,
scalar_t *__restrict__ ddsph) {

int n_total = (l_max + 1) * (l_max + 1);

Expand All @@ -590,14 +590,12 @@ void sphericart::cuda::spherical_harmonics_cuda_base(
dim3 block_dim(find_num_blocks(nedges, GRID_DIM_Y));

size_t total_buff_size = total_buffer_size(
l_max, GRID_DIM_X, GRID_DIM_Y, sizeof(scalar_t),
xyz_requires_grad || gradients, xyz_requires_grad && hessian);
l_max, GRID_DIM_X, GRID_DIM_Y, sizeof(scalar_t), gradients, hessian);

spherical_harmonics_kernel<scalar_t>
<<<block_dim, grid_dim, total_buff_size>>>(
xyz, nedges, prefactors, nprefactors, l_max, n_total,
xyz_requires_grad || gradients, xyz_requires_grad && hessian,
normalize, sph, dsph, ddsph);
xyz, nedges, prefactors, nprefactors, l_max, n_total, gradients,
hessian, normalize, sph, dsph, ddsph);

cudaDeviceSynchronize();
}
Expand All @@ -606,15 +604,15 @@ template void sphericart::cuda::spherical_harmonics_cuda_base<float>(
const float *__restrict__ xyz, const int nedges,
const float *__restrict__ prefactors, const int nprefactors,
const int64_t l_max, const bool normalize, const int64_t GRID_DIM_X,
const int64_t GRID_DIM_Y, const bool xyz_requires_grad,
const int64_t GRID_DIM_Y,
const bool gradients, const bool hessian, float *__restrict__ sph,
float *__restrict__ dsph, float *__restrict__ ddsph);

template void sphericart::cuda::spherical_harmonics_cuda_base<double>(
const double *__restrict__ xyz, const int nedges,
const double *__restrict__ prefactors, const int nprefactors,
const int64_t l_max, const bool normalize, const int64_t GRID_DIM_X,
const int64_t GRID_DIM_Y, const bool xyz_requires_grad,
const int64_t GRID_DIM_Y,
const bool gradients, const bool hessian, double *__restrict__ sph,
double *__restrict__ dsph, double *__restrict__ ddsph);

Expand Down
2 changes: 1 addition & 1 deletion sphericart/src/cuda_stub.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <stdexcept>

#include "cuda.hpp"
#include "cuda_base.hpp"

template <typename scalar_t>
void sphericart::cuda::spherical_harmonics_cuda_base(
Expand Down
88 changes: 88 additions & 0 deletions sphericart/src/sphericart_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdexcept>
#include <iostream>

#define _SPHERICART_INTERNAL_IMPLEMENTATION
#include "sphericart_cuda.hpp"
#include "cuda_base.hpp"


/*host macro that checks for errors in CUDA calls, and prints the file + line
* and error string if one occurs
*/
#define CUDA_CHECK(call) \
do { \
cudaError_t cudaStatus = (call); \
if (cudaStatus != cudaSuccess) { \
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \
<< " - " << cudaGetErrorString(cudaStatus) << std::endl; \
cudaDeviceReset(); \
exit(EXIT_FAILURE); \
} \
} while (0)

using namespace sphericart::cuda;

template <typename T>
SphericalHarmonics<T>::SphericalHarmonics(size_t l_max, bool normalized) {
/*
This is the constructor of the SphericalHarmonics class. It initizlizes
buffer space, compute prefactors, and sets the function pointers that are
used for the actual calls
*/

this->l_max = (int)l_max;
this->nprefactors = (int)(l_max + 1) * (l_max + 2);
this->normalized = normalized;
this->prefactors_cpu = new T[this->nprefactors];

// compute prefactors on host first
compute_sph_prefactors<T>((int)l_max, this->prefactors_cpu);
// allocate them on device and copy to device
CUDA_CHECK(
cudaMalloc((void **)this->prefactors_cuda, this->nprefactors * sizeof(T)));
CUDA_CHECK(cudaMemcpy(this->prefactors_cpu, this->prefactors_cuda,
this->nprefactors * sizeof(T),
cudaMemcpyHostToDevice));
}

template <typename T> SphericalHarmonics<T>::~SphericalHarmonics() {
// Destructor, frees the prefactors
delete[] (this->prefactors_cpu);
CUDA_CHECK(cudaFree(this->prefactors_cuda));
}
template <typename T>
void SphericalHarmonics<T>::compute(const T *xyz, const size_t nsamples,
bool compute_with_gradients,
bool compute_with_hessian,
size_t GRID_DIM_X, size_t GRID_DIM_Y,
T *sph, T *dsph,
T *ddsph) {

if (sph == nullptr) {
throw std::runtime_error(
"sphericart::cuda::SphericalHarmonics::compute expected "
"sph ptr initialised, instead nullptr found. Initialise "
"sph with cudaMalloc.");
}

if (compute_with_gradients && dsph == nullptr) {
throw std::runtime_error(
"sphericart::cuda::SphericalHarmonics::compute expected "
"dsph != nullptr since compute_with_gradients = true. "
"initialise dsph with cudaMalloc.");
}

if (compute_with_hessian && ddsph == nullptr) {
throw std::runtime_error(
"sphericart::cuda::SphericalHarmonics::compute expected "
"ddsph != nullptr since compute_with_hessian = true. "
"initialise ddsph with cudaMalloc.");
}

sphericart::cuda::spherical_harmonics_cuda_base<T>(
xyz, nsamples, this->prefactors_cuda, this->nprefactors, this->l_max,
this->normalized, GRID_DIM_X, GRID_DIM_Y, compute_with_gradients,
compute_with_hessian, sph, dsph, ddsph);
}
19 changes: 19 additions & 0 deletions sphericart/src/sphericart_cuda_stub.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "sphericart_cuda.hpp"
#include <cuda.h>
#include <cuda_runtime.h>

using namespace sphericart::cuda;

template <typename T>
SphericalHarmonics<T>::SphericalHarmonics(size_t l_max, bool normalized) {}

template <typename T> SphericalHarmonics<T>::~SphericalHarmonics() {}

void SphericalHarmonics<T>::compute(const T *xyz, const size_t nsamples,
bool compute_with_gradients,
bool compute_with_hessian,
size_t GRID_DIM_X, size_t GRID_DIM_Y,
T *sph, T *dsph = nullptr,
T *ddsph = nullptr) {
throw std::runtime_error("sphericart was not compiled with CUDA support");
}

0 comments on commit 355b776

Please sign in to comment.