diff --git a/falkon/models/falkon.py b/falkon/models/falkon.py index 4b0e5be..93336c0 100644 --- a/falkon/models/falkon.py +++ b/falkon/models/falkon.py @@ -1,7 +1,7 @@ import dataclasses import time import warnings -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, Any import torch from torch import Tensor @@ -125,7 +125,7 @@ def __init__( center_selection: Union[str, falkon.center_selection.CenterSelector] = "uniform", maxiter: int = 20, seed: Optional[int] = None, - error_fn: Optional[Callable[[torch.Tensor, torch.Tensor], Union[float, Tuple[float, str]]]] = None, + error_fn: Optional[Callable[[torch.Tensor, torch.Tensor], Union[Any, Tuple[Any, str]]]] = None, error_every: Optional[int] = 1, weight_fn: Optional[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]] = None, options: Optional[FalkonOptions] = None, diff --git a/falkon/models/model_utils.py b/falkon/models/model_utils.py index e61dc3e..69e5609 100644 --- a/falkon/models/model_utils.py +++ b/falkon/models/model_utils.py @@ -109,7 +109,7 @@ def val_cback(it, beta, train_time): if isinstance(err, tuple) and len(err) == 2: err, err_name = err print( - f"Iteration {it:3d} - Elapsed {self.fit_times_[-1]:.2f}s - {err_str} {err_name}: {err:.8f}", + f"Iteration {it:3d} - Elapsed {self.fit_times_[-1]:.2f}s - {err_str} {err_name}: {str(err)}", flush=True, ) self.val_errors_.append(err)