Skip to content

Commit

Permalink
Merge pull request #723 from cwentland0/alt_gpod
Browse files Browse the repository at this point in the history
Add alternative form of weighted Gauss-Newton nonlinear solve
  • Loading branch information
fnrizzi authored Feb 14, 2025
2 parents cf8b987 + b438718 commit 4077a90
Show file tree
Hide file tree
Showing 10 changed files with 522 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/pressio/solvers_nonlinear/impl/diagnostics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#define PRESSIO_SOLVERS_NONLINEAR_IMPL_DIAGNOSTICS_HPP_

#include <iostream>
#include <iomanip>

namespace pressio{
namespace nonlinearsolvers{
Expand Down
71 changes: 71 additions & 0 deletions include/pressio/solvers_nonlinear/impl/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,23 @@ auto compute_nonlinearls_objective(WeightedGaussNewtonNormalEqTag /*tag*/,
return v * (static_cast<sc_t>(1) / static_cast<sc_t>(2));
}

template<class RegistryType, class StateType, class SystemType>
auto compute_nonlinearls_objective(CompactWeightedGaussNewtonNormalEqTag /*tag*/,
RegistryType & reg,
const StateType & state,
const SystemType & system)
{
const auto & W = reg.template get<WeightingOperatorTag>();
auto & r = reg.template get<ResidualTag>();
auto & Wr = reg.template get<WeightedResidualTag>();
compute_residual(reg, state, system);
W.get()(r, Wr);

const auto v = ::pressio::ops::dot(Wr, Wr);
using sc_t = mpl::remove_cvref_t< decltype(v) >;
return v * (static_cast<sc_t>(1) / static_cast<sc_t>(2));
}

#ifdef PRESSIO_ENABLE_CXX20
template<class RegistryType, class SystemType>
requires RealValuedNonlinearSystemFusingResidualAndJacobian<SystemType>
Expand Down Expand Up @@ -234,6 +251,50 @@ auto compute_nonlinearls_operators_and_objective(WeightedGaussNewtonNormalEqTag
return v * (static_cast<sc_t>(1) / static_cast<sc_t>(2));
}

/* Special case of weighted Gauss Newton, just changing the action of W such that
H = (J^T_r * W^T) * (W * J_r)
g = (J^T_r * W^T) * (W * r)
In instances where W is shape [M x N] and M << N, this results in a significant memory usage reduction
Particularly useful for GNAT on large sample meshes,
where M is the number of modes and N is the number of sampling points
*/
#ifdef PRESSIO_ENABLE_CXX20
template<class RegistryType, class SystemType>
requires RealValuedNonlinearSystemFusingResidualAndJacobian<SystemType>
#else
template<
class RegistryType, class SystemType,
std::enable_if_t<
RealValuedNonlinearSystemFusingResidualAndJacobian<SystemType>::value,
int> = 0
>
#endif
auto compute_nonlinearls_operators_and_objective(CompactWeightedGaussNewtonNormalEqTag /*tag*/,
RegistryType & reg,
const SystemType & system)
{
compute_residual_and_jacobian(reg, system);

constexpr auto pT = ::pressio::transpose();
constexpr auto pnT = ::pressio::nontranspose();
const auto & W = reg.template get<WeightingOperatorTag>();
const auto & r = reg.template get<ResidualTag>();
const auto & J = reg.template get<JacobianTag>();
auto & Wr = reg.template get<WeightedResidualTag>();
auto & WJ = reg.template get<WeightedJacobianTag>();
auto & g = reg.template get<GradientTag>();
auto & H = reg.template get<HessianTag>();

W.get()(r, Wr);
W.get()(J, WJ);
::pressio::ops::product(pT, pnT, 1, WJ, WJ, 0, H);
::pressio::ops::product(pT, 1, WJ, Wr, 0, g);

using sc_t = scalar_trait_t<typename SystemType::state_type>;
const auto v = ::pressio::ops::dot(Wr, Wr);
return v * (static_cast<sc_t>(1) / static_cast<sc_t>(2));
}


#ifdef PRESSIO_ENABLE_CXX20
template<class RegistryType, class SystemType>
Expand Down Expand Up @@ -333,6 +394,16 @@ void compute_correction(WeightedGaussNewtonNormalEqTag /*tag*/,
::pressio::ops::scale(c, -1);
}

template<class RegistryType>
void compute_correction(CompactWeightedGaussNewtonNormalEqTag /*tag*/,
RegistryType & reg)
{
// this is same as regular GN since we solve H delta = g
solve_hessian_gradient_linear_system(reg);
auto & c = reg.template get<CorrectionTag>();
::pressio::ops::scale(c, -1);
}

template<class RegistryType>
void compute_correction(LevenbergMarquardtNormalEqTag /*tag*/,
RegistryType & reg)
Expand Down
1 change: 1 addition & 0 deletions include/pressio/solvers_nonlinear/impl/internal_tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct NewtonTag{};
struct MatrixFreeNewtonTag{};
struct GaussNewtonNormalEqTag{};
struct WeightedGaussNewtonNormalEqTag{};
struct CompactWeightedGaussNewtonNormalEqTag{};
struct LevenbergMarquardtNormalEqTag{};
struct GaussNewtonQrTag{};

Expand Down
73 changes: 73 additions & 0 deletions include/pressio/solvers_nonlinear/impl/registries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,79 @@ class RegistryWeightedGaussNewtonNormalEqs
GETMETHOD(11)
};

template<class SystemType, class InnSolverType, class WeightingOpType>
class RegistryCompactWeightedGaussNewtonNormalEqs
{
using state_t = typename SystemType::state_type;
using r_t = typename SystemType::residual_type;
using j_t = typename SystemType::jacobian_type;
using hg_default = normal_eqs_default_types<state_t>;
using hessian_t = typename hg_default::hessian_type;
using gradient_t = typename hg_default::gradient_type;

using Tag1 = nonlinearsolvers::CorrectionTag;
using Tag2 = nonlinearsolvers::InitialGuessTag;
using Tag3 = nonlinearsolvers::ResidualTag;
using Tag4 = nonlinearsolvers::JacobianTag;
using Tag5 = nonlinearsolvers::WeightedResidualTag;
using Tag6 = nonlinearsolvers::WeightedJacobianTag;
using Tag7 = nonlinearsolvers::GradientTag;
using Tag8 = nonlinearsolvers::HessianTag;
using Tag9 = nonlinearsolvers::InnerSolverTag;
using Tag10 = nonlinearsolvers::WeightingOperatorTag;
using Tag11 = nonlinearsolvers::impl::SystemTag;

state_t d1_;
state_t d2_;
r_t d3_;
j_t d4_;
r_t d5_;
j_t d6_;
gradient_t d7_;
hessian_t d8_;
InstanceOrReferenceWrapper<InnSolverType> d9_;
InstanceOrReferenceWrapper<WeightingOpType> d10_;
SystemType const * d11_;

public:
template<class _InnSolverType, class _WeightingOpType>
RegistryCompactWeightedGaussNewtonNormalEqs(const SystemType & system,
_InnSolverType && innS,
_WeightingOpType && weigher)
: d1_(system.createState()),
d2_(system.createState()),
d3_(system.createResidual()),
d4_(system.createJacobian()),
d5_(system.createResidual()),
d6_(system.createJacobian()),
d7_(system.createState()),
d8_( hg_default::createHessian(system.createState()) ),
d9_(std::forward<InnSolverType>(innS)),
d10_(std::forward<_WeightingOpType>(weigher)),
d11_(&system){
// resize Wr and WJ leading dimension according to weighing operator
pressio::ops::resize(d5_, weigher.leadingDim());
pressio::ops::resize(d6_, weigher.leadingDim(), pressio::ops::extent(d6_, 1));
}

template<class TagToFind>
static constexpr bool contains(){
return (mpl::variadic::find_if_binary_pred_t<TagToFind, std::is_same,
Tag1, Tag2, Tag3, Tag4, Tag5, Tag6, Tag7, Tag8, Tag9, Tag10, Tag11>::value) < 11;
}

GETMETHOD(1)
GETMETHOD(2)
GETMETHOD(3)
GETMETHOD(4)
GETMETHOD(5)
GETMETHOD(6)
GETMETHOD(7)
GETMETHOD(8)
GETMETHOD(9)
GETMETHOD(10)
GETMETHOD(11)
};

template<class SystemType, class QRSolverType>
class RegistryGaussNewtonQr
Expand Down
2 changes: 2 additions & 0 deletions include/pressio/solvers_nonlinear/impl/root_finder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#ifndef SOLVERS_NONLINEAR_IMPL_ROOT_FINDER_HPP_
#define SOLVERS_NONLINEAR_IMPL_ROOT_FINDER_HPP_

#include <utility>

namespace pressio{
namespace nonlinearsolvers{
namespace impl{
Expand Down
35 changes: 35 additions & 0 deletions include/pressio/solvers_nonlinear/solvers_create_gauss_newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,41 @@ auto create_gauss_newton_solver(const SystemType & system, /*(2)*/
std::forward<WeightingOpType>(weigher));
}


/*
Special modifications for using CompactWeightedGaussNewtonNormalEqTag
Requires that weigher have member leading_dim specifying leading dimension of Wr and WJ
*/
template<class SystemType, class LinearSolverType, class WeightingOpType>
auto create_gauss_newton_solver(const SystemType & system, /*(3)*/
LinearSolverType && linSolver,
WeightingOpType && weigher,
nonlinearsolvers::impl::CompactWeightedGaussNewtonNormalEqTag /*tag*/)
{

using nonlinearsolvers::Diagnostic;
const std::vector<Diagnostic> defaultDiagnostics =
{Diagnostic::objectiveAbsolute,
Diagnostic::objectiveRelative,
Diagnostic::residualAbsolutel2Norm,
Diagnostic::residualRelativel2Norm,
Diagnostic::correctionAbsolutel2Norm,
Diagnostic::correctionRelativel2Norm,
Diagnostic::gradientAbsolutel2Norm,
Diagnostic::gradientRelativel2Norm};

using tag_t = nonlinearsolvers::impl::CompactWeightedGaussNewtonNormalEqTag;
using state_t = typename SystemType::state_type;
using reg_t = nonlinearsolvers::impl::RegistryCompactWeightedGaussNewtonNormalEqs<
SystemType, LinearSolverType, WeightingOpType>;
using scalar_t = nonlinearsolvers::scalar_of_t<SystemType>;

return nonlinearsolvers::impl::NonLinLeastSquares<tag_t, state_t, reg_t, scalar_t>
(tag_t{}, defaultDiagnostics, system,
std::forward<LinearSolverType>(linSolver),
std::forward<WeightingOpType>(weigher));
}

namespace experimental{

/*
Expand Down
14 changes: 14 additions & 0 deletions tests/functional_small/solvers_nonlinear/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ if (PRESSIO_ENABLE_TPL_EIGEN)
add_serial_exe_and_test(${name} ${ROOTNAME} ${CMAKE_CURRENT_SOURCE_DIR}/${name}.cc "PASSED")
endif()

# -----------------------------
# compact weighted gauss-newton with normal equations
# -----------------------------
if (PRESSIO_ENABLE_TPL_EIGEN)
set(name compact_weighted_gaussnewton_normaleqs_problem3_eigen)
add_serial_exe_and_test(${name} ${ROOTNAME} ${CMAKE_CURRENT_SOURCE_DIR}/${name}.cc "PASSED")

set(name compact_weighted_gaussnewton_normaleqs_nontrivial_problem3_eigen)
add_serial_exe_and_test(${name} ${ROOTNAME} ${CMAKE_CURRENT_SOURCE_DIR}/${name}.cc "PASSED")

set(name compact_weighted_gaussnewton_normaleqs_custom_types_compile_only)
add_serial_exe_and_test(${name} ${ROOTNAME} ${CMAKE_CURRENT_SOURCE_DIR}/${name}.cc "PASSED")
endif()

# -----------------------------
# gauss-newton via QR
# -----------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

#include "pressio/type_traits.hpp"
#include "pressio/ops.hpp"
#include <optional>

struct CustomVecB{};
struct CustomMat{};

struct MyProblem{
using state_type = Eigen::VectorXd;
using residual_type = CustomVecB;
using jacobian_type = CustomMat;
state_type createState() const { return state_type{}; }
residual_type createResidual() const { return residual_type{}; }
jacobian_type createJacobian() const { return jacobian_type{}; }
void residualAndJacobian(const state_type& /*x*/,
residual_type& /*r*/,
std::optional<jacobian_type*> /*Jo*/) const{}
};

struct Weigher{
int leadingDim() { return {}; }
void operator()(const CustomVecB & /*operand*/, CustomVecB & /*result*/) const{}
void operator()(const CustomMat & /*operand*/, CustomMat & /*result*/) const{}
};

using my_hessian_type = Eigen::MatrixXd;
using my_gradient_type = Eigen::VectorXd;

namespace pressio{
template<> struct Traits<CustomVecB>{
static constexpr int rank = 1;
using scalar_type = double;
};
template<> struct Traits<CustomMat>{
static constexpr int rank = 2;
using scalar_type = double;
};

namespace ops{
double norm2(const CustomVecB &){ return {}; }
double dot(const CustomVecB &, const CustomVecB &){ return {}; }
void product(transpose, nontranspose, double, const CustomMat &, const CustomMat &, double, my_hessian_type &){}
void product(transpose, double, const CustomMat &, const CustomVecB &, double, my_gradient_type &){}
void resize(CustomVecB &, int){}
void resize(CustomMat &, int, int){}
std::size_t extent(CustomMat &, int){ return {}; }
}//end namespace ops
}//end namespace pressio

#include "pressio/solvers_nonlinear_gaussnewton.hpp"

struct MyLinSolver{
void solve(const my_hessian_type & /*A*/,
const my_gradient_type & /*b*/,
typename MyProblem::state_type & /*x*/){}
};

int main()
{
pressio::log::initialize(pressio::logto::terminal);
pressio::log::setVerbosity({pressio::log::level::debug});
{
using namespace pressio;
using problem_t = MyProblem;
using state_t = typename problem_t::state_type;
using tag_t = nonlinearsolvers::impl::CompactWeightedGaussNewtonNormalEqTag;
problem_t sys;
state_t y;
auto nonLinSolver = create_gauss_newton_solver(sys, MyLinSolver{}, Weigher{}, tag_t{});
nonLinSolver.solve(y);
(void)y;
std::cout << "PASSED" << std::endl;
}
pressio::log::finalize();
}
Loading

0 comments on commit 4077a90

Please sign in to comment.