From 4aff034e58550fb2bb6cd3c3c67b000b4f291e93 Mon Sep 17 00:00:00 2001 From: kylebystrom Date: Mon, 11 Nov 2024 15:33:17 -0500 Subject: [PATCH] parallelize functions in conv_interpolation --- ciderpress/dft/lcao_interpolation.py | 6 +-- ciderpress/lib/mod_cider/conv_interpolation.c | 48 ++++++++++++++++--- 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/ciderpress/dft/lcao_interpolation.py b/ciderpress/dft/lcao_interpolation.py index be496e7..9db28cc 100644 --- a/ciderpress/dft/lcao_interpolation.py +++ b/ciderpress/dft/lcao_interpolation.py @@ -21,6 +21,7 @@ import ctypes import numpy as np +from pyscf import lib as pyscflib from ciderpress import lib from ciderpress.dft.lcao_convolutions import ( @@ -585,7 +586,7 @@ def _contract_grad_terms(self, excsum, f_g, a, v): assert iatom_list is not None assert iatom_list.flags.c_contiguous ngrids = iatom_list.size - libcider.contract_grad_terms2( + libcider.contract_grad_terms_parallel( excsum.ctypes.data_as(ctypes.c_void_p), f_g.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(self.atco.natm), @@ -637,8 +638,7 @@ def _interpolate_nopar_atom_deriv(self, f_arlpq, f_gq): args[3] = auxo_vgp[0].ctypes.data_as(ctypes.c_void_p) fn(*args) self._call_l1_fill(ftmp_gq, self.atom_coords[a], True) - # TODO accelerate since this step will be done many times - ftmp = np.einsum("gq,gq->g", ftmp_gq, f_gq) + ftmp = pyscflib.einsum("gq,gq->g", ftmp_gq, f_gq) self._contract_grad_terms(excsum, ftmp, a, v) return excsum diff --git a/ciderpress/lib/mod_cider/conv_interpolation.c b/ciderpress/lib/mod_cider/conv_interpolation.c index ff5dfce..fd901c5 100644 --- a/ciderpress/lib/mod_cider/conv_interpolation.c +++ b/ciderpress/lib/mod_cider/conv_interpolation.c @@ -23,6 +23,7 @@ #include "sph_harm.h" #include "spline.h" #include +#include #include #include @@ -737,8 +738,8 @@ void compute_pot_convs_single_new(double *f_gq, double *f_rlpq, double *auxo_gl, int q, ir, g, p; double *f_lpq; double *auxo_tmp_gl; - double BETA = - 0; // TODO beta should be 1, not 0, when doing grid batches + // TODO beta should be 1, not 0, when doing grid batches + double BETA = 0; double ALPHA = 1; char NTRANS = 'N'; char TRANS = 'T'; @@ -1209,8 +1210,8 @@ void add_lp1_term_onsite_bwd(double *f, double *coords, int natm, } // TODO might want to move this somewhere else -void contract_grad_terms(double *excsum, double *f_g, int natm, int a, int v, - int ngrids, int *ga_loc) { +void contract_grad_terms_old(double *excsum, double *f_g, int natm, int a, + int v, int ngrids, int *ga_loc) { double *tmp = (double *)calloc(natm, sizeof(double)); int ib; #pragma omp parallel @@ -1231,9 +1232,8 @@ void contract_grad_terms(double *excsum, double *f_g, int natm, int a, int v, free(tmp); } -void contract_grad_terms2(double *excsum, double *f_g, int natm, int a, int v, - int ngrids, int *atm_g) { - // TODO neeed to parallelize +void contract_grad_terms_serial(double *excsum, double *f_g, int natm, int a, + int v, int ngrids, int *atm_g) { double *tmp = (double *)calloc(natm, sizeof(double)); int ib; int ia; @@ -1247,3 +1247,37 @@ void contract_grad_terms2(double *excsum, double *f_g, int natm, int a, int v, } free(tmp); } + +void contract_grad_terms_parallel(double *excsum, double *f_g, int natm, int a, + int v, int ngrids, int *atm_g) { + double *tmp_priv; + double total = 0; +#pragma omp parallel + { + const int nthreads = omp_get_num_threads(); + const int ithread = omp_get_thread_num(); + const int ngrids_local = (ngrids + nthreads - 1) / nthreads; + const int ig0 = ithread * ngrids_local; + const int ig1 = MIN(ig0 + ngrids_local, ngrids); +#pragma omp single + { tmp_priv = (double *)calloc(nthreads * natm, sizeof(double)); } +#pragma omp barrier + double *my_tmp = tmp_priv + ithread * natm; + int ib; + int it; + int g; + for (g = ig0; g < ig1; g++) { + my_tmp[atm_g[g]] += f_g[g]; + } +#pragma omp barrier +#pragma omp for reduction(+ : total) + for (ib = 0; ib < natm; ib++) { + for (it = 0; it < nthreads; it++) { + excsum[ib * 3 + v] += tmp_priv[it * natm + ib]; + total += tmp_priv[it * natm + ib]; + } + } + } + excsum[a * 3 + v] -= total; + free(tmp_priv); +}