Skip to content

Commit

Permalink
[LinearAlgebra] constexpr if statement when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
alxbilger authored and fredroy committed Dec 14, 2023
1 parent 6ed9f8f commit ecfb87e
Showing 1 changed file with 72 additions and 28 deletions.
100 changes: 72 additions & 28 deletions Sofa/framework/LinearAlgebra/src/sofa/linearalgebra/BaseMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace sofa::linearalgebra
{

BaseMatrix::BaseMatrix() {}
BaseMatrix::BaseMatrix() = default;

BaseMatrix::~BaseMatrix()
{}
Expand Down Expand Up @@ -64,11 +64,20 @@ struct BaseMatrixLinearOpMV_BlockDiagonal
const Index colSize = mat->colSize();
BlockData buffer;

if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = mat->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
rowRange.first != rowRange.second;
++rowRange.first)
{
std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
if (colRange.first != colRange.second) // diagonal block exists
Expand All @@ -77,7 +86,7 @@ struct BaseMatrixLinearOpMV_BlockDiagonal
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index i = block.getRow() * NL;
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
type::VecNoInit<NC,Real> vj;
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -162,8 +171,17 @@ struct BaseMatrixLinearOpMV_BlockDiagonal<Real, 1, 1, add, transpose, M, V1, V2>
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand All @@ -188,14 +206,16 @@ struct BaseMatrixLinearOpMV_BlockSparse
BlockData buffer;
type::Vec<NC,Real> vtmpj;
type::Vec<NL,Real> vtmpi;
if (!add)
if constexpr (!add)
{
opVresize(result, (transpose ? colSize : rowSize));
}
for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = mat->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
rowRange.first != rowRange.second;
++rowRange.first)
{
const Index i = rowRange.first.row() * NL;
if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
vtmpi[bi] = (Real)0;
Expand All @@ -212,7 +232,7 @@ struct BaseMatrixLinearOpMV_BlockSparse
BlockConstAccessor block = colRange.first.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
for (int bj = 0; bj < NC; ++bj)
vtmpj[bj] = (Real)opVget(v, j+bj);
Expand All @@ -231,14 +251,11 @@ struct BaseMatrixLinearOpMV_BlockSparse
opVadd(result, j+bj, vtmpj[bj]);
}
}
if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
opVadd(result, i+bi, vtmpi[bi]);
}
else
{
}
}
}
};
Expand All @@ -253,9 +270,18 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if (!transpose)
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
if constexpr (!transpose)
{
for (Index i=0; i<rowSize; ++i)
{
Expand Down Expand Up @@ -285,8 +311,17 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand All @@ -300,8 +335,17 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand Down Expand Up @@ -513,7 +557,7 @@ struct BaseMatrixLinearOpAM_BlockSparse
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;

if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -558,7 +602,7 @@ struct BaseMatrixLinearOpAMS_BlockSparse
BlockConstAccessor block = colRange.first.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -603,7 +647,7 @@ struct BaseMatrixLinearOpAM1_BlockSparse
const BlockData& bdata = *(const BlockData*)block.elements(&buffer);
const Index j = block.getCol();

if (!transpose)
if constexpr (!transpose)
{
m2->add(i,j,bdata * fact);
}
Expand All @@ -626,7 +670,7 @@ class BaseMatrixLinearOpAM
{
const Index rowSize = m1->rowSize();
const Index colSize = m2->colSize();
if (!transpose)
if constexpr (!transpose)
{
for (Index j=0; j<rowSize; ++j)
{
Expand Down

0 comments on commit ecfb87e

Please sign in to comment.