Skip to content

Commit

Permalink
Make tests more robust towards stray warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Nov 5, 2024
1 parent 59ebafd commit 2703fdc
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions skorch/tests/test_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,16 @@ def test_dimension_mismatch_warning(self, net_cls, module_cls, data, recwarn):
X, y = X[:100], y[:100].flatten() # make y 1d
net.fit(X, y)

w0, w1 = recwarn.list # one warning for train, one for valid
# The warning comes from PyTorch, so checking the exact wording is prone to
# error in future PyTorch versions. We thus check a substring of the
# whole message and cross our fingers that it's not changed.
msg_substr = (
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size"
)
assert msg_substr in str(w0.message)
assert msg_substr in str(w1.message)
warn_list = [w for w in recwarn.list if msg_substr in str(w.message)]
# one warning for train, one for valid
assert len(warn_list) == 2

def test_fitting_with_1d_target_and_pred(
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
Expand All @@ -156,7 +159,11 @@ def test_fitting_with_1d_target_and_pred(

net = net_cls(module_pred_1d_cls)
net.fit(X, y)
assert not recwarn.list
msg_substr = (
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size"
)
assert not any(msg_substr in str(w.message) for w in recwarn.list)

def test_bagging_regressor(
self, net_cls, module_cls, data, module_pred_1d_cls, recwarn
Expand All @@ -171,4 +178,8 @@ def test_bagging_regressor(
regr = BaggingRegressor(net, n_estimators=2, random_state=0)
regr.fit(X, y) # does not raise
# ensure there is no broadcast warning from torch
assert not recwarn.list
msg_substr = (
"This will likely lead to incorrect results due to broadcasting. "
"Please ensure they have the same size"
)
assert not any(msg_substr in str(w.message) for w in recwarn.list)

0 comments on commit 2703fdc

Please sign in to comment.