Skip to content

Commit

Permalink
Change CUDA API
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Aug 19, 2024
1 parent b07a831 commit eff52b4
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 19 deletions.
21 changes: 13 additions & 8 deletions sphericart/include/sphericart_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,8 @@ template <typename T> class SphericalHarmonics {
*
* @param l_max
* The maximum degree of the spherical harmonics to be calculated.
* @param normalized
* If `false` (default) computes the scaled spherical harmonics, which
* are homogeneous polynomials in the Cartesian coordinates of the input
* points. If `true`, computes the normalized spherical harmonics that are
* evaluated on the unit sphere. In practice, this simply computes the
* scaled harmonics at the normalized coordinates \f$(x/r, y/r, z/r)\f$, and
* adapts the derivatives accordingly.
*/
SphericalHarmonics(size_t l_max, bool normalized = false);
SphericalHarmonics(size_t l_max);

/** Default constructor
* Required so sphericart_torch can conditionally instantiate this class
Expand Down Expand Up @@ -96,6 +89,7 @@ template <typename T> class SphericalHarmonics {
void* cuda_stream = nullptr
);

template <typename U> friend class SolidHarmonics;
/* @cond */
private:
size_t l_max; // maximum l value computed by this class
Expand All @@ -112,6 +106,17 @@ template <typename T> class SphericalHarmonics {
/* @endcond */
};

template <typename T> class SolidHarmonics : public SphericalHarmonics<T> {
public:
/** Initialize the SolidHarmonics class setting maximum degree and
* normalization
*
* @param l_max
* The maximum degree of the spherical harmonics to be calculated.
*/
SolidHarmonics(size_t l_max);
};

} // namespace cuda

/* @cond */
Expand Down
9 changes: 4 additions & 5 deletions sphericart/src/sphericart.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,6 @@ void SphericalHarmonics<T>::compute_sample_with_hessians(
);
}

// instantiates the SphericalHarmonics class for basic floating point types
template class sphericart::SphericalHarmonics<float>;
template class sphericart::SphericalHarmonics<double>;

template <typename T>
SolidHarmonics<T>::SolidHarmonics(size_t l_max) : SphericalHarmonics<T>(l_max) {
/*
Expand Down Expand Up @@ -445,6 +441,9 @@ SolidHarmonics<T>::SolidHarmonics(size_t l_max) : SphericalHarmonics<T>(l_max) {
}
}

// instantiates the SolidHarmonics class for basic floating point types
// instantiates the SphericalHarmonics and SolidHarmonics classes
// for basic floating point types
template class sphericart::SphericalHarmonics<float>;
template class sphericart::SphericalHarmonics<double>;
template class sphericart::SolidHarmonics<float>;
template class sphericart::SolidHarmonics<double>;
14 changes: 11 additions & 3 deletions sphericart/src/sphericart_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ using namespace sphericart::cuda;
} \
} while (0)

template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t l_max, bool normalized) {
template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t l_max) {
/*
This is the constructor of the SphericalHarmonics class. It initizlizes
buffer space, compute prefactors, and sets the function pointers that are
used for the actual calls
*/
this->l_max = (int)l_max;
this->nprefactors = (int)(l_max + 1) * (l_max + 2);
this->normalized = normalized;
this->normalized = true; // SphericalHarmonics class
this->prefactors_cpu = new T[this->nprefactors];

CUDA_CHECK(cudaGetDeviceCount(&this->device_count));
Expand Down Expand Up @@ -211,6 +211,14 @@ void SphericalHarmonics<T>::compute(
CUDA_CHECK(cudaSetDevice(current_device));
}

// instantiates the SphericalHarmonics class for basic floating point types
template <typename T>
SolidHarmonics<T>::SolidHarmonics(size_t l_max) : SphericalHarmonics<T>(l_max) {
this->normalized = false; // SolidHarmonics class
}

// instantiates the SphericalHarmonics and SolidHarmonics classes
// for basic floating point types
template class sphericart::cuda::SphericalHarmonics<float>;
template class sphericart::cuda::SphericalHarmonics<double>;
template class sphericart::cuda::SolidHarmonics<float>;
template class sphericart::cuda::SolidHarmonics<double>;
11 changes: 8 additions & 3 deletions sphericart/src/sphericart_cuda_stub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

using namespace sphericart::cuda;

template <typename T>
SphericalHarmonics<T>::SphericalHarmonics(size_t /*l_max*/, bool /*normalized*/) {}
template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t /*l_max*/) {}

template <typename T> SphericalHarmonics<T>::~SphericalHarmonics() {}

Expand All @@ -23,6 +22,12 @@ void SphericalHarmonics<T>::compute(
throw std::runtime_error("sphericart was not compiled with CUDA support");
}

// instantiates the SphericalHarmonics class for basic floating point types
template <typename T>
SolidHarmonics<T>::SolidHarmonics(size_t l_max) : SphericalHarmonics<T>(l_max) {}

// instantiates the SphericalHarmonics and SolidHarmonics classes
// for basic floating point types
template class sphericart::cuda::SphericalHarmonics<float>;
template class sphericart::cuda::SphericalHarmonics<double>;
template class sphericart::cuda::SolidHarmonics<float>;
template class sphericart::cuda::SolidHarmonics<double>;

0 comments on commit eff52b4

Please sign in to comment.