Skip to content

Commit

Permalink
Change C++ API
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Aug 18, 2024
1 parent b57751e commit 0d61c52
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 98 deletions.
22 changes: 13 additions & 9 deletions sphericart/include/sphericart.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,8 @@ template <typename T> class SphericalHarmonics {
*
* @param l_max
* The maximum degree of the spherical harmonics to be calculated.
* @param normalized
* If `false` (default) computes the scaled spherical harmonics, which
* are homogeneous polynomials in the Cartesian coordinates of the input
* points. If `true`, computes the normalized spherical harmonics that are
* evaluated on the unit sphere. In practice, this simply computes the
* scaled harmonics at the normalized coordinates \f$(x/r, y/r, z/r)\f$, and
* adapts the derivatives accordingly.
*/
SphericalHarmonics(size_t l_max, bool normalized = false);
SphericalHarmonics(size_t l_max);

/* @cond */
~SphericalHarmonics();
Expand Down Expand Up @@ -364,12 +357,12 @@ template <typename T> class SphericalHarmonics {
*/
int get_omp_num_threads() { return this->omp_num_threads; }

template <typename U> friend class SolidHarmonics;
/* @cond */
private:
size_t l_max; // maximum l value computed by this class
size_t size_y; // size of the Ylm rows (l_max+1)**2
size_t size_q; // size of the prefactor-like arrays (l_max+1)*(l_max+2)/2
bool normalized; // should we normalize the input vectors?
int omp_num_threads; // number of openmp thread
T* prefactors; // storage space for prefactor and buffers
T* buffers;
Expand All @@ -388,6 +381,17 @@ template <typename T> class SphericalHarmonics {
/* @endcond */
};

template <typename T> class SolidHarmonics : public SphericalHarmonics<T> {
public:
/** Initialize the SphericalHarmonics class setting maximum degree and
* normalization
*
* @param l_max
* The maximum degree of the spherical harmonics to be calculated.
*/
SolidHarmonics(size_t l_max);
};

} // namespace sphericart

#endif
168 changes: 101 additions & 67 deletions sphericart/src/sphericart.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ using namespace sphericart;
// This macro defines the different possible hardcoded function calls. It is
// used to initialize the function pointers that are used by the `compute_`
// calls in the SphericalHarmonics class
#define _HARCODED_SWITCH_CASE(L_MAX) \
if (this->normalized) { \
this->_array_no_derivatives = &hardcoded_sph<T, false, false, true, L_MAX>; \
this->_array_with_derivatives = &hardcoded_sph<T, true, false, true, L_MAX>; \
this->_sample_no_derivatives = &hardcoded_sph_sample<T, false, false, true, L_MAX>; \
this->_sample_with_derivatives = &hardcoded_sph_sample<T, true, false, true, L_MAX>; \
} else { \
this->_array_no_derivatives = &hardcoded_sph<T, false, false, false, L_MAX>; \
this->_array_with_derivatives = &hardcoded_sph<T, true, false, false, L_MAX>; \
this->_sample_no_derivatives = &hardcoded_sph_sample<T, false, false, false, L_MAX>; \
this->_sample_with_derivatives = &hardcoded_sph_sample<T, true, false, false, L_MAX>; \
}

template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t l_max, bool normalized) {
#define _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(L_MAX) \
this->_array_no_derivatives = &hardcoded_sph<T, false, false, true, L_MAX>; \
this->_array_with_derivatives = &hardcoded_sph<T, true, false, true, L_MAX>; \
this->_sample_no_derivatives = &hardcoded_sph_sample<T, false, false, true, L_MAX>; \
this->_sample_with_derivatives = &hardcoded_sph_sample<T, true, false, true, L_MAX>;

// Same, but for SolidHarmonics
#define _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(L_MAX) \
this->_array_no_derivatives = &hardcoded_sph<T, false, false, false, L_MAX>; \
this->_array_with_derivatives = &hardcoded_sph<T, true, false, false, L_MAX>; \
this->_sample_no_derivatives = &hardcoded_sph_sample<T, false, false, false, L_MAX>; \
this->_sample_with_derivatives = &hardcoded_sph_sample<T, true, false, false, L_MAX>;

template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t l_max) {
/*
This is the constructor of the SphericalHarmonics class. It initizlizes
buffer space, compute prefactors, and sets the function pointers that are
Expand All @@ -32,7 +32,6 @@ template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t l_max, bo
this->l_max = (int)l_max;
this->size_y = (int)(l_max + 1) * (l_max + 1);
this->size_q = (int)(l_max + 1) * (l_max + 2) / 2;
this->normalized = normalized;
this->prefactors = new T[this->size_q * 2];
this->omp_num_threads = omp_get_max_threads();

Expand All @@ -48,76 +47,48 @@ template <typename T> SphericalHarmonics<T>::SphericalHarmonics(size_t l_max, bo
// using a macro to avoid even more code duplication.
switch (this->l_max) {
case 0:
_HARCODED_SWITCH_CASE(0);
_HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(0);
break;
case 1:
_HARCODED_SWITCH_CASE(1);
_HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(1);
break;
case 2:
_HARCODED_SWITCH_CASE(2);
_HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(2);
break;
case 3:
_HARCODED_SWITCH_CASE(3);
_HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(3);
break;
case 4:
_HARCODED_SWITCH_CASE(4);
_HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(4);
break;
case 5:
_HARCODED_SWITCH_CASE(5);
_HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(5);
break;
case 6:
_HARCODED_SWITCH_CASE(6);
_HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(6);
break;
}
} else {
if (this->normalized) {
this->_array_no_derivatives =
&generic_sph<T, false, false, true, SPHERICART_LMAX_HARDCODED>;
this->_array_with_derivatives =
&generic_sph<T, true, false, true, SPHERICART_LMAX_HARDCODED>;
this->_sample_no_derivatives =
&generic_sph_sample<T, false, false, true, SPHERICART_LMAX_HARDCODED>;
this->_sample_with_derivatives =
&generic_sph_sample<T, true, false, true, SPHERICART_LMAX_HARDCODED>;
} else {
this->_array_no_derivatives =
&generic_sph<T, false, false, false, SPHERICART_LMAX_HARDCODED>;
this->_array_with_derivatives =
&generic_sph<T, true, false, false, SPHERICART_LMAX_HARDCODED>;
this->_sample_no_derivatives =
&generic_sph_sample<T, false, false, false, SPHERICART_LMAX_HARDCODED>;
this->_sample_with_derivatives =
&generic_sph_sample<T, true, false, false, SPHERICART_LMAX_HARDCODED>;
}
// if (this->normalized) {
this->_array_no_derivatives = &generic_sph<T, false, false, true, SPHERICART_LMAX_HARDCODED>;
this->_array_with_derivatives =
&generic_sph<T, true, false, true, SPHERICART_LMAX_HARDCODED>;
this->_sample_no_derivatives =
&generic_sph_sample<T, false, false, true, SPHERICART_LMAX_HARDCODED>;
this->_sample_with_derivatives =
&generic_sph_sample<T, true, false, true, SPHERICART_LMAX_HARDCODED>;
}

// set up the second derivative functions
if (this->normalized) {
if (this->l_max == 0) {
this->_array_with_hessians = &hardcoded_sph<T, true, true, true, 0>;
this->_sample_with_hessians = &hardcoded_sph_sample<T, true, true, true, 0>;
} else if (this->l_max == 1) {
this->_array_with_hessians = &hardcoded_sph<T, true, true, true, 1>;
this->_sample_with_hessians = &hardcoded_sph_sample<T, true, true, true, 1>;
} else { // second derivatives are not hardcoded past l = 1. Call
// generic
// implementations
this->_array_with_hessians = &generic_sph<T, true, true, true, 1>;
this->_sample_with_hessians = &generic_sph_sample<T, true, true, true, 1>;
}
} else {
if (this->l_max == 0) {
this->_array_with_hessians = &hardcoded_sph<T, true, true, false, 0>;
this->_sample_with_hessians = &hardcoded_sph_sample<T, true, true, false, 0>;
} else if (this->l_max == 1) {
this->_array_with_hessians = &hardcoded_sph<T, true, true, false, 1>;
this->_sample_with_hessians = &hardcoded_sph_sample<T, true, true, false, 1>;
} else { // second derivatives are not hardcoded past l = 1. Call
// generic
// implementations
this->_array_with_hessians = &generic_sph<T, true, true, false, 1>;
this->_sample_with_hessians = &generic_sph_sample<T, true, true, false, 1>;
}
if (this->l_max == 0) {
this->_array_with_hessians = &hardcoded_sph<T, true, true, true, 0>;
this->_sample_with_hessians = &hardcoded_sph_sample<T, true, true, true, 0>;
} else if (this->l_max == 1) {
this->_array_with_hessians = &hardcoded_sph<T, true, true, true, 1>;
this->_sample_with_hessians = &hardcoded_sph_sample<T, true, true, true, 1>;
} else { // second derivatives are not hardcoded past l = 1. Call generic implementations
this->_array_with_hessians = &generic_sph<T, true, true, true, 1>;
this->_sample_with_hessians = &generic_sph_sample<T, true, true, true, 1>;
}
}

Expand Down Expand Up @@ -414,3 +385,66 @@ void SphericalHarmonics<T>::compute_sample_with_hessians(
// instantiates the SphericalHarmonics class for basic floating point types
template class sphericart::SphericalHarmonics<float>;
template class sphericart::SphericalHarmonics<double>;

template <typename T>
SolidHarmonics<T>::SolidHarmonics(size_t l_max) : SphericalHarmonics<T>(l_max) {
/*
This is the constructor of the SolidHarmonics class. It initizlizes
buffer space, compute prefactors, and sets the function pointers that are
used for the actual calls
*/

// Just override the function pointers with the SolidHarmonics versions
if (this->l_max <= SPHERICART_LMAX_HARDCODED) {
// If we only need hard-coded calls, we set them up at this point
// using a macro to avoid even more code duplication.
switch (this->l_max) {
case 0:
_HARDCODED_SWITCH_CASE_SOLID_HARMONICS(0);
break;
case 1:
_HARDCODED_SWITCH_CASE_SOLID_HARMONICS(1);
break;
case 2:
_HARDCODED_SWITCH_CASE_SOLID_HARMONICS(2);
break;
case 3:
_HARDCODED_SWITCH_CASE_SOLID_HARMONICS(3);
break;
case 4:
_HARDCODED_SWITCH_CASE_SOLID_HARMONICS(4);
break;
case 5:
_HARDCODED_SWITCH_CASE_SOLID_HARMONICS(5);
break;
case 6:
_HARDCODED_SWITCH_CASE_SOLID_HARMONICS(6);
break;
}
} else {
this->_array_no_derivatives =
&generic_sph<T, false, false, false, SPHERICART_LMAX_HARDCODED>;
this->_array_with_derivatives =
&generic_sph<T, true, false, false, SPHERICART_LMAX_HARDCODED>;
this->_sample_no_derivatives =
&generic_sph_sample<T, false, false, false, SPHERICART_LMAX_HARDCODED>;
this->_sample_with_derivatives =
&generic_sph_sample<T, true, false, false, SPHERICART_LMAX_HARDCODED>;
}

// set up the second derivative functions
if (this->l_max == 0) {
this->_array_with_hessians = &hardcoded_sph<T, true, true, false, 0>;
this->_sample_with_hessians = &hardcoded_sph_sample<T, true, true, false, 0>;
} else if (this->l_max == 1) {
this->_array_with_hessians = &hardcoded_sph<T, true, true, false, 1>;
this->_sample_with_hessians = &hardcoded_sph_sample<T, true, true, false, 1>;
} else { // second derivatives are not hardcoded past l = 1. Call generic implementations
this->_array_with_hessians = &generic_sph<T, true, true, false, 1>;
this->_sample_with_hessians = &generic_sph_sample<T, true, true, false, 1>;
}
}

