diff --git a/c/reciprocal_to_normal.c b/c/reciprocal_to_normal.c index bf30e175..c3f39bef 100644 --- a/c/reciprocal_to_normal.c +++ b/c/reciprocal_to_normal.c @@ -34,6 +34,11 @@ #include "reciprocal_to_normal.h" +#if defined(MKL_LAPACKE) || defined(SCIPY_MKL_H) +#include +#else +#include +#endif #include #include @@ -49,7 +54,11 @@ static double get_fc3_sum(const lapack_complex_double *e0, const lapack_complex_double *e2, const lapack_complex_double *fc3_reciprocal, const long num_band); - +static double get_fc3_sum_blas(const lapack_complex_double *e0, + const lapack_complex_double *e1, + const lapack_complex_double *e2, + const lapack_complex_double *fc3_reciprocal, + const long num_band); void reciprocal_to_normal_squared( double *fc3_normal_squared, const long (*g_pos)[4], const long num_g_pos, const lapack_complex_double *fc3_reciprocal, const double *freqs0, @@ -189,3 +198,40 @@ static double get_fc3_sum(const lapack_complex_double *e0, e_12_cache = NULL; return (sum_real * sum_real + sum_imag * sum_imag); } + +static double get_fc3_sum_blas(const lapack_complex_double *e0, + const lapack_complex_double *e1, + const lapack_complex_double *e2, + const lapack_complex_double *fc3_reciprocal, + const long num_band) { + long i, j; + lapack_complex_double *fc3_e12, *e_12, zero, one, retval; + const lapack_complex_double *fc3_i; + + e_12 = (lapack_complex_double *)malloc(sizeof(lapack_complex_double) * + num_band * num_band); + fc3_e12 = (lapack_complex_double *)malloc(sizeof(lapack_complex_double) * + num_band); + zero = lapack_make_complex_double(0, 0); + one = lapack_make_complex_double(1, 0); + + for (i = 0; i < num_band; i++) { + cblas_zcopy(num_band, e2, 1, e_12 + i * num_band, 1); + cblas_zscal(num_band, e1 + i, e_12 + i * num_band, 1); + } + + cblas_zgemv(CblasRowMajor, CblasNoTrans, num_band, num_band * num_band, + &one, fc3_reciprocal, num_band * num_band, e_12, 1, &zero, + fc3_e12, 1); + cblas_zdotu_sub(num_band, e0, 1, fc3_e12, 1, &retval); + + free(e_12); + e_12 = NULL; + free(fc3_e12); + fc3_e12 = NULL; + + return lapack_complex_double_real(retval) * + lapack_complex_double_real(retval) + + lapack_complex_double_imag(retval) * + lapack_complex_double_imag(retval); +}