Skip to content

Commit

Permalink
Copy new tests from #1290
Browse files Browse the repository at this point in the history
Also fix CSV precision bug in matrix outputs
  • Loading branch information
WardBrian committed Nov 1, 2024
1 parent 4cd3325 commit 064adc4
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 4 deletions.
5 changes: 1 addition & 4 deletions src/cmdstan/stansummary_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,7 @@ void write_all_model_params(const stan::mcmc::chainset &chains,
if (as_csv) {
*out << "\"" << chains.param_name(row_maj_index_chains) << "\"";
for (int j = 0; j < params.cols(); j++) {
*out << "," << std::fixed
<< std::setprecision(compute_precision(
params(row_maj_index, j), sig_figs, false))
<< params(row_maj_index, j);
*out << "," << params(row_maj_index, j);
}
} else {
*out << std::setw(max_name_length + 1) << std::left
Expand Down
23 changes: 23 additions & 0 deletions src/test/interface/example_output/matrix_summary.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name,Mean,MCSE,StdDev,MAD,5%,50%,95%,ESS_bulk,ESS_tail,R_hat
"lp__",-15.3787,1.17576,3.71807,0.779848,-20.9245,-14.8305,-12.6946,5,5,1.31109
"accept_stat__",0.68789,0.0961912,0.304183,0.252626,0.208269,0.779947,0.989055,5,5,1.05408
"stepsize__",0.894277,nan,1.17028e-16,0,0.894277,0.894277,0.894277,nan,nan,nan
"treedepth__",0.9,0.1,0.316228,0,0.45,1,1,nan,nan,nan
"x[1]",0.608916,0.0690549,0.218371,0.108465,0.284524,0.666485,0.866975,5,5,0.905216
"x[2]",0.479369,0.051013,0.161317,0.164219,0.263389,0.451793,0.686842,5,5,1.00156
"y[1,1]",0.425556,0.124575,0.393941,0.150819,0.0588995,0.164637,0.926039,5,5,1.431
"y[1,2]",0.595682,0.0758338,0.239808,0.178212,0.249647,0.663856,0.850511,5,5,1.11989
"y[1,3]",0.473713,0.104663,0.330973,0.39317,0.0677836,0.483405,0.863312,5,5,0.986085
"y[2,1]",0.511464,0.0868387,0.274608,0.383569,0.167417,0.487338,0.895304,5,5,1.43105
"y[2,2]",0.559192,0.0872593,0.275938,0.307409,0.230475,0.525767,0.932888,5,5,0.915761
"y[2,3]",0.51956,0.0997935,0.315575,0.212694,0.0361943,0.536618,0.930452,5,5,1.32719
# Inference for Stan model: issue_342_model
# 1 chains: each with iter=10; warmup=1000; thin=1; 10 iterations saved.
#
# Warmup took 0.23 seconds
# Sampling took 0.0018 seconds
# Samples were drawn using hmc with nuts.
# For each parameter, ESS_bulk and ESS_tail measure the effective sample size
for the entire sample (bulk) and for the the .05 and .95 tails (tail),
# and R_hat measures the potential scale reduction on split chains.
At convergence R_hat will be very close to 1.00.
27 changes: 27 additions & 0 deletions src/test/interface/example_output/matrix_summary.nom
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Inference for Stan model: issue_342_model
1 chains: each with iter=10; warmup=1000; thin=1; 10 iterations saved.

Warmup took 0.23 seconds
Sampling took 0.0018 seconds

Mean MCSE StdDev MAD 5% 50% 95% ESS_bulk ESS_tail R_hat

lp__ -15 1.2 3.7 0.78 -21 -15 -13 5.0 5.0 1.3
accept_stat__ 0.69 0.096 3.0e-01 0.25 0.21 0.78 0.99 5.0 5.0 1.1
stepsize__ 0.89 nan 1.2e-16 0.00 0.89 0.89 0.89 nan nan nan
treedepth__ 0.90 0.10 3.2e-01 0.00 0.45 1.0 1.0 nan nan nan

x[1] 0.61 0.069 0.22 0.11 0.28 0.67 0.87 5.0 5.0 0.91
x[2] 0.48 0.051 0.16 0.16 0.26 0.45 0.69 5.0 5.0 1.0
y[1,1] 0.43 0.12 0.39 0.15 0.059 0.16 0.93 5.0 5.0 1.4
y[1,2] 0.60 0.076 0.24 0.18 0.25 0.66 0.85 5.0 5.0 1.1
y[1,3] 0.47 0.10 0.33 0.39 0.068 0.48 0.86 5.0 5.0 0.99
y[2,1] 0.51 0.087 0.27 0.38 0.17 0.49 0.90 5.0 5.0 1.4
y[2,2] 0.56 0.087 0.28 0.31 0.23 0.53 0.93 5.0 5.0 0.92
y[2,3] 0.52 0.100 0.32 0.21 0.036 0.54 0.93 5.0 5.0 1.3

Samples were drawn using hmc with nuts.
For each parameter, ESS_bulk and ESS_tail measure the effective sample size
for the entire sample (bulk) and for the the .05 and .95 tails (tail),
and R_hat measures the potential scale reduction on split chains.
At convergence R_hat will be very close to 1.00.
62 changes: 62 additions & 0 deletions src/test/interface/stansummary_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,3 +685,65 @@ TEST(CommandStansummary, check_csv_output_include_param) {
if (return_code != 0)
FAIL();
}

TEST(CommandStansummary, check_reorder_stats) {
std::string path_separator;
path_separator.push_back(get_path_separator());
std::string csv_file = "src" + path_separator + "test" + path_separator
+ "interface" + path_separator + "matrix_output.csv";
std::stringstream ss_command;
ss_command << "bin" << path_separator << "stansummary " << csv_file;
run_command_output out = run_command(ss_command.str());
ASSERT_FALSE(out.hasError);

std::string expected_file = "src" + path_separator + "test" + path_separator
+ "interface" + path_separator + "example_output"
+ path_separator + "matrix_summary.nom";
std::ifstream expected_output(expected_file.c_str());
EXPECT_FALSE(expected_output.bad());
std::stringstream ss;
ss << expected_output.rdbuf();
EXPECT_EQ(ss.str(), out.output);
}

TEST(CommandStansummary, check_reorder_stats_csv) {
std::string path_separator;
path_separator.push_back(get_path_separator());
std::string target_csv_file = "test" + path_separator + "interface"
+ path_separator
+ "tmp_test_target_csv_file_e.csv";
std::string csv_file = "src" + path_separator + "test" + path_separator
+ "interface" + path_separator + "matrix_output.csv";
std::stringstream ss_command;
ss_command << "bin" << path_separator << "stansummary "
<< "-c " << target_csv_file << " " << csv_file;
run_command_output out = run_command(ss_command.str());
ASSERT_FALSE(out.hasError);

std::string expected_file = "src" + path_separator + "test" + path_separator
+ "interface" + path_separator + "example_output"
+ path_separator + "matrix_summary.csv";
std::ifstream expected_output(expected_file.c_str());
EXPECT_FALSE(expected_output.bad());
std::stringstream ss_expected;
ss_expected << expected_output.rdbuf();

std::ifstream target_stream(target_csv_file.c_str());
if (!target_stream.is_open()) {
std::cerr << "Failed to open file: " << target_csv_file << "\n";
std::cerr << "Error: " << std::strerror(errno) << std::endl;
FAIL();
}
std::stringstream ss_actual;
ss_actual << target_stream.rdbuf();
target_stream.close();

EXPECT_EQ(ss_expected.str(), ss_actual.str());

int return_code = std::remove(target_csv_file.c_str());
if (return_code != 0) {
std::cerr << "Failed to remove file: " << target_csv_file << "\n";
std::cerr << "Error: " << std::strerror(errno) << std::endl;
FAIL();
}
}

0 comments on commit 064adc4

Please sign in to comment.