Skip to content

Commit

Permalink
Merge pull request #212 from gitpeterwind/vecmult
Browse files Browse the repository at this point in the history
Vecmult
  • Loading branch information
gitpeterwind authored Oct 21, 2023
2 parents 0dde01e + 6f6452c commit f8def0a
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 35 deletions.
9 changes: 2 additions & 7 deletions src/trees/FunctionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,8 +737,8 @@ template <> int FunctionTree<3>::saveNodesAndRmCoeff() {
for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { stack.push_back(this->getRootBox().getNodes()[rIdx]); }
while (stack.size() > stack_p) {
MWNode<3> *Node = stack[stack_p++];
this->NodeIndex2serialIx[Node->getNodeIndex()] = Node->serialIx;
NodesCoeff->put_data(Node->serialIx, sizecoeff, Node->getCoefs());
int id = 0;
NodesCoeff->put_data(Node->getNodeIndex(), sizecoeff, Node->getCoefs());
for (int i = 0; i < Node->getNChildren(); i++) { stack.push_back(Node->children[i]); }
}
}
Expand All @@ -749,11 +749,6 @@ template <> int FunctionTree<3>::saveNodesAndRmCoeff() {
return this->NodeIndex2serialIx.size();
}

template <> void FunctionTree<3>::getNodeCoeff(int id, int size, double *data) {
assert(this->isLocal);
this->NodesCoeff->get_data(id, size, data);
}

template class FunctionTree<1>;
template class FunctionTree<2>;
template class FunctionTree<3>;
Expand Down
6 changes: 1 addition & 5 deletions src/trees/FunctionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@

