Skip to content

Commit

Permalink
Merge pull request #96 from chrhansk/feature-step-control-refactor
Browse files Browse the repository at this point in the history
Refactor step controllers
  • Loading branch information
chrhansk authored Jun 3, 2024
2 parents 516648e + 8e0bfbc commit 32a58a6
Show file tree
Hide file tree
Showing 22 changed files with 120 additions and 74 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
project = "pygradflow"
copyright = "2023, Christoph Hansknecht"
author = "Christoph Hansknecht"
release = "0.5.3"
release = "0.5.4"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
4 changes: 2 additions & 2 deletions pygradflow/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from pygradflow.log import logger as lgg
from pygradflow.params import NewtonType, Params
from pygradflow.problem import Problem
from pygradflow.step import step_solver
from pygradflow.step.step_solver import StepResult, StepSolver
from pygradflow.step.solver import step_solver
from pygradflow.step.solver.step_solver import StepResult, StepSolver

logger = lgg.getChild("newton")

Expand Down
18 changes: 2 additions & 16 deletions pygradflow/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pygradflow.display import Format, StateData, print_problem_stats, solver_display
from pygradflow.iterate import Iterate
from pygradflow.log import logger
from pygradflow.newton import newton_method
from pygradflow.params import Params
from pygradflow.penalty import penalty_strategy
from pygradflow.problem import Problem
Expand Down Expand Up @@ -92,22 +91,9 @@ def compute_step(
display: bool,
timer: Timer,
) -> StepControlResult:
problem = self.problem
params = self.params
assert self.rho != -1.0

method = newton_method(problem, params, iterate, dt, self.rho)

def next_steps():
curr_iterate = iterate
while True:
next_step = method.step(curr_iterate)
yield next_step
curr_iterate = next_step.iterate

return controller.compute_step(
iterate, self.rho, dt, next_steps(), display, timer
)
assert self.rho != -1.0
return controller.compute_step(iterate, self.rho, dt, display, timer)

def _deriv_check(self, x: np.ndarray, y: np.ndarray) -> None:
from pygradflow.deriv_check import deriv_check
Expand Down
2 changes: 1 addition & 1 deletion pygradflow/step/box_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def hessian(x):
raise StepSolverError("Box-constrained solver failed to converge") from e

def step(
self, iterate, rho: float, dt: float, next_steps, display: bool, timer
self, iterate, rho: float, dt: float, display: bool, timer
) -> StepControlResult:

lamb = 1.0 / dt
Expand Down
9 changes: 6 additions & 3 deletions pygradflow/step/distance_ratio_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@
from pygradflow.log import logger
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.step_control import StepController, StepControlResult
from pygradflow.step.newton_control import NewtonController
from pygradflow.step.step_control import StepControlResult


class DistanceRatioController(StepController):
class DistanceRatioController(NewtonController):
def __init__(self, problem: Problem, params: Params) -> None:
super().__init__(problem, params)
settings = ControllerSettings.from_params(params)
self.controller = LogController(settings, params.theta_ref)

def step(self, iterate, rho, dt, next_steps, display, timer):
def step(self, iterate, rho, dt, display, timer):
assert dt > 0.0
lamb = 1.0 / dt

problem = self.problem
params = self.params

next_steps = self.newton_steps(iterate, rho, dt)

func = ImplicitFunc(problem, iterate, dt)

mid_step = next(next_steps)
Expand Down
14 changes: 7 additions & 7 deletions pygradflow/step/exact_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@

from pygradflow.implicit_func import ImplicitFunc
from pygradflow.log import logger
from pygradflow.step.step_control import (
StepController,
StepControlResult,
StepSolverError,
)
from pygradflow.step.newton_control import NewtonController
from pygradflow.step.step_control import StepControlResult
from pygradflow.step.step_solver_error import StepSolverError


class ExactController(StepController):
class ExactController(NewtonController):
def __init__(self, problem, params, max_num_it=10, rate_bound=0.5):
super().__init__(problem, params)
self.max_num_it = max_num_it
self.rate_bound = rate_bound

def step(self, iterate, rho, dt, next_steps, display, timer):
def step(self, iterate, rho, dt, display, timer):
assert dt > 0.0
lamb = 1.0 / dt

Expand All @@ -26,6 +24,8 @@ def func_val(iterate):

curr_func_val = func_val(iterate)

next_steps = self.newton_steps(iterate, rho, dt)

rcond = None
active_set = None

Expand Down
9 changes: 6 additions & 3 deletions pygradflow/step/fixed_control.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.step_control import StepController, StepControlResult
from pygradflow.step.newton_control import NewtonController
from pygradflow.step.step_control import StepControlResult


class FixedStepSizeController(StepController):
class FixedStepSizeController(NewtonController):
def __init__(self, problem: Problem, params: Params) -> None:
super().__init__(problem, params)
self.lamb = params.lamb_init

def step(self, iterate, rho, dt, next_steps, display, timer):
def step(self, iterate, rho, dt, display, timer):
assert dt > 0.0

next_steps = self.newton_steps(iterate, rho, dt)

step = next(next_steps)

return StepControlResult.from_step_result(step, self.lamb, True)
6 changes: 5 additions & 1 deletion pygradflow/step/linear_solver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from abc import ABC, abstractmethod

import numpy as np
import scipy as sp
from numpy import ndarray
Expand All @@ -16,7 +18,9 @@ class LinearSolverError(Exception):
pass