// instantiates the SolidHarmonics class for basic floating point types
template class sphericart::SolidHarmonics<float>;
template class sphericart::SolidHarmonics<double>;
48 changes: 28 additions & 20 deletions sphericart/tests/test_derivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
1e-4 // High tolerance: finite differences are inaccurate for second
// derivatives

bool check_gradient_call(
int l_max, sphericart::SphericalHarmonics<double>& calculator, const std::vector<double>& xyz
) {
template <template <typename> class C>
bool check_gradient_call(int l_max, C<double>& calculator, const std::vector<double>& xyz) {
bool is_passed = true;
int n_samples = xyz.size() / 3;
int n_sph = (l_max + 1) * (l_max + 1);
Expand Down Expand Up @@ -60,9 +59,8 @@ bool check_gradient_call(
return is_passed;
}

bool check_hessian_call(
int l_max, sphericart::SphericalHarmonics<double>& calculator, const std::vector<double>& xyz
) {
template <template <typename> class C>
bool check_hessian_call(int l_max, C<double>& calculator, const std::vector<double>& xyz) {
bool is_passed = true;
int n_samples = xyz.size() / 3;
int n_sph = (l_max + 1) * (l_max + 1);
Expand Down Expand Up @@ -178,20 +176,30 @@ int main() {
}
}

for (bool normalize : {false, true}) { // Test with and without normalization
for (int l_max = 0; l_max < l_max_max; l_max++) { // Test for a range of l_max values
sphericart::SphericalHarmonics<double> normalized_calculator =
sphericart::SphericalHarmonics<double>(l_max, normalize);
is_passed = check_gradient_call(l_max, normalized_calculator, xyz);
if (!is_passed) {
std::cout << "Test failed" << std::endl;
return -1;
}
is_passed = check_hessian_call(l_max, normalized_calculator, xyz);
if (!is_passed) {
std::cout << "Test failed" << std::endl;
return -1;
}
for (int l_max = 0; l_max < l_max_max; l_max++) { // Test for a range of l_max values
sphericart::SphericalHarmonics<double> calculator =
sphericart::SphericalHarmonics<double>(l_max);
is_passed = check_gradient_call(l_max, calculator, xyz);
if (!is_passed) {
std::cout << "Test failed" << std::endl;
return -1;
}
is_passed = check_hessian_call(l_max, calculator, xyz);
if (!is_passed) {
std::cout << "Test failed" << std::endl;
return -1;
}

sphericart::SolidHarmonics<double> calculator = sphericart::SolidHarmonics<double>(l_max);
is_passed = check_gradient_call(l_max, calculator, xyz);
if (!is_passed) {
std::cout << "Test failed" << std::endl;
return -1;
}
is_passed = check_hessian_call(l_max, calculator, xyz);
if (!is_passed) {
std::cout << "Test failed" << std::endl;
return -1;
}
}

Expand Down
2 changes: 1 addition & 1 deletion sphericart/tests/test_hardcoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ int main(int argc, char* argv[]) {
auto sph1 = std::vector<DTYPE>(n_samples * (l_max + 1) * (l_max + 1), 0.0);
auto dsph1 = std::vector<DTYPE>(n_samples * 3 * (l_max + 1) * (l_max + 1), 0.0);

SphericalHarmonics<DTYPE> SH(l_max, false);
SphericalHarmonics<DTYPE> SH(l_max);
SH.compute_with_gradients(xyz, sph1, dsph1);

int size3 = 3 * (l_max + 1) * (l_max + 1); // Size of the third dimension in derivative
Expand Down
2 changes: 1 addition & 1 deletion sphericart/tests/test_samples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char* argv[]) {
auto dsph = std::vector<DTYPE>(n_samples * 3 * (l_max + 1) * (l_max + 1), 0.0);
auto sph_sample = std::vector<DTYPE>(1 * (l_max + 1) * (l_max + 1), 0.0);
auto dsph_sample = std::vector<DTYPE>(1 * 3 * (l_max + 1) * (l_max + 1), 0.0);
SphericalHarmonics<DTYPE> SH(l_max, false);
SphericalHarmonics<DTYPE> SH(l_max);
SH.compute_with_gradients(xyz_sample, sph_sample, dsph_sample);
SH.compute_with_gradients(xyz, sph, dsph);
int size3 = 3 * (l_max + 1) * (l_max + 1); // Size of the third dimension in derivative
Expand Down

0 comments on commit 0d61c52

Please sign in to comment.