namespace mrcpp {

class BankAccount;

/** @class FunctionTree
*
* @brief Function representation in MW basis
Expand Down Expand Up @@ -113,12 +111,10 @@ template <int D> class FunctionTree final : public MWTree<D>, public Representab
void appendTreeNoCoeff(MWTree<D> &inTree);

// tools for use of local (nodes are stored in Bank) representation
int saveNodesAndRmCoeff(); // put all nodes coefficients in Bank and delete all coefficients
void getNodeCoeff(int id, int size, double *data); // fetch coefficient from a specific node stored in Bank
int saveNodesAndRmCoeff(); // put all nodes coefficients in Bank and delete all coefficients
protected:
std::unique_ptr<NodeAllocator<D>> genNodeAllocator_p{nullptr};
std::ostream &print(std::ostream &o) const override;
BankAccount *NodesCoeff = nullptr;

void allocRootNodes();
};
Expand Down
12 changes: 10 additions & 2 deletions src/trees/MWNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "utils/Printer.h"
#include "utils/Timer.h"
#include "utils/math_utils.h"
#include "utils/parallel.h"
#include "utils/tree_utils.h"

using namespace Eigen;
Expand Down Expand Up @@ -409,6 +410,7 @@ template <int D> void MWNode<D>::copyCoefsFromChildren() {
* coefficients for the children
*/
template <int D> void MWNode<D>::threadSafeGenChildren() {
if (tree->isLocal) { NOT_IMPLEMENTED_ABORT; }
MRCPP_SET_OMP_LOCK();
if (isLeafNode()) {
genChildren();
Expand Down Expand Up @@ -734,7 +736,7 @@ template <int D> Coord<D> MWNode<D>::getCenter() const {
auto scaling_factor = getMWTree().getMRA().getWorldBox().getScalingFactors();
auto &l = getNodeIndex();
auto r = Coord<D>{};
for (int d = 0; d < D; d++) r[d] = scaling_factor[d]*two_n*(l[d] + 0.5);
for (int d = 0; d < D; d++) r[d] = scaling_factor[d] * two_n * (l[d] + 0.5);
return r;
}

Expand Down Expand Up @@ -1076,10 +1078,16 @@ template <int D> MWNode<D> *MWNode<D>::retrieveNode(const Coord<D> &r, int depth
* Recursive routine to find and return the node with a given NodeIndex. This
* routine always returns the appropriate node, and will generate nodes that
* does not exist. Recursion starts at this node and ASSUMES the requested
* node is in fact decending from this node.
* node is in fact descending from this node.
*/
template <int D> MWNode<D> *MWNode<D>::retrieveNode(const NodeIndex<D> &idx) {
if (getScale() == idx.getScale()) { // we're done
if (tree->isLocal) {
// has to fetch coeff in Bank. NOT USED YET
int ncoefs = (1 << D) * this->getKp1_d();
coefs = new double[ncoefs]; // TODO must be cleaned at some stage
tree->getNodeCoeff(idx, coefs);
}
assert(getNodeIndex() == idx);
return this;
}
Expand Down
9 changes: 9 additions & 0 deletions src/trees/MWTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "TreeIterator.h"
#include "MultiResolutionAnalysis.h"
#include "NodeAllocator.h"
#include "utils/Bank.h"
#include "utils/Printer.h"
#include "utils/math_utils.h"
#include "utils/periodic_utils.h"
Expand Down Expand Up @@ -550,6 +551,14 @@ template <int D> int MWTree<D>::getIx(NodeIndex<D> nIdx) {
else return NodeIndex2serialIx[nIdx];
}

template <int D> void MWTree<D>::getNodeCoeff(NodeIndex<D> nIdx, double *data) {
assert(this->isLocal);
int size = (1 << D) * kp1_d;
int id = 0;
for (int i = 0; i < D; i++) id += std::abs(nIdx.getTranslation(i));
this->NodesCoeff->get_data(id, size, data);
}

template class MWTree<1>;
template class MWTree<2>;
template class MWTree<3>;
Expand Down
6 changes: 6 additions & 0 deletions src/trees/MWTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@

namespace mrcpp {

class BankAccount;

/** @class MWTree
*
* @brief Base class for Multiwavelet tree structures, such as FunctionTree and OperatorTree
Expand Down Expand Up @@ -140,6 +142,8 @@ template <int D> class MWTree {
const NodeAllocator<D> &getNodeAllocator() const { return *this->nodeAllocator_p; }
MWNodeVector<D> endNodeTable; ///< Final projected nodes

void getNodeCoeff(NodeIndex<D> nIdx, double *data); // fetch coefficient from a specific node stored in Bank

friend std::ostream &operator<<(std::ostream &o, const MWTree<D> &tree) { return tree.print(o); }

friend class MWNode<D>;
Expand Down Expand Up @@ -175,6 +179,8 @@ template <int D> class MWTree {
void incrementNodeCount(int scale);
void decrementNodeCount(int scale);

BankAccount *NodesCoeff = nullptr;

virtual std::ostream &print(std::ostream &o) const;
};

Expand Down
76 changes: 72 additions & 4 deletions src/utils/Bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ struct Blockdata_struct {
std::map<int, std::map<int, Blockdata_struct *> *> get_nodeid2block; // to get block from its nodeid (all coeff for one node)
std::map<int, std::map<int, Blockdata_struct *> *> get_orbid2block; // to get block from its orbid

int const MIN_SCALE = -999; // Smaller than smallest scale

void Bank::open() {
#ifdef MRCPP_HAS_MPI
MPI_Status status;
Expand All @@ -36,7 +38,7 @@ void Bank::open() {
int messages[message_size];
int datasize = -1;
std::map<int, int> get_numberofclients;

std::map<NodeIndex<3>, int> nIdx2id;
bool printinfo = false;
int max_account_id = -1;
int next_task = 0;
Expand Down Expand Up @@ -287,6 +289,18 @@ void Bank::open() {
else if (message == GET_FUNCTION or message == GET_FUNCTION_AND_WAIT or message == GET_FUNCTION_AND_DELETE or message == GET_FUNCTION or message == GET_DATA) {
// withdrawal
int id = messages[2];
if (message == GET_DATA and messages[3] > MIN_SCALE) {
NodeIndex<3> nIdx;
nIdx.setScale(messages[4]);
nIdx.setTranslation({messages[2], messages[5], messages[6]});
if (nIdx2id.count(nIdx) == 0) {
// data is not yet saved, but one can hope it will be created at some stage
id = nIdx2id.size();
nIdx2id[nIdx] = id;
} else {
id = nIdx2id[nIdx];
}
}
int ix = id2ix[id];
if (id2ix.count(id) == 0 or ix == 0) {
if (printinfo) std::cout << world_rank << " not found " << id << " " << message << std::endl;
Expand Down Expand Up @@ -366,8 +380,20 @@ void Bank::open() {
if (printinfo) std::cout << " written block " << nodeid << " id " << orbid << " subblocks " << nodeid2block[nodeid]->data.size() << std::endl;
} else if (message == SAVE_FUNCTION or message == SAVE_DATA) {
// make a new deposit
int exist_flag = 0;
int id = messages[2];
if (message == SAVE_DATA and messages[4] > MIN_SCALE) {
// has to find or create unique id from NodeIndex. Use the same internal mapping for all trees
NodeIndex<3> nIdx;
nIdx.setScale(messages[4]);
nIdx.setTranslation({messages[2], messages[5], messages[6]});
if (nIdx2id.count(nIdx) == 0) {
id = nIdx2id.size();
nIdx2id[nIdx] = id;
} else {
id = nIdx2id[nIdx];
}
}
int exist_flag = 0;
if (id2ix[id]) {
std::cout << "WARNING: id " << id << " exists already"
<< " " << status.MPI_SOURCE << " " << message << " " << messages[1] << std::endl;
Expand Down Expand Up @@ -753,11 +779,33 @@ int BankAccount::put_data(int id, int size, double *data) {
#ifdef MRCPP_HAS_MPI
// for now we distribute according to id
int messages[message_size];

messages[0] = SAVE_DATA;
messages[1] = account_id;
messages[2] = id;
messages[3] = size;
MPI_Send(messages, 4, MPI_INT, bankmaster[id % bank_size], 0, comm_bank);
messages[4] = MIN_SCALE; // to indicate that it is defined by id
MPI_Send(messages, 5, MPI_INT, bankmaster[id % bank_size], 0, comm_bank);
MPI_Send(data, size, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank);
#endif
return 1;
}

// save data in Bank with identity nIdx. datasize MUST have been set already. NB:not tested
int BankAccount::put_data(NodeIndex<3> nIdx, int size, double *data) {
#ifdef MRCPP_HAS_MPI
// for now we distribute according to id
int messages[message_size];
messages[0] = SAVE_DATA;
messages[1] = account_id;
messages[2] = nIdx.getTranslation(0);
messages[3] = size;
messages[4] = nIdx.getScale();
messages[5] = nIdx.getTranslation(1);
messages[6] = nIdx.getTranslation(2);
int id = std::abs(nIdx.getTranslation(0) + nIdx.getTranslation(1) + nIdx.getTranslation(2));
// std::cout<<mpi::wrk_rank<<" bankidx "<<bank_size<<" ID "<<messages[4]<<" "<<messages[2]<<" "<<messages[5]<<" "<<messages[6]<<std::endl;
MPI_Send(messages, 7, MPI_INT, bankmaster[id % bank_size], 0, comm_bank);
MPI_Send(data, size, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank);
#endif
return 1;
Expand All @@ -771,7 +819,27 @@ int BankAccount::get_data(int id, int size, double *data) {
messages[0] = GET_DATA;
messages[1] = account_id;
messages[2] = id;
MPI_Send(messages, 3, MPI_INT, bankmaster[id % bank_size], 0, comm_bank);
messages[3] = MIN_SCALE;
MPI_Send(messages, 4, MPI_INT, bankmaster[id % bank_size], 0, comm_bank);
MPI_Recv(data, size, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank, &status);
#endif
return 1;
}

// get data with identity id
int BankAccount::get_data(NodeIndex<3> nIdx, int size, double *data) {
#ifdef MRCPP_HAS_MPI
MPI_Status status;
int messages[message_size];
int id = std::abs(nIdx.getTranslation(0) + nIdx.getTranslation(1) + nIdx.getTranslation(2));
messages[0] = GET_DATA;
messages[1] = account_id;
messages[2] = id;
messages[3] = nIdx.getScale();
messages[4] = nIdx.getTranslation(0);
messages[5] = nIdx.getTranslation(1);
messages[6] = nIdx.getTranslation(2);
MPI_Send(messages, 7, MPI_INT, bankmaster[id % bank_size], 0, comm_bank);
MPI_Recv(data, size, MPI_DOUBLE, bankmaster[id % bank_size], 1, comm_bank, &status);
#endif
return 1;
Expand Down
5 changes: 4 additions & 1 deletion src/utils/Bank.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "ComplexFunction.h"
#include "parallel.h"
#include "trees/NodeIndex.h"

namespace mrcpp {

Expand Down Expand Up @@ -100,6 +101,8 @@ class BankAccount {
int get_func(int id, ComplexFunction &func, int wait = 0);
int put_data(int id, int size, double *data);
int get_data(int id, int size, double *data);
int put_data(NodeIndex<3> nIdx, int size, double *data);
int get_data(NodeIndex<3> nIdx, int size, double *data);
int put_nodedata(int id, int nodeid, int size, double *data);
int get_nodedata(int id, int nodeid, int size, double *data, std::vector<int> &idVec);
int get_nodeblock(int nodeid, double *data, std::vector<int> &idVec);
Expand All @@ -120,6 +123,6 @@ class TaskManager {
int n_tasks = 0; // used in serial case only
};

int const message_size = 5;
int const message_size = 7;

} // namespace mrcpp
46 changes: 32 additions & 14 deletions src/utils/ComplexFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,28 @@ void cplxfunc::multiply(ComplexFunction &out, ComplexFunction inp_a, ComplexFunc
multiply_imag(out, inp_a, inp_b, prec, absPrec, useMaxNorms);
}

/** @brief out = inp_a * f
*
*/
void cplxfunc::multiply(ComplexFunction &out, ComplexFunction &inp_a, RepresentableFunction<3> &f, double prec, int nrefine) {
// uses the mpifuncvec multiply
MPI_FuncVector mpi_funcvec_a;
mpi_funcvec_a.push_back(inp_a);
MPI_FuncVector mpi_funcvec_out;
mpi_funcvec_out = mpifuncvec::multiply(mpi_funcvec_a, f, prec, nullptr, nrefine, true);
out = mpi_funcvec_out[0];
}

/** @brief out = inp_a * f
*
*/
void cplxfunc::multiply(ComplexFunction &out, FunctionTree<3> &inp_a, RepresentableFunction<3> &f, double prec, int nrefine) {
ComplexFunction cplxfunc_a;
cplxfunc_a.setReal(&inp_a);
cplxfunc::multiply(out, cplxfunc_a, f, prec, nrefine);
cplxfunc_a.setReal(nullptr); // otherwise inp_a is deleted by cplxfunc_a destructor
}

/** @brief out = c_0*inp_0 + c_1*inp_1 + ... + c_N*inp_N
*
*/
Expand Down Expand Up @@ -1066,24 +1088,21 @@ void save_nodes(MPI_FuncVector &Phi, FunctionTree<3> &refTree, BankAccount &acco
* in parallel using a local representation.
* Input trees are extended by one scale at most.
*/
MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double prec, ComplexFunction *Func, int nrefine) {
MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double prec, ComplexFunction *Func, int nrefine, bool all) {

int N = Phi.size();
const int D = 3;
bool serial = mpi::wrk_size == 1; // flag for serial/MPI switch
if (serial) nrefine = 2;

// 1a) extend grid where f is large (around nuclei)
// TODO: do it in save_nodes + refTree, only saving the extra nodes, without keeping them permanently. Or refine refTree?

for (int i = 0; i < N; i++) {
if (!mpi::my_orb(i)) continue;
if (mrcpp::mpi::wrk_rank == 0) {
int irefine = 0;
while (Phi[i].hasReal() and irefine < nrefine and refine_grid(Phi[i].real(), f) > 0) irefine++;
irefine = 0;
while (Phi[i].hasImag() and irefine < nrefine and refine_grid(Phi[i].imag(), f) > 0) irefine++;
}
int irefine = 0;
while (Phi[i].hasReal() and irefine < nrefine and refine_grid(Phi[i].real(), f) > 0) irefine++;
irefine = 0;
while (Phi[i].hasImag() and irefine < nrefine and refine_grid(Phi[i].imag(), f) > 0) irefine++;
}

// 1b) make union tree without coefficients
Expand Down Expand Up @@ -1284,7 +1303,7 @@ MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double
fval[j] = f.evalf(r);
}
} else {
Func->real().getNodeCoeff(nIdx, nCoefs, fval); // fetch coef from Bank
Func->real().getNodeCoeff(nIdx, fval); // fetch coef from Bank
Fnode.attachCoefs(fval);
Fnode.mwTransform(Reconstruction);
Fnode.cvTransform(Forward);
Expand Down Expand Up @@ -1316,7 +1335,6 @@ MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double
Fnode.attachCoefs(nullptr); // to avoid deletion of valid multipliedCoeff by destructor
}
mrcpp::mpi::barrier(mrcpp::mpi::comm_wrk); // wait until everything is stored before fetching!
// if(mrcpp::mpi::wrk_rank==0){std::cout<<" not found "<<count2<<" of "<<count1<<std::endl;}
}

// 5) reconstruct trees using multiplied nodes.
Expand Down Expand Up @@ -1349,7 +1367,7 @@ MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double
}
} else {
for (int j = 0; j < Neff; j++) {
if (not mpi::my_orb(j % N)) continue;
if (not mpi::my_orb(j % N) and not all) continue;
// traverse possible nodes, and stop descending when norm is zero (leaf in out[j])
std::vector<double *> coeffpVec; //
std::map<int, int> ix2coef; // to find the index in coeffVec[] corresponding to a serialIx in refTree
Expand Down Expand Up @@ -1380,7 +1398,7 @@ MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double
out[j].real().calcSquareNorm();
out[j].real().resetEndNodeTable();
// out[j].real().crop(prec, 1.0, false); //bad convergence if out is cropped
Phi[j].real().crop(prec, 1.0, false); // restablishes original Phi
if (nrefine > 0) Phi[j].real().crop(prec, 1.0, false); // restablishes original Phi
}
} else {
if (Phi[j % N].hasImag()) {
Expand All @@ -1389,8 +1407,8 @@ MPI_FuncVector multiply(MPI_FuncVector &Phi, RepresentableFunction<3> &f, double
out[j % N].imag().makeTreefromCoeff(refTree, coeffpVec, ix2coef, -1.0, "copy");
out[j % N].imag().mwTransform(BottomUp);
out[j % N].imag().calcSquareNorm();
out[j % N].imag().crop(prec, 1.0, false);
Phi[j % N].imag().crop(prec, 1.0, false);
// out[j % N].imag().crop(prec, 1.0, false);
if (nrefine > 0) Phi[j % N].imag().crop(prec, 1.0, false);
}
}

Expand Down
Loading

0 comments on commit f8def0a

Please sign in to comment.