diff --git a/sphericart/include/sphericart.hpp b/sphericart/include/sphericart.hpp index 8023afd62..73251552e 100644 --- a/sphericart/include/sphericart.hpp +++ b/sphericart/include/sphericart.hpp @@ -30,15 +30,8 @@ template class SphericalHarmonics { * * @param l_max * The maximum degree of the spherical harmonics to be calculated. - * @param normalized - * If `false` (default) computes the scaled spherical harmonics, which - * are homogeneous polynomials in the Cartesian coordinates of the input - * points. If `true`, computes the normalized spherical harmonics that are - * evaluated on the unit sphere. In practice, this simply computes the - * scaled harmonics at the normalized coordinates \f$(x/r, y/r, z/r)\f$, and - * adapts the derivatives accordingly. */ - SphericalHarmonics(size_t l_max, bool normalized = false); + SphericalHarmonics(size_t l_max); /* @cond */ ~SphericalHarmonics(); @@ -364,12 +357,12 @@ template class SphericalHarmonics { */ int get_omp_num_threads() { return this->omp_num_threads; } + template friend class SolidHarmonics; /* @cond */ private: size_t l_max; // maximum l value computed by this class size_t size_y; // size of the Ylm rows (l_max+1)**2 size_t size_q; // size of the prefactor-like arrays (l_max+1)*(l_max+2)/2 - bool normalized; // should we normalize the input vectors? int omp_num_threads; // number of openmp thread T* prefactors; // storage space for prefactor and buffers T* buffers; @@ -388,6 +381,17 @@ template class SphericalHarmonics { /* @endcond */ }; +template class SolidHarmonics : public SphericalHarmonics { + public: + /** Initialize the SphericalHarmonics class setting maximum degree and + * normalization + * + * @param l_max + * The maximum degree of the spherical harmonics to be calculated. + */ + SolidHarmonics(size_t l_max); +}; + } // namespace sphericart #endif diff --git a/sphericart/src/sphericart.cpp b/sphericart/src/sphericart.cpp index f0a14601b..6cc9c275c 100644 --- a/sphericart/src/sphericart.cpp +++ b/sphericart/src/sphericart.cpp @@ -9,20 +9,20 @@ using namespace sphericart; // This macro defines the different possible hardcoded function calls. It is // used to initialize the function pointers that are used by the `compute_` // calls in the SphericalHarmonics class -#define _HARCODED_SWITCH_CASE(L_MAX) \ - if (this->normalized) { \ - this->_array_no_derivatives = &hardcoded_sph; \ - this->_array_with_derivatives = &hardcoded_sph; \ - this->_sample_no_derivatives = &hardcoded_sph_sample; \ - this->_sample_with_derivatives = &hardcoded_sph_sample; \ - } else { \ - this->_array_no_derivatives = &hardcoded_sph; \ - this->_array_with_derivatives = &hardcoded_sph; \ - this->_sample_no_derivatives = &hardcoded_sph_sample; \ - this->_sample_with_derivatives = &hardcoded_sph_sample; \ - } - -template SphericalHarmonics::SphericalHarmonics(size_t l_max, bool normalized) { +#define _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(L_MAX) \ + this->_array_no_derivatives = &hardcoded_sph; \ + this->_array_with_derivatives = &hardcoded_sph; \ + this->_sample_no_derivatives = &hardcoded_sph_sample; \ + this->_sample_with_derivatives = &hardcoded_sph_sample; + +// Same, but for SolidHarmonics +#define _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(L_MAX) \ + this->_array_no_derivatives = &hardcoded_sph; \ + this->_array_with_derivatives = &hardcoded_sph; \ + this->_sample_no_derivatives = &hardcoded_sph_sample; \ + this->_sample_with_derivatives = &hardcoded_sph_sample; + +template SphericalHarmonics::SphericalHarmonics(size_t l_max) { /* This is the constructor of the SphericalHarmonics class. It initizlizes buffer space, compute prefactors, and sets the function pointers that are @@ -32,7 +32,6 @@ template SphericalHarmonics::SphericalHarmonics(size_t l_max, bo this->l_max = (int)l_max; this->size_y = (int)(l_max + 1) * (l_max + 1); this->size_q = (int)(l_max + 1) * (l_max + 2) / 2; - this->normalized = normalized; this->prefactors = new T[this->size_q * 2]; this->omp_num_threads = omp_get_max_threads(); @@ -48,76 +47,48 @@ template SphericalHarmonics::SphericalHarmonics(size_t l_max, bo // using a macro to avoid even more code duplication. switch (this->l_max) { case 0: - _HARCODED_SWITCH_CASE(0); + _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(0); break; case 1: - _HARCODED_SWITCH_CASE(1); + _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(1); break; case 2: - _HARCODED_SWITCH_CASE(2); + _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(2); break; case 3: - _HARCODED_SWITCH_CASE(3); + _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(3); break; case 4: - _HARCODED_SWITCH_CASE(4); + _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(4); break; case 5: - _HARCODED_SWITCH_CASE(5); + _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(5); break; case 6: - _HARCODED_SWITCH_CASE(6); + _HARDCODED_SWITCH_CASE_SPHERICAL_HARMONICS(6); break; } } else { - if (this->normalized) { - this->_array_no_derivatives = - &generic_sph; - this->_array_with_derivatives = - &generic_sph; - this->_sample_no_derivatives = - &generic_sph_sample; - this->_sample_with_derivatives = - &generic_sph_sample; - } else { - this->_array_no_derivatives = - &generic_sph; - this->_array_with_derivatives = - &generic_sph; - this->_sample_no_derivatives = - &generic_sph_sample; - this->_sample_with_derivatives = - &generic_sph_sample; - } + // if (this->normalized) { + this->_array_no_derivatives = &generic_sph; + this->_array_with_derivatives = + &generic_sph; + this->_sample_no_derivatives = + &generic_sph_sample; + this->_sample_with_derivatives = + &generic_sph_sample; } // set up the second derivative functions - if (this->normalized) { - if (this->l_max == 0) { - this->_array_with_hessians = &hardcoded_sph; - this->_sample_with_hessians = &hardcoded_sph_sample; - } else if (this->l_max == 1) { - this->_array_with_hessians = &hardcoded_sph; - this->_sample_with_hessians = &hardcoded_sph_sample; - } else { // second derivatives are not hardcoded past l = 1. Call - // generic - // implementations - this->_array_with_hessians = &generic_sph; - this->_sample_with_hessians = &generic_sph_sample; - } - } else { - if (this->l_max == 0) { - this->_array_with_hessians = &hardcoded_sph; - this->_sample_with_hessians = &hardcoded_sph_sample; - } else if (this->l_max == 1) { - this->_array_with_hessians = &hardcoded_sph; - this->_sample_with_hessians = &hardcoded_sph_sample; - } else { // second derivatives are not hardcoded past l = 1. Call - // generic - // implementations - this->_array_with_hessians = &generic_sph; - this->_sample_with_hessians = &generic_sph_sample; - } + if (this->l_max == 0) { + this->_array_with_hessians = &hardcoded_sph; + this->_sample_with_hessians = &hardcoded_sph_sample; + } else if (this->l_max == 1) { + this->_array_with_hessians = &hardcoded_sph; + this->_sample_with_hessians = &hardcoded_sph_sample; + } else { // second derivatives are not hardcoded past l = 1. Call generic implementations + this->_array_with_hessians = &generic_sph; + this->_sample_with_hessians = &generic_sph_sample; } } @@ -414,3 +385,66 @@ void SphericalHarmonics::compute_sample_with_hessians( // instantiates the SphericalHarmonics class for basic floating point types template class sphericart::SphericalHarmonics; template class sphericart::SphericalHarmonics; + +template +SolidHarmonics::SolidHarmonics(size_t l_max) : SphericalHarmonics(l_max) { + /* + This is the constructor of the SolidHarmonics class. It initizlizes + buffer space, compute prefactors, and sets the function pointers that are + used for the actual calls + */ + + // Just override the function pointers with the SolidHarmonics versions + if (this->l_max <= SPHERICART_LMAX_HARDCODED) { + // If we only need hard-coded calls, we set them up at this point + // using a macro to avoid even more code duplication. + switch (this->l_max) { + case 0: + _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(0); + break; + case 1: + _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(1); + break; + case 2: + _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(2); + break; + case 3: + _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(3); + break; + case 4: + _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(4); + break; + case 5: + _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(5); + break; + case 6: + _HARDCODED_SWITCH_CASE_SOLID_HARMONICS(6); + break; + } + } else { + this->_array_no_derivatives = + &generic_sph; + this->_array_with_derivatives = + &generic_sph; + this->_sample_no_derivatives = + &generic_sph_sample; + this->_sample_with_derivatives = + &generic_sph_sample; + } + + // set up the second derivative functions + if (this->l_max == 0) { + this->_array_with_hessians = &hardcoded_sph; + this->_sample_with_hessians = &hardcoded_sph_sample; + } else if (this->l_max == 1) { + this->_array_with_hessians = &hardcoded_sph; + this->_sample_with_hessians = &hardcoded_sph_sample; + } else { // second derivatives are not hardcoded past l = 1. Call generic implementations + this->_array_with_hessians = &generic_sph; + this->_sample_with_hessians = &generic_sph_sample; + } +} + +// instantiates the SolidHarmonics class for basic floating point types +template class sphericart::SolidHarmonics; +template class sphericart::SolidHarmonics; diff --git a/sphericart/tests/test_derivatives.cpp b/sphericart/tests/test_derivatives.cpp index 77c8e22f7..d7ed4a8f2 100644 --- a/sphericart/tests/test_derivatives.cpp +++ b/sphericart/tests/test_derivatives.cpp @@ -13,9 +13,8 @@ 1e-4 // High tolerance: finite differences are inaccurate for second // derivatives -bool check_gradient_call( - int l_max, sphericart::SphericalHarmonics& calculator, const std::vector& xyz -) { +template