Skip to content

Commit

Permalink
Standalong CUDA code (#83)
Browse files Browse the repository at this point in the history
The CUDA code is now part of the main sphericart package and no longer sphericart-torch

Co-authored-by: Guillaume Fraux <[email protected]>
Co-authored-by: frostedoyster <[email protected]>
Co-authored-by: Guillaume Fraux <[email protected]>
  • Loading branch information
4 people authored Jan 16, 2024
1 parent bec16ed commit 6bbcb3a
Show file tree
Hide file tree
Showing 26 changed files with 892 additions and 428 deletions.
2 changes: 1 addition & 1 deletion ci/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test_job:
- mkdir buildcpp
- cd buildcpp
- export Torch_DIR=/usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch/
- cmake .. -DSPHERICART_BUILD_TESTS=ON -DSPHERICART_OPENMP=ON -DSPHERICART_BUILD_EXAMPLES=ON -DSPHERICART_BUILD_TORCH=ON
- cmake .. -DSPHERICART_BUILD_TESTS=ON -DSPHERICART_OPENMP=ON -DSPHERICART_BUILD_EXAMPLES=ON -DSPHERICART_ENABLE_CUDA=ON -DSPHERICART_BUILD_TORCH=ON
- cmake --build .
- ctest

Expand Down
11 changes: 11 additions & 0 deletions docs/src/cuda-examples.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CUDA C++
--------

The ``sphericart::cuda::SphericalHarmonics`` class automatically initializes and internally stores
pre-factors and buffers, and its usage is similar to the C++ API, although here the class provides
a single unified function for all purposes (values, gradients, and Hessians). This is
illustrated in the example below. The CUDA C++ API is undocumented at this time and subject
to change, but the example below should be sufficient to get started.

.. literalinclude:: ../../examples/cuda/example.cu
:language: cuda
1 change: 1 addition & 0 deletions docs/src/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ floating-point arithmetics, and they evaluate the mean relative error between th

cpp-examples
c-examples
cuda-examples
python-examples
pytorch-examples
jax-examples
Expand Down
18 changes: 18 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
cmake_minimum_required(VERSION 3.10)

if (CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA)
project(sphericart_examples LANGUAGES C CXX CUDA)
else()
project(sphericart_examples LANGUAGES C CXX)
endif()

add_executable(example_cpp cpp/example.cpp)
target_link_libraries(example_cpp sphericart)
add_test(NAME example_cpp COMMAND ./example_cpp)
Expand All @@ -6,3 +14,13 @@ target_compile_features(example_cpp PRIVATE cxx_std_14)
add_executable(example_c c/example.c)
target_link_libraries(example_c sphericart)
add_test(NAME example_c COMMAND ./example_c)

if (CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA)
if(SPHERICART_ARCH_NATIVE)
set(CMAKE_CUDA_ARCHITECTURES native)
endif()

add_executable(example_cuda cuda/example.cu)
target_link_libraries(example_cuda sphericart)
add_test(NAME example_cuda COMMAND ./example_cuda)
endif()
78 changes: 78 additions & 0 deletions examples/cuda/example.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/** @file example.cpp
* @brief Usage example for the C++ API
*/

#include "sphericart_cuda.hpp"
#include <cmath>
#include <cstdio>
#include <cuda.h>
#include <cuda_runtime.h>
#include <iostream>
#include <vector>

using namespace std;
using namespace sphericart::cuda;

/*host macro that checks for errors in CUDA calls, and prints the file + line
* and error string if one occurs
*/
#define CUDA_CHECK(call) \
do { \
cudaError_t cudaStatus = (call); \
if (cudaStatus != cudaSuccess) { \
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \
<< " - " << cudaGetErrorString(cudaStatus) << std::endl; \
cudaDeviceReset(); \
exit(EXIT_FAILURE); \
} \
} while (0)

int main() {
/* ===== set up the calculation ===== */

// hard-coded parameters for the example
size_t n_samples = 10000;
size_t l_max = 10;

// initializes samples
auto xyz = std::vector<double>(n_samples * 3, 0.0);
for (size_t i = 0; i < n_samples * 3; ++i) {
xyz[i] = (double)rand() / (double)RAND_MAX * 2.0 - 1.0;
}

// to avoid unnecessary allocations, calculators can use pre-allocated
// memory, one also can provide uninitialized vectors that will be
// automatically reshaped
auto sph = std::vector<double>(n_samples * (l_max + 1) * (l_max + 1), 0.0);
auto dsph =
std::vector<double>(n_samples * 3 * (l_max + 1) * (l_max + 1), 0.0);
auto ddsph =
std::vector<double>(n_samples * 3 * 3 * (l_max + 1) * (l_max + 1), 0.0);

/* ===== API calls ===== */

// internal buffers and numerical factors are initalized at construction
sphericart::cuda::SphericalHarmonics<double> calculator_cuda(l_max);

double *xyz_cuda;
CUDA_CHECK(cudaMalloc(&xyz_cuda, n_samples * 3 * sizeof(double)));
CUDA_CHECK(cudaMemcpy(xyz_cuda, xyz.data(), n_samples * 3 * sizeof(double),
cudaMemcpyHostToDevice));
double *sph_cuda;
CUDA_CHECK(cudaMalloc(&sph_cuda, n_samples * (l_max + 1) * (l_max + 1) *
sizeof(double)));

calculator_cuda.compute(xyz_cuda, n_samples, false, false,
sph_cuda); // no gradients */

CUDA_CHECK(
cudaMemcpy(sph.data(), sph_cuda,
n_samples * (l_max + 1) * (l_max + 1) * sizeof(double),
cudaMemcpyDeviceToHost));

for (int i = 0; i < 4; i++) {
std::cout << sph[i] << std::endl;
}

return 0;
}
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def run(self):
f"-DSPHERICART_ARCH_NATIVE={SPHERICART_ARCH_NATIVE}",
]