class LinearSolver:
class LinearSolver(ABC):

@abstractmethod
def solve(self, b: ndarray, trans: bool = False) -> ndarray:
raise NotImplementedError()

Expand Down
40 changes: 40 additions & 0 deletions pygradflow/step/newton_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Iterator, Optional

from pygradflow.iterate import Iterate
from pygradflow.newton import newton_method
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.solver.step_solver import StepResult
from pygradflow.step.step_control import StepController


class NewtonController(StepController):
"""
Step controller working by solving the implicit Euler
equations using a semi-smooth Newton method
"""

def __init__(self, problem: Problem, params: Params) -> None:
super().__init__(problem, params)

def newton_steps(
self,
iterate: Iterate,
rho: float,
dt: float,
initial_iterate: Optional[Iterate] = None,
) -> Iterator[StepResult]:
problem = self.problem
params = self.params

self.method = newton_method(problem, params, iterate, dt, rho)

if initial_iterate is not None:
curr_iterate = initial_iterate
else:
curr_iterate = iterate

while True:
next_step = self.method.step(curr_iterate)
yield next_step
curr_iterate = next_step.iterate
2 changes: 1 addition & 1 deletion pygradflow/step/opti_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def solve(self, timer):


class OptimizingController(StepController):
def step(self, iterate, rho, dt, next_steps, display, timer) -> StepControlResult:
def step(self, iterate, rho, dt, display, timer) -> StepControlResult:

problem = self.problem
params = self.params
Expand Down
9 changes: 6 additions & 3 deletions pygradflow/step/residuum_ratio_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@
from pygradflow.log import logger
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.step_control import StepController, StepControlResult
from pygradflow.step.newton_control import NewtonController
from pygradflow.step.step_control import StepControlResult


class ResiduumRatioController(StepController):
class ResiduumRatioController(NewtonController):
def __init__(self, problem: Problem, params: Params) -> None:
settings = ControllerSettings.from_params(params)
self.controller = LogController(settings, params.theta_ref)
super().__init__(problem, params)

def step(self, iterate, rho, dt, next_steps, display, timer):
def step(self, iterate, rho, dt, display, timer):
assert dt > 0.0
lamb = 1.0 / dt

problem = self.problem
params = self.params

next_steps = self.newton_steps(iterate, rho, dt)

func = ImplicitFunc(problem, iterate, dt)

mid_step = next(next_steps)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pygradflow.iterate import Iterate
from pygradflow.params import Params, StepSolverType
from pygradflow.problem import Problem
from pygradflow.step.asymmetric_step_solver import AsymmetricStepSolver
from pygradflow.step.extended_step_solver import ExtendedStepSolver
from pygradflow.step.standard_step_solver import StandardStepSolver
from pygradflow.step.step_solver import StepSolver
from pygradflow.step.symmetric_step_solver import SymmetricStepSolver

from .asymmetric_step_solver import AsymmetricStepSolver
from .extended_step_solver import ExtendedStepSolver
from .standard_step_solver import StandardStepSolver
from .step_solver import StepSolver
from .symmetric_step_solver import SymmetricStepSolver


def step_solver(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.linear_solver import LinearSolverError
from pygradflow.step.scaled_step_solver import ScaledStepSolver
from pygradflow.step.step_control import StepSolverError
from pygradflow.step.step_solver_error import StepSolverError

from .scaled_step_solver import ScaledStepSolver


class AsymmetricStepSolver(ScaledStepSolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.linear_solver import LinearSolverError
from pygradflow.step.scaled_step_solver import ScaledStepSolver
from pygradflow.step.step_control import StepSolverError
from pygradflow.step.step_solver_error import StepSolverError

from .scaled_step_solver import ScaledStepSolver


class ExtendedStepSolver(ScaledStepSolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from pygradflow.iterate import Iterate
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.step_solver import StepResult, StepSolver

from .step_solver import StepResult, StepSolver


class ScaledStepSolver(StepSolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.linear_solver import LinearSolverError
from pygradflow.step.step_control import StepSolverError
from pygradflow.step.step_solver import StepResult, StepSolver
from pygradflow.step.step_solver_error import StepSolverError

from .step_solver import StepResult, StepSolver


class StandardStepSolver(StepSolver):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ def hess(self) -> sp.sparse.spmatrix:
return cast(sp.sparse.spmatrix, self._hess)

def linear_solver(self, mat: sp.sparse.spmatrix) -> LinearSolver:
from .linear_solver import linear_solver
from pygradflow.step.linear_solver import linear_solver

solver_type = self.params.linear_solver_type
return linear_solver(mat, solver_type)

def estimate_rcond(
self, mat: sp.sparse.spmatrix, solver: LinearSolver
) -> Optional[float]:
from .cond_estimate import ConditionEstimator
from pygradflow.step.cond_estimate import ConditionEstimator

estimator = ConditionEstimator(mat, solver, self.params)
rcond = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from pygradflow.params import Params
from pygradflow.problem import Problem
from pygradflow.step.linear_solver import LinearSolverError
from pygradflow.step.scaled_step_solver import ScaledStepSolver
from pygradflow.step.step_control import StepSolverError
from pygradflow.step.step_solver_error import StepSolverError

from .scaled_step_solver import ScaledStepSolver


class SymmetricStepSolver(ScaledStepSolver):
Expand Down
Loading

0 comments on commit 32a58a6

Please sign in to comment.