Skip to content

Commit

Permalink
honor max_rows; fail when files are missing columns for sb2 or indica…
Browse files Browse the repository at this point in the history
…tors (fixes #22)
  • Loading branch information
j-faria committed Oct 4, 2024
1 parent 5ad4365 commit d95a276
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
41 changes: 34 additions & 7 deletions src/kima/Data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ void RVData::load(const string filename, const string units, int skip, int max_r

if (sb2)
{
if (data.size() < 5) {
std::string msg = "kima: RVData: sb2 is true but file (" + filename + ") contains less than 5 columns!";
throw std::runtime_error(msg);
}
y2 = data[3];
sig2 = data[4];
}
Expand All @@ -243,6 +247,11 @@ void RVData::load(const string filename, const string units, int skip, int max_r
number_indicators = static_cast<int>(indicators.size()) - nempty;
indicator_correlations = number_indicators > 0;

if (data.size() < 3 + number_indicators + nempty) {
std::string msg = "kima: RVData: file (" + filename + ") contains too few columns!";
throw std::runtime_error(msg);
}

_indicator_names = indicators;
_indicator_names.erase(std::remove(_indicator_names.begin(), _indicator_names.end(), ""), _indicator_names.end());

Expand Down Expand Up @@ -361,6 +370,12 @@ void RVData::load_multi(const string filename, const string units, int skip, int
int nempty = (int) count(indicators.begin(), indicators.end(), "");
number_indicators = (int)(indicators.size()) - nempty;
indicator_correlations = number_indicators > 0;

if (data.size() < 3 + number_indicators + nempty) {
std::string msg = "kima: RVData: file (" + filename + ") contains too few columns!";
throw std::runtime_error(msg);
}

_indicator_names = indicators;
_indicator_names.erase(std::remove(_indicator_names.begin(), _indicator_names.end(), ""), _indicator_names.end());

Expand Down Expand Up @@ -484,13 +499,20 @@ void RVData::load_multi(vector<string> filenames, const string units, int skip,

int filecount = 1;
for (auto& filename : filenames) {
auto data = loadtxt(filename).skiprows(skip)();
auto data = loadtxt(filename)
.skiprows(skip)
.max_rows(max_rows)();

if (data.size() < 3) {
std::string msg = "kima: RVData: file (" + filename + ") contains less than 3 columns! (is skip correct?)";
throw std::runtime_error(msg);
}

if (data.size() < 3 + number_indicators + nempty) {
std::string msg = "kima: RVData: file (" + filename + ") contains too few columns!";
throw std::runtime_error(msg);
}

t.insert(t.end(), data[0].begin(), data[0].end());
y.insert(y.end(), data[1].begin(), data[1].end());
sig.insert(sig.end(), data[2].begin(), data[2].end());
Expand All @@ -515,9 +537,7 @@ void RVData::load_multi(vector<string> filenames, const string units, int skip,
continue; // skip column
else
{
actind[j].insert(actind[j].end(),
data[3 + i].begin(),
data[3 + i].end());
actind[j].insert(actind[j].end(), data[3 + i].begin(), data[3 + i].end());
j++;
}
}
Expand Down Expand Up @@ -614,10 +634,8 @@ double RVData::get_RV_var() const
{
double sum = accumulate(begin(y), end(y), 0.0);
double mean = sum / y.size();

double accum = 0.0;
for_each(begin(y), end(y),
[&](const double d) { accum += (d - mean) * (d - mean); });
for_each(begin(y), end(y), [&](const double d) { accum += (d - mean) * (d - mean); });
return accum / (y.size() - 1);
}

Expand Down Expand Up @@ -729,6 +747,15 @@ int RVData::get_trend_magnitude(int degree) const
return (int)round(log10(get_RV_span() / pow(get_timespan(), degree)));
}

double RVData::get_actind_var(size_t i) const
{
double sum = accumulate(begin(actind[i]), end(actind[i]), 0.0);
double mean = sum / actind[i].size();
double accum = 0.0;
for_each(begin(actind[i]), end(actind[i]), [&](const double d) { accum += (d - mean) * (d - mean); });
return accum / (actind[i].size() - 1);
}


ostream& operator<<(ostream& os, const RVData& d)
{
Expand Down
4 changes: 4 additions & 0 deletions src/kima/Data.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ class KIMA_API RVData {
double get_actind_max(size_t i) const { return *max_element(actind[i].begin(), actind[i].end()); }
/// Get the span of Activity Indicator i
double get_actind_span(size_t i) const { return get_actind_max(i) - get_actind_min(i); }
/// Get the variance of Activity Indicator i
double get_actind_var(size_t i) const;
/// Get the standard deviation of Activity Indicator i
double get_actind_std(size_t i) const { return sqrt(get_actind_var(i)); }


/// Normalize the activity indicators from 0 to 1
Expand Down
14 changes: 14 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ def test_RVData():
assert_equal(np.unique(D.obsi), [1, 2])


def test_RVData_missing_indicators():
# test the issue described in https://github.com/kima-org/kima/issues/22

# simulated2.txt is missing the 7th column
with pytest.raises(RuntimeError):
_ = kima.RVData('tests/simulated2.txt',
indicators=['i', 'j', 'n', 'missing'])

# simulated1.txt is missing the 3rd column
with pytest.raises(RuntimeError):
_ = kima.RVData(['tests/simulated2.txt', 'tests/simulated1.txt'],
indicators=['i', 'j'])


def test_RVmodel():
m = kima.RVmodel(True, 0, kima.RVData('tests/simulated1.txt'))

Expand Down

0 comments on commit d95a276

Please sign in to comment.