CUDA_HOME = os.environ.get("CUDA_HOME")
if CUDA_HOME is not None:
cmake_options.append(f"-DCUDA_TOOLKIT_ROOT_DIR={CUDA_HOME}")
cmake_options.append("-DSPHERICART_ENABLE_CUDA=ON")

if sys.platform.startswith("darwin"):
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0")

Expand Down
7 changes: 4 additions & 3 deletions sphericart-torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,17 @@ find_package(Torch 1.13 REQUIRED)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/_build_torch_version.py "BUILD_TORCH_VERSION = '${Torch_VERSION}'")

add_library(sphericart_torch SHARED
"include/sphericart/torch_cuda_wrapper.hpp"
"include/sphericart/torch.hpp"
"include/sphericart/autograd.hpp"
"src/autograd.cpp"
"src/torch.cpp"
)

if (CMAKE_CUDA_COMPILER)
target_sources(sphericart_torch PUBLIC "src/cuda.cu")
if (CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA)
target_sources(sphericart_torch PUBLIC "src/torch_cuda_wrapper.cpp")
else()
target_sources(sphericart_torch PUBLIC "src/cuda_stub.cpp")
target_sources(sphericart_torch PUBLIC "src/torch_cuda_wrapper_stub.cpp")
endif()

target_link_libraries(sphericart_torch PUBLIC torch sphericart)
Expand Down
2 changes: 2 additions & 0 deletions sphericart-torch/include/sphericart/autograd.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#ifndef SPHERICART_TORCH_AUTOGRAD_HPP
#define SPHERICART_TORCH_AUTOGRAD_HPP

#include <ATen/Tensor.h>
#include <torch/autograd.h>
#include <torch/data.h>
#include <vector>

