From 3830f6f1a9c7fd29dae971b263b04740f8914754 Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Mon, 12 Apr 2021 09:26:59 +0300 Subject: [PATCH] don't regularize the metric if there were too few samples to update it --- src/stan/mcmc/covar_adaptation.hpp | 10 ++++--- src/stan/mcmc/var_adaptation.hpp | 8 ++++-- src/test/unit/mcmc/covar_adaptation_test.cpp | 29 ++++++++++++++++++++ 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/stan/mcmc/covar_adaptation.hpp b/src/stan/mcmc/covar_adaptation.hpp index 6c21c63e33a..423f17fca67 100644 --- a/src/stan/mcmc/covar_adaptation.hpp +++ b/src/stan/mcmc/covar_adaptation.hpp @@ -23,10 +23,12 @@ class covar_adaptation : public windowed_adaptation { estimator_.sample_covariance(covar); - double n = static_cast(estimator_.num_samples()); - covar = (n / (n + 5.0)) * covar - + 1e-3 * (5.0 / (n + 5.0)) - * Eigen::MatrixXd::Identity(covar.rows(), covar.cols()); + if (estimator_.num_samples() > 1) { + double n = static_cast(estimator_.num_samples()); + covar = (n / (n + 5.0)) * covar + + 1e-3 * (5.0 / (n + 5.0)) + * Eigen::MatrixXd::Identity(covar.rows(), covar.cols()); + } estimator_.restart(); diff --git a/src/stan/mcmc/var_adaptation.hpp b/src/stan/mcmc/var_adaptation.hpp index c81de41d980..36f9089c10d 100644 --- a/src/stan/mcmc/var_adaptation.hpp +++ b/src/stan/mcmc/var_adaptation.hpp @@ -23,9 +23,11 @@ class var_adaptation : public windowed_adaptation { estimator_.sample_variance(var); - double n = static_cast(estimator_.num_samples()); - var = (n / (n + 5.0)) * var - + 1e-3 * (5.0 / (n + 5.0)) * Eigen::VectorXd::Ones(var.size()); + if (estimator_.num_samples() > 1) { + double n = static_cast(estimator_.num_samples()); + var = (n / (n + 5.0)) * var + + 1e-3 * (5.0 / (n + 5.0)) * Eigen::VectorXd::Ones(var.size()); + } estimator_.restart(); diff --git a/src/test/unit/mcmc/covar_adaptation_test.cpp b/src/test/unit/mcmc/covar_adaptation_test.cpp index f1e0853d884..afc85b67f19 100644 --- a/src/test/unit/mcmc/covar_adaptation_test.cpp +++ b/src/test/unit/mcmc/covar_adaptation_test.cpp @@ -27,3 +27,32 @@ TEST(McmcCovarAdaptation, learn_covariance) { } EXPECT_EQ(0, logger.call_count()); } + +TEST(McmcCovarAdaptation, learn_covariance_one_sample) { + stan::test::unit::instrumented_logger logger; + + const int n = 10; + Eigen::VectorXd q = Eigen::VectorXd::Zero(n); + Eigen::MatrixXd covar(Eigen::MatrixXd::Identity(n, n)); + + const int n_learn = 1; + + Eigen::MatrixXd target_covar(Eigen::MatrixXd::Identity(n, n)); + + stan::mcmc::covar_adaptation adapter(n); + adapter.set_window_params(50, 0, 0, n_learn, logger); + + bool update = false; + + for (int i = 0; i < n_learn; ++i) + update = adapter.learn_covariance(covar, q); + + EXPECT_TRUE(update); + + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + EXPECT_EQ(target_covar(i, j), covar(i, j)); + } + } + EXPECT_EQ(0, logger.call_count()); +}