Skip to content

Commit

Permalink
Add time limit and solver status
Browse files Browse the repository at this point in the history
  • Loading branch information
chrhansk committed Nov 9, 2023
1 parent 93d57a9 commit 5ee3585
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
2 changes: 2 additions & 0 deletions pygradflow/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class Params:

validate_input: bool = True

time_limit: float = np.inf

@property
def dtype(self):
return np.float32 if self.precision == Precision.Single else np.float64
39 changes: 34 additions & 5 deletions pygradflow/solver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from collections import namedtuple
import time

from enum import Enum, auto

import numpy as np
from termcolor import colored
Expand All @@ -14,14 +17,27 @@
DualNormUpdate,
PenaltyStrategy,
)

from pygradflow.step.step_control import (
StepResult,
step_controller,
StepController,
)


Result = namedtuple("Result", ["x", "y", "d", "success"])
class SolverStatus(Enum):
Converged = auto(),
IterationLimit = auto(),
TimeLimit = auto()


class Result:
def __init__(self, x, y, d, success, status):
self.x = x
self.y = y
self.d = d
self.success = success
self.status = status


def bold(s: str) -> str:
Expand Down Expand Up @@ -155,12 +171,21 @@ def solve(self, x_0: np.ndarray, y_0: np.ndarray) -> Result:

logger.info("Initial Aug Lag: %.10e", iterate.aug_lag(self.rho))

status = None
start_time = time.time()

for i in range(params.num_it):
if (i % 25) == 0:
print_header()

if iterate.total_res <= params.opt_tol:
logger.info("Convergence achieved")
status = SolverStatus.Converged
break

if time.time() - start_time >= params.time_limit:
logger.info("Reached time limit")
status = SolverStatus.TimeLimit
break

step_result = self.compute_step(controller, iterate, 1.0 / lamb)
Expand Down Expand Up @@ -201,9 +226,9 @@ def solve(self, x_0: np.ndarray, y_0: np.ndarray) -> Result:
next_rho = self.penalty.update(iterate, next_iterate)

if next_rho != self.rho:
logger.debug(
"Updating penalty parameter from %e to %e", self.rho, next_rho
)
logger.debug("Updating penalty parameter from %e to %e",
self.rho,
next_rho)
self.rho = next_rho

delta = iterate.dist(next_iterate)
Expand All @@ -212,10 +237,12 @@ def solve(self, x_0: np.ndarray, y_0: np.ndarray) -> Result:

if (lamb <= params.lamb_term) and (delta <= params.opt_tol):
logger.info("Convergence achieved")
status = SolverStatus.Converged
break

else:
success = False
status = SolverStatus.IterationLimit
logger.info("Iteration limit reached")

self.print_result(iterate)
Expand All @@ -224,4 +251,6 @@ def solve(self, x_0: np.ndarray, y_0: np.ndarray) -> Result:
y = iterate.y
d = iterate.bound_duals

return Result(x, y, d, success)
assert status is not None

return Result(x, y, d, success, status)

0 comments on commit 5ee3585

Please sign in to comment.