namespace sphericart_torch {

Expand Down
26 changes: 0 additions & 26 deletions sphericart-torch/include/sphericart/cuda.hpp

This file was deleted.

31 changes: 4 additions & 27 deletions sphericart-torch/include/sphericart/torch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,13 @@
#include <mutex>

#include "sphericart.hpp"
#include "sphericart_cuda.hpp"

namespace sphericart_torch {

class SphericalHarmonicsAutograd;
class SphericalHarmonicsAutogradBackward;

class CudaSharedMemorySettings {
public:
CudaSharedMemorySettings()
: scalar_size_(0), l_max_(-1), grid_dim_x_(-1), grid_dim_y_(-1),
requires_grad_(false), requires_hessian_(false) {}

bool update_if_required(torch::ScalarType scalar_type, int64_t l_max,
int64_t GRID_DIM_X, int64_t GRID_DIM_Y,
bool gradients, bool hessian);

private:
int64_t l_max_;
int64_t grid_dim_x_;
int64_t grid_dim_y_;
bool requires_grad_;
bool requires_hessian_;
size_t scalar_size_;
};

class SphericalHarmonics : public torch::CustomClassHolder {
public:
SphericalHarmonics(int64_t l_max, bool normalized = false,
Expand Down Expand Up @@ -64,14 +46,9 @@ class SphericalHarmonics : public torch::CustomClassHolder {
sphericart::SphericalHarmonics<double> calculator_double_;
sphericart::SphericalHarmonics<float> calculator_float_;

// CUDA sdata
torch::Tensor prefactors_cuda_double_;
torch::Tensor prefactors_cuda_float_;

int64_t CUDA_GRID_DIM_X_ = 8;
int64_t CUDA_GRID_DIM_Y_ = 8;
CudaSharedMemorySettings cuda_shmem_;
std::mutex cuda_shmem_mutex_;
// CUDA implementation
sphericart::cuda::SphericalHarmonics<double> calculator_cuda_double_;
sphericart::cuda::SphericalHarmonics<float> calculator_cuda_float_;
};

} // namespace sphericart_torch
Expand Down
15 changes: 15 additions & 0 deletions sphericart-torch/include/sphericart/torch_cuda_wrapper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef SPHERICART_TORCH_TORCH_CUDA_WRAPPER_HPP
#define SPHERICART_TORCH_TORCH_CUDA_WRAPPER_HPP

#include <ATen/Tensor.h>
#include <torch/torch.h>
#include <vector>

namespace sphericart_torch {

at::Tensor spherical_harmonics_backward_cuda(at::Tensor xyz, at::Tensor dsph,
at::Tensor sph_grad);

} // namespace sphericart_torch

#endif
19 changes: 9 additions & 10 deletions sphericart-torch/python/tests/test_e3nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# only run e3nn tests if e3nn is present
try:
import e3nn
import e3nn.o3 as o3

_HAS_E3NN = True
except ModuleNotFoundError:
Expand Down Expand Up @@ -36,12 +37,12 @@ def test_e3nn_inputs(xyz):
xyz_sh = xyz.clone().detach().requires_grad_()
xyz_e3nn = xyz.clone().detach().requires_grad_()

e3nn_reference = e3nn.o3.spherical_harmonics(8, xyz, False)
e3nn_reference = o3.spherical_harmonics(8, xyz, False)
sh = sphericart.torch.e3nn_spherical_harmonics(8, xyz, False)

assert relative_mse(e3nn_reference.detach(), sh.detach()) < TOLERANCE

e3nn_reference = e3nn.o3.spherical_harmonics([1, 3, 5], xyz_e3nn, True)
e3nn_reference = o3.spherical_harmonics([1, 3, 5], xyz_e3nn, True)
sh = sphericart.torch.e3nn_spherical_harmonics([1, 3, 5], xyz_sh, True)

assert relative_mse(e3nn_reference.detach(), sh.detach()) < TOLERANCE
Expand All @@ -63,9 +64,7 @@ def test_e3nn_parameters(xyz, normalize, normalization):
"""Checks that the different normalization options match."""

l_list = list(range(10))
e3nn_reference = e3nn.o3.spherical_harmonics(
l_list, xyz, normalize, normalization
)
e3nn_reference = o3.spherical_harmonics(l_list, xyz, normalize, normalization)
sh = sphericart.torch.e3nn_spherical_harmonics(
l_list, xyz, normalize, normalization
)
Expand All @@ -74,16 +73,16 @@ def test_e3nn_parameters(xyz, normalize, normalization):

def test_e3nn_patch(xyz):
"""Tests the patch function."""
e3nn_reference = e3nn.o3.spherical_harmonics([1, 3, 5], xyz, True)
e3nn_builtin = e3nn.o3.spherical_harmonics
e3nn_reference = o3.spherical_harmonics([1, 3, 5], xyz, True)
e3nn_builtin = o3.spherical_harmonics

sphericart.torch.patch_e3nn(e3nn)

assert e3nn.o3.spherical_harmonics is sphericart.torch.e3nn_spherical_harmonics
sh = e3nn.o3.spherical_harmonics([1, 3, 5], xyz, True)
assert o3.spherical_harmonics is sphericart.torch.e3nn_spherical_harmonics
sh = o3.spherical_harmonics([1, 3, 5], xyz, True)

# restore spherical_harmonics
sphericart.torch.unpatch_e3nn(e3nn)
assert e3nn.o3.spherical_harmonics is e3nn_builtin
assert o3.spherical_harmonics is e3nn_builtin

assert relative_mse(e3nn_reference.detach(), sh.detach()) < TOLERANCE
1 change: 1 addition & 0 deletions sphericart-torch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def run(self):
CUDA_HOME = os.environ.get("CUDA_HOME")
if CUDA_HOME is not None:
cmake_options.append(f"-DCUDA_TOOLKIT_ROOT_DIR={CUDA_HOME}")
cmake_options.append("-DSPHERICART_ENABLE_CUDA=ON")

if sys.platform.startswith("darwin"):
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0")
Expand Down
Loading

0 comments on commit 6bbcb3a

Please sign in to comment.