Skip to content

Commit

Permalink
More lenient in accepting errors of different types
Browse files Browse the repository at this point in the history
  • Loading branch information
gmeanti committed May 1, 2024
1 parent fe672ec commit 51d0885
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions falkon/models/falkon.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import time
import warnings
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union, Any

import torch
from torch import Tensor
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
center_selection: Union[str, falkon.center_selection.CenterSelector] = "uniform",
maxiter: int = 20,
seed: Optional[int] = None,
error_fn: Optional[Callable[[torch.Tensor, torch.Tensor], Union[float, Tuple[float, str]]]] = None,
error_fn: Optional[Callable[[torch.Tensor, torch.Tensor], Union[Any, Tuple[Any, str]]]] = None,
error_every: Optional[int] = 1,
weight_fn: Optional[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]] = None,
options: Optional[FalkonOptions] = None,
Expand Down
2 changes: 1 addition & 1 deletion falkon/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def val_cback(it, beta, train_time):
if isinstance(err, tuple) and len(err) == 2:
err, err_name = err
print(
f"Iteration {it:3d} - Elapsed {self.fit_times_[-1]:.2f}s - {err_str} {err_name}: {err:.8f}",
f"Iteration {it:3d} - Elapsed {self.fit_times_[-1]:.2f}s - {err_str} {err_name}: {str(err)}",
flush=True,
)
self.val_errors_.append(err)
Expand Down

0 comments on commit 51d0885

Please sign in to comment.