Skip to content

Commit

Permalink
More native Complex tree operations: add, mult, Bank, send, rescale
Browse files Browse the repository at this point in the history
  • Loading branch information
gitpeterwind committed Jul 12, 2024
1 parent fa478e8 commit 303a545
Show file tree
Hide file tree
Showing 19 changed files with 267 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/treebuilders/AdditionCalculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ template <int D, typename T> class AdditionCalculator final : public TreeCalcula
const NodeIndex<D> &idx = node_o.getNodeIndex();
T *coefs_o = node_o.getCoefs();
for (int i = 0; i < this->sum_vec.size(); i++) {
double c_i = get_coef(this->sum_vec, i);
T c_i = get_coef(this->sum_vec, i);
FunctionTree<D, T> &func_i = get_func(this->sum_vec, i);
// This generates missing nodes
const MWNode<D, T> &node_i = func_i.getNode(idx);
Expand Down
2 changes: 1 addition & 1 deletion src/treebuilders/MultiplicationCalculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ template <int D, typename T> class MultiplicationCalculator final : public TreeC
T *coefs_o = node_o.getCoefs();
for (int j = 0; j < node_o.getNCoefs(); j++) { coefs_o[j] = 1.0; }
for (int i = 0; i < this->prod_vec.size(); i++) {
double c_i = get_coef(this->prod_vec, i);
T c_i = get_coef(this->prod_vec, i);
FunctionTree<D, T> &func_i = get_func(this->prod_vec, i);
// This generates missing nodes
MWNode<D, T> node_i = func_i.getNode(idx); // Copy node
Expand Down
18 changes: 8 additions & 10 deletions src/treebuilders/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ namespace mrcpp {
template <int D, typename T>
void add(double prec,
FunctionTree<D, T> &out,
double a,
T a,
FunctionTree<D, T> &inp_a,
double b,
T b,
FunctionTree<D, T> &inp_b,
int maxIter,
bool absPrec) {
Expand Down Expand Up @@ -190,29 +190,27 @@ template void add<3, double>(double prec,
bool absPrec);




template void add<1, ComplexDouble>(double prec,
FunctionTree<1, ComplexDouble> &out,
double a,
ComplexDouble a,
FunctionTree<1, ComplexDouble> &tree_a,
double b,
ComplexDouble b,
FunctionTree<1, ComplexDouble> &tree_b,
int maxIter,
bool absPrec);
template void add<2, ComplexDouble>(double prec,
FunctionTree<2, ComplexDouble> &out,
double a,
ComplexDouble a,
FunctionTree<2, ComplexDouble> &tree_a,
double b,
ComplexDouble b,
FunctionTree<2, ComplexDouble> &tree_b,
int maxIter,
bool absPrec);
template void add<3, ComplexDouble>(double prec,
FunctionTree<3, ComplexDouble> &out,
double a,
ComplexDouble a,
FunctionTree<3, ComplexDouble> &tree_a,
double b,
ComplexDouble b,
FunctionTree<3, ComplexDouble> &tree_b,
int maxIter,
bool absPrec);
Expand Down
4 changes: 2 additions & 2 deletions src/treebuilders/add.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace mrcpp {

template <int D, typename T> void add(double prec,
FunctionTree<D, T> &out,
double a,
T a,
FunctionTree<D, T> &tree_a,
double b,
T b,
FunctionTree<D, T> &tree_b,
int maxIter = -1,
bool absPrec = false);
Expand Down
2 changes: 1 addition & 1 deletion src/treebuilders/apply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ template <int D, typename T> void divergence(FunctionTree<D, T> &out, Derivative

FunctionTreeVector<D, T> tmp_vec;
for (int d = 0; d < D; d++) {
double coef_d = get_coef(inp, d);
T coef_d = get_coef(inp, d);
FunctionTree<D, T> &func_d = get_func(inp, d);
auto *out_d = new FunctionTree<D, T>(func_d.getMRA());
apply(*out_d, oper, func_d, d);
Expand Down
15 changes: 8 additions & 7 deletions src/treebuilders/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace mrcpp {
template <int D, typename T>
void multiply(double prec,
FunctionTree<D, T> &out,
double c,
T c,
FunctionTree<D, T> &inp_a,
FunctionTree<D, T> &inp_b,
int maxIter,
Expand Down Expand Up @@ -278,13 +278,14 @@ void dot(double prec,

FunctionTreeVector<D, T> tmp_vec;
for (int d = 0; d < inp_a.size(); d++) {
double coef_a = get_coef(inp_a, d);
double coef_b = get_coef(inp_b, d);
T coef_a = get_coef(inp_a, d);
T coef_b = get_coef(inp_b, d);
FunctionTree<D, T> &tree_a = get_func(inp_a, d);
FunctionTree<D, T> &tree_b = get_func(inp_b, d);
auto *out_d = new FunctionTree<D, T>(out.getMRA());
build_grid(*out_d, out);
multiply(prec, *out_d, 1.0, tree_a, tree_b, maxIter, absPrec);
T One = 1.0;
multiply(prec, *out_d, One, tree_a, tree_b, maxIter, absPrec);
tmp_vec.push_back({coef_a * coef_b, out_d});
}
build_grid(out, tmp_vec);
Expand Down Expand Up @@ -509,23 +510,23 @@ template double node_norm_dot<3, double>(FunctionTree<3, double> &bra, FunctionT

template void multiply<1, ComplexDouble>(double prec,
FunctionTree<1, ComplexDouble> &out,
double c,
ComplexDouble c,
FunctionTree<1, ComplexDouble> &tree_a,
FunctionTree<1, ComplexDouble> &tree_b,
int maxIter,
bool absPrec,
bool useMaxNorms);
template void multiply<2, ComplexDouble>(double prec,
FunctionTree<2, ComplexDouble> &out,
double c,
ComplexDouble c,
FunctionTree<2, ComplexDouble> &tree_a,
FunctionTree<2, ComplexDouble> &tree_b,
int maxIter,
bool absPrec,
bool useMaxNorms);
template void multiply<3, ComplexDouble>(double prec,
FunctionTree<3, ComplexDouble> &out,
double c,
ComplexDouble c,
FunctionTree<3, ComplexDouble> &tree_a,
FunctionTree<3, ComplexDouble> &tree_b,
int maxIter,
Expand Down
2 changes: 1 addition & 1 deletion src/treebuilders/multiply.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ template <int D, typename T> double node_norm_dot(FunctionTree<D, T> &bra,

template <int D, typename T> void multiply(double prec,
FunctionTree<D, T> &out,
double c,
T c,
FunctionTree<D, T> &inp_a,
FunctionTree<D, T> &inp_b,
int maxIter = -1,
Expand Down
39 changes: 31 additions & 8 deletions src/trees/FunctionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ template <int D, typename T> void FunctionTree<D, T>::power(double p) {
* in-place multiplied by the given coefficient, no grid refinement.
*
*/
template <int D, typename T> void FunctionTree<D, T>::rescale(double c) {
template <int D, typename T> void FunctionTree<D, T>::rescale(T c) {
if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared");
#pragma omp parallel firstprivate(c) num_threads(mrcpp_get_num_threads())
{
Expand Down Expand Up @@ -399,7 +399,7 @@ template <int D, typename T> void FunctionTree<D, T>::normalize() {
* the function, i.e. no further grid refinement.
*
*/
template <int D, typename T> void FunctionTree<D, T>::add(double c, FunctionTree<D, T> &inp) {
template <int D, typename T> void FunctionTree<D, T>::add(T c, FunctionTree<D, T> &inp) {
if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");
if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared");
#pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads())
Expand Down Expand Up @@ -428,7 +428,7 @@ template <int D, typename T> void FunctionTree<D, T>::add(double c, FunctionTree
* function, i.e. no further grid refinement.
*
*/
template <int D, typename T> void FunctionTree<D, T>::absadd(double c, FunctionTree<D, T> &inp) {
template <int D, typename T> void FunctionTree<D, T>::absadd (T c, FunctionTree<D, T> &inp) {
if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared");
#pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads())
{
Expand All @@ -443,7 +443,7 @@ template <int D, typename T> void FunctionTree<D, T>::absadd(double c, FunctionT
inp_node.cvTransform(Forward);
T *out_coefs = out_node.getCoefs();
const T *inp_coefs = inp_node.getCoefs();
for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] = abs(out_coefs[i]) + c * abs(inp_coefs[i]); }
for (int i = 0; i < inp_node.getNCoefs(); i++) { out_coefs[i] = std::norm(out_coefs[i]) + std::norm(c * inp_coefs[i]); }
out_node.cvTransform(Backward);
out_node.mwTransform(Compression);
out_node.calcNorms();
Expand All @@ -463,7 +463,7 @@ template <int D, typename T> void FunctionTree<D, T>::absadd(double c, FunctionT
* of the function, i.e. no further grid refinement.
*
*/
template <int D, typename T> void FunctionTree<D, T>::multiply(double c, FunctionTree<D, T> &inp) {
template <int D, typename T> void FunctionTree<D, T>::multiply(T c, FunctionTree<D, T> &inp) {
if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");
if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared");
#pragma omp parallel firstprivate(c) shared(inp) num_threads(mrcpp_get_num_threads())
Expand Down Expand Up @@ -763,16 +763,39 @@ template <int D, typename T> void FunctionTree<D, T>::deleteGeneratedParents() {
for (int n = 0; n < this->getRootBox().size(); n++) this->getRootMWNode(n).deleteParent();
}

template <> int FunctionTree<3>::saveNodesAndRmCoeff() {
template <> int FunctionTree<3, double>::saveNodesAndRmCoeff() {
if (this->isLocal) MSG_INFO("Tree is already in local representation");
NodesCoeff = new BankAccount; // NB: must be a collective call!
int stack_p = 0;
if (mpi::wrk_rank == 0) {
int sizecoeff = (1 << 3) * this->getKp1_d();
std::vector<MWNode<3> *> stack; // nodes from this Tree
std::vector<MWNode<3, double> *> stack; // nodes from this Tree
for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { stack.push_back(this->getRootBox().getNodes()[rIdx]); }
while (stack.size() > stack_p) {
MWNode<3> *Node = stack[stack_p++];
MWNode<3, double> *Node = stack[stack_p++];
int id = 0;
NodesCoeff->put_data(Node->getNodeIndex(), sizecoeff, Node->getCoefs());
for (int i = 0; i < Node->getNChildren(); i++) { stack.push_back(Node->children[i]); }
}
}
this->nodeAllocator_p->deallocAllCoeff();
mpi::broadcast_Tree_noCoeff(*this, mpi::comm_wrk);
this->isLocal = true;
assert(this->NodeIndex2serialIx.size() == getNNodes());
return this->NodeIndex2serialIx.size();
}

template <> int FunctionTree<3, ComplexDouble>::saveNodesAndRmCoeff() {
if (this->isLocal) MSG_INFO("Tree is already in local representation");
NodesCoeff = new BankAccount; // NB: must be a collective call!
int stack_p = 0;
if (mpi::wrk_rank == 0) {
int sizecoeff = (1 << 3) * this->getKp1_d();
sizecoeff *= 2; // double->ComplexDouble. Saved as twice as many doubles
std::vector<MWNode<3, ComplexDouble> *> stack; // nodes from this Tree
for (int rIdx = 0; rIdx < this->getRootBox().size(); rIdx++) { stack.push_back(this->getRootBox().getNodes()[rIdx]); }
while (stack.size() > stack_p) {
MWNode<3, ComplexDouble> *Node = stack[stack_p++];
int id = 0;
NodesCoeff->put_data(Node->getNodeIndex(), sizecoeff, Node->getCoefs());
for (int i = 0; i < Node->getNChildren(); i++) { stack.push_back(Node->children[i]); }
Expand Down
8 changes: 4 additions & 4 deletions src/trees/FunctionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ template <int D, typename T> class FunctionTree final : public MWTree<D, T>, pub
// In place operations
void square();
void power(double p);
void rescale(double c);
void rescale(T c);
void normalize();
void add(double c, FunctionTree<D, T> &inp);
void absadd(double c, FunctionTree<D, T> &inp);
void multiply(double c, FunctionTree<D, T> &inp);
void add(T c, FunctionTree<D, T> &inp);
void absadd(T c, FunctionTree<D, T> &inp);
void multiply(T c, FunctionTree<D, T> &inp);
void map(FMap<T, T> fmap);

int getNChunks() { return this->getNodeAllocator().getNChunks(); }
Expand Down
4 changes: 2 additions & 2 deletions src/trees/FunctionTreeVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

namespace mrcpp {

template <int D, typename T = double> using CoefsFunctionTree = std::tuple<double, FunctionTree<D, T> *>;
template <int D, typename T = double> using CoefsFunctionTree = std::tuple<T, FunctionTree<D, T> *>;
template <int D, typename T = double> using FunctionTreeVector = std::vector<CoefsFunctionTree<D, T>>;

/** @brief Remove all entries in the vector
Expand Down Expand Up @@ -77,7 +77,7 @@ template <int D, typename T> int get_size_nodes(const FunctionTreeVector<D, T> &
* @param[in] fs: Vector to fetch from
* @param[in] i: Position in vector
*/
template <int D, typename T> double get_coef(const FunctionTreeVector<D, T> &fs, int i) {
template <int D, typename T> T get_coef(const FunctionTreeVector<D, T> &fs, int i) {
return std::get<0>(fs[i]);
}

Expand Down
4 changes: 2 additions & 2 deletions src/trees/MWNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ namespace mrcpp {
* translation index, the norm, pointers to parent node and child
* nodes, pointer to the corresponding MWTree etc... See member and
* data descriptions for details.
*
*
*/
template <int D, typename T> class MWNode {
template <int D, typename T> class MWNode {
public:
MWNode(const MWNode<D, T> &node, bool allocCoef = true, bool SetCoef = true);
MWNode<D , T> &operator=(const MWNode<D , T> &node) = delete;
Expand Down
10 changes: 5 additions & 5 deletions src/trees/MWTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ template <int D, typename T> void MWTree<D, T>::calcSquareNorm() {
* @details It performs a Multiwavlet transform of the whole tree. The
* input parameters will specify the direction (upwards or downwards)
* and whether the result is added to the coefficients or it
* overwrites them. See the documentation for the #mwTransformUp
* overwrites them. See the documentation for the #mwTransformUp
* and #mwTransformDown for details.
* \f[
* \f[
* \pmatrix{
* s_{nl}\\
* d_{nl}
Expand Down Expand Up @@ -215,7 +215,7 @@ template <int D, typename T> void MWTree<D, T>::mwTransformDown(bool overwrite)
}

/** @brief Set the MW coefficients to zero, keeping the same tree structure
*
*
* @details Keeps the node structure of the tree, even though the zero
* function is representable at depth zero. One should then use \ref cropTree to remove
* unnecessary nodes.
Expand Down Expand Up @@ -447,7 +447,7 @@ template <int D, typename T> MWNodeVector<D, T> *MWTree<D, T>::copyEndNodeTable(
*
* @details the endNodeTable is first deleted and then rebuilt from
* scratch. It makes use of the TreeIterator to traverse the tree.
*
*
*/
template <int D, typename T> void MWTree<D, T>::resetEndNodeTable() {
clearEndNodeTable();
Expand Down Expand Up @@ -552,7 +552,7 @@ template <int D, typename T> int MWTree<D, T>::getIx(NodeIndex<D> nIdx) {
else return NodeIndex2serialIx[nIdx];
}

template <int D, typename T> void MWTree<D, T>::getNodeCoeff(NodeIndex<D> nIdx, double *data) {
template <int D, typename T> void MWTree<D, T>::getNodeCoeff(NodeIndex<D> nIdx, T *data) {
assert(this->isLocal);
int size = (1 << D) * kp1_d;
int id = 0;
Expand Down
2 changes: 1 addition & 1 deletion src/trees/MWTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class BankAccount;
const NodeAllocator<D, T> &getNodeAllocator() const { return *this->nodeAllocator_p; }
MWNodeVector<D, T> endNodeTable; ///< Final projected nodes

void getNodeCoeff(NodeIndex<D> nIdx, double *data); // fetch coefficient from a specific node stored in Bank
void getNodeCoeff(NodeIndex<D> nIdx, T *data); // fetch coefficient from a specific node stored in Bank

friend std::ostream &operator<<(std::ostream &o, const MWTree<D, T> &tree) { return tree.print(o); }

Expand Down
Loading

0 comments on commit 303a545

Please sign in to comment.