Skip to content

Commit

Permalink
Minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
joshivanhoe committed Jan 16, 2024
1 parent a45c399 commit c0391b2
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 42 deletions.
File renamed without changes.
39 changes: 0 additions & 39 deletions .github/workflows/python-publish.yml

This file was deleted.

1 change: 0 additions & 1 deletion src/sparsely/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> SparseLinearRegressor:
model.add_linear_constr(sum(selected) <= self.k_)
model.optimize()
selected = np.round([model.var_value(var) for var in selected]).astype(bool)
model.reset()

# Compute coefficients
self.coef_ = np.zeros(self.n_features_in_)
Expand Down
3 changes: 2 additions & 1 deletion src/sparsely/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def tune_estimator(
search_log = list()

for k in tqdm(
range(k_min, k_max or X.shape[0], step_size), disable=not show_progress_bar
range(k_min, (k_max or X.shape[1]) + 1, step_size),
disable=not show_progress_bar,
):
# Perform cross-validation
output = cross_validate(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_tune_estimator(
if return_search_log:
estimator, search_log = output
assert isinstance(search_log, pd.DataFrame)
assert search_log.columns == ["k", "score", "std"]
assert (search_log.columns == ["k", "score", "std"]).all()
if max_iters_no_improvement is None:
assert len(search_log) == 5
else:
Expand Down

0 comments on commit c0391b2

Please sign in to comment.