-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The CUDA code is now part of the main sphericart package and no longer sphericart-torch Co-authored-by: Guillaume Fraux <[email protected]> Co-authored-by: frostedoyster <[email protected]> Co-authored-by: Guillaume Fraux <[email protected]>
- Loading branch information
1 parent
bec16ed
commit 6bbcb3a
Showing
26 changed files
with
892 additions
and
428 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
CUDA C++ | ||
-------- | ||
|
||
The ``sphericart::cuda::SphericalHarmonics`` class automatically initializes and internally stores | ||
pre-factors and buffers, and its usage is similar to the C++ API, although here the class provides | ||
a single unified function for all purposes (values, gradients, and Hessians). This is | ||
illustrated in the example below. The CUDA C++ API is undocumented at this time and subject | ||
to change, but the example below should be sufficient to get started. | ||
|
||
.. literalinclude:: ../../examples/cuda/example.cu | ||
:language: cuda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/** @file example.cpp | ||
* @brief Usage example for the C++ API | ||
*/ | ||
|
||
#include "sphericart_cuda.hpp" | ||
#include <cmath> | ||
#include <cstdio> | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
#include <iostream> | ||
#include <vector> | ||
|
||
using namespace std; | ||
using namespace sphericart::cuda; | ||
|
||
/*host macro that checks for errors in CUDA calls, and prints the file + line | ||
* and error string if one occurs | ||
*/ | ||
#define CUDA_CHECK(call) \ | ||
do { \ | ||
cudaError_t cudaStatus = (call); \ | ||
if (cudaStatus != cudaSuccess) { \ | ||
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ \ | ||
<< " - " << cudaGetErrorString(cudaStatus) << std::endl; \ | ||
cudaDeviceReset(); \ | ||
exit(EXIT_FAILURE); \ | ||
} \ | ||
} while (0) | ||
|
||
int main() { | ||
/* ===== set up the calculation ===== */ | ||
|
||
// hard-coded parameters for the example | ||
size_t n_samples = 10000; | ||
size_t l_max = 10; | ||
|
||
// initializes samples | ||
auto xyz = std::vector<double>(n_samples * 3, 0.0); | ||
for (size_t i = 0; i < n_samples * 3; ++i) { | ||
xyz[i] = (double)rand() / (double)RAND_MAX * 2.0 - 1.0; | ||
} | ||
|
||
// to avoid unnecessary allocations, calculators can use pre-allocated | ||
// memory, one also can provide uninitialized vectors that will be | ||
// automatically reshaped | ||
auto sph = std::vector<double>(n_samples * (l_max + 1) * (l_max + 1), 0.0); | ||
auto dsph = | ||
std::vector<double>(n_samples * 3 * (l_max + 1) * (l_max + 1), 0.0); | ||
auto ddsph = | ||
std::vector<double>(n_samples * 3 * 3 * (l_max + 1) * (l_max + 1), 0.0); | ||
|
||
/* ===== API calls ===== */ | ||
|
||
// internal buffers and numerical factors are initalized at construction | ||
sphericart::cuda::SphericalHarmonics<double> calculator_cuda(l_max); | ||
|
||
double *xyz_cuda; | ||
CUDA_CHECK(cudaMalloc(&xyz_cuda, n_samples * 3 * sizeof(double))); | ||
CUDA_CHECK(cudaMemcpy(xyz_cuda, xyz.data(), n_samples * 3 * sizeof(double), | ||
cudaMemcpyHostToDevice)); | ||
double *sph_cuda; | ||
CUDA_CHECK(cudaMalloc(&sph_cuda, n_samples * (l_max + 1) * (l_max + 1) * | ||
sizeof(double))); | ||
|
||
calculator_cuda.compute(xyz_cuda, n_samples, false, false, | ||
sph_cuda); // no gradients */ | ||
|
||
CUDA_CHECK( | ||
cudaMemcpy(sph.data(), sph_cuda, | ||
n_samples * (l_max + 1) * (l_max + 1) * sizeof(double), | ||
cudaMemcpyDeviceToHost)); | ||
|
||
for (int i = 0; i < 4; i++) { | ||
std::cout << sph[i] << std::endl; | ||
} | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
sphericart-torch/include/sphericart/torch_cuda_wrapper.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef SPHERICART_TORCH_TORCH_CUDA_WRAPPER_HPP | ||
#define SPHERICART_TORCH_TORCH_CUDA_WRAPPER_HPP | ||
|
||
#include <ATen/Tensor.h> | ||
#include <torch/torch.h> | ||
#include <vector> | ||
|
||
namespace sphericart_torch { | ||
|
||
at::Tensor spherical_harmonics_backward_cuda(at::Tensor xyz, at::Tensor dsph, | ||
at::Tensor sph_grad); | ||
|
||
} // namespace sphericart_torch | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.