Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinearAlgebra] constexpr if statement when possible #4352

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 88 additions & 64 deletions Sofa/framework/LinearAlgebra/src/sofa/linearalgebra/BaseMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace sofa::linearalgebra
{

BaseMatrix::BaseMatrix() {}
BaseMatrix::BaseMatrix() = default;

BaseMatrix::~BaseMatrix()
{}
Expand Down Expand Up @@ -64,20 +64,27 @@ struct BaseMatrixLinearOpMV_BlockDiagonal
const Index colSize = mat->colSize();
BlockData buffer;

if (!add)
opVresize(result, (transpose ? colSize : rowSize));
for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = mat->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
for (auto [rowIt, rowEnd] = mat->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
if (colRange.first != colRange.second) // diagonal block exists
auto [colBegin, colEnd] = rowIt.range();
if (colBegin != colEnd) // diagonal block exists
{
BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colBegin.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index i = block.getRow() * NL;
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
type::VecNoInit<NC,Real> vj;
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -162,8 +169,17 @@ struct BaseMatrixLinearOpMV_BlockDiagonal<Real, 1, 1, add, transpose, M, V1, V2>
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand All @@ -188,14 +204,14 @@ struct BaseMatrixLinearOpMV_BlockSparse
BlockData buffer;
type::Vec<NC,Real> vtmpj;
type::Vec<NL,Real> vtmpi;
if (!add)
if constexpr (!add)
{
opVresize(result, (transpose ? colSize : rowSize));
for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = mat->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
}
for (auto [rowIt, rowEnd] = mat->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
const Index i = rowRange.first.row() * NL;
if (!transpose)
const Index i = rowIt.row() * NL;
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
vtmpi[bi] = (Real)0;
Expand All @@ -205,14 +221,12 @@ struct BaseMatrixLinearOpMV_BlockSparse
for (int bi = 0; bi < NL; ++bi)
vtmpi[bi] = (Real)opVget(v, i+bi);
}
for (std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
colRange.first != colRange.second;
++colRange.first)
for (auto [colIt, colEnd] = rowIt.range(); colIt != colEnd; ++colIt)
{
BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colIt.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
for (int bj = 0; bj < NC; ++bj)
vtmpj[bj] = (Real)opVget(v, j+bj);
Expand All @@ -231,14 +245,11 @@ struct BaseMatrixLinearOpMV_BlockSparse
opVadd(result, j+bj, vtmpj[bj]);
}
}
if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
opVadd(result, i+bi, vtmpi[bi]);
}
else
{
}
}
}
};
Expand All @@ -253,9 +264,18 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if (!transpose)
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
if constexpr (!transpose)
{
for (Index i=0; i<rowSize; ++i)
{
Expand Down Expand Up @@ -285,8 +305,17 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand All @@ -300,8 +329,17 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand Down Expand Up @@ -498,22 +536,18 @@ struct BaseMatrixLinearOpAM_BlockSparse
{
BlockData buffer;

for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = m1->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
for (auto [rowIt, rowEnd] = m1->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
const Index i = rowRange.first.row() * NL;
const Index i = rowIt.row() * NL;

for (std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
colRange.first != colRange.second;
++colRange.first)
for (auto [colIt, colEnd] = rowIt.range(); colIt != colEnd; ++colIt)
{

BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colIt.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;

if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -544,21 +578,16 @@ struct BaseMatrixLinearOpAMS_BlockSparse
{
BlockData buffer;

for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = m1->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
for (auto [rowIt, rowEnd] = m1->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
const Index i = rowRange.first.row() * NL;
const Index i = rowIt.row() * NL;

for (std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
colRange.first != colRange.second;
++colRange.first)
for (auto [colIt, colEnd] = rowIt.range(); colIt != colEnd; ++colIt)
{

BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colIt.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -588,22 +617,17 @@ struct BaseMatrixLinearOpAM1_BlockSparse
{
BlockData buffer;

for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = m1->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
for (auto [rowIt, rowEnd] = m1->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
const Index i = rowRange.first.row();
const Index i = rowIt.row();

for (std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
colRange.first != colRange.second;
++colRange.first)
for (auto [colIt, colEnd] = rowIt.range(); colIt != colEnd; ++colIt)
{

BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colIt.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(&buffer);
const Index j = block.getCol();

if (!transpose)
if constexpr (!transpose)
{
m2->add(i,j,bdata * fact);
}
Expand All @@ -626,7 +650,7 @@ class BaseMatrixLinearOpAM
{
const Index rowSize = m1->rowSize();
const Index colSize = m2->colSize();
if (!transpose)
if constexpr (!transpose)
{
for (Index j=0; j<rowSize; ++j)
{
Expand Down
Loading