Skip to content

Commit

Permalink
minor cleanup to IterControl
Browse files Browse the repository at this point in the history
  • Loading branch information
kkappler committed Jul 12, 2024
1 parent 4db8c40 commit 792c859
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
66 changes: 48 additions & 18 deletions aurora/transfer_function/regression/iter_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
follows Gary's IterControl.m in
iris_mt_scratch/egbert_codes-20210121T193218Z-001/egbert_codes/matlabPrototype_10-13-20/TF/classes
"""
from loguru import logger
import numpy as np

from aurora.transfer_function.regression.helper_functions import rme_beta
Expand Down Expand Up @@ -46,8 +47,8 @@ def __init__(
etc.
"""
self.number_of_iterations = 0
self.number_of_redescending_iterations = 0
self._number_of_iterations = 0
self._number_of_redescending_iterations = 0

self.tolerance = 0.005
self.epsilon = 1000
Expand All @@ -69,14 +70,45 @@ def __init__(
self.robust_diagonalize = False
# </Additional properties>

def converged(self, b, b0):
@property
def number_of_iterations(self) -> int:
return self._number_of_iterations

# @number_of_iterations.setter
# def number_of_iterations(self, value) -> int:
# self._number_of_iterations = value

def reset_number_of_iterations(self) -> int:
self._number_of_iterations = 0

def increment_iteration_number(self):
self._number_of_iterations += 1

@property
def number_of_redescending_iterations(self) -> int:
return self._number_of_redescending_iterations

@number_of_redescending_iterations.setter
def number_of_redescending_iterations(self, value) -> int:
self._number_of_redescending_iterations = value

def increment_redescending_iteration_number(self):
self._number_of_redescending_iterations += 1

@property
def max_iterations_reached(self):
return self.number_of_iterations >= self.max_number_of_iterations

def converged(self, b, b0, verbose=False):
"""
Parameters
----------
b : complex-valued numpy array
the most recent regression estimate
b0 : complex-valued numpy array
The previous regression estimate
verbose: bool
Set to True for debugging
Returns
-------
Expand All @@ -89,27 +121,25 @@ def converged(self, b, b0):
1 - abs(b/b0), however, that will be insensitive to phase changes in b,
which is complex valued. The way it is coded np.max(np.abs(1 - b / b0)) is
correct as it stands.
"""

converged = False
maximum_change = np.max(np.abs(1 - b / b0))
tolerance_cond = maximum_change <= self.tolerance
iteration_cond = self.number_of_iterations >= self.max_number_of_iterations
iteration_cond = self.max_iterations_reached
if tolerance_cond or iteration_cond:
converged = True
# These print statments are not very clear and
# Should be reworded.
# if tolerance_cond:
# print(
# f"Converged Due to MaxChange < Tolerance after "
# f" {self.number_of_iterations} of "
# f" {self.max_number_of_iterations} iterations"
# )
# elif iteration_cond:
# print(
# f"Converged Due to maximum number_of_iterations "
# f" {self.max_number_of_iterations}"
# )
if verbose:
msg_start = "Converged due to"
msg_end = (
f"{self.number_of_iterations} of "
f"{self.max_number_of_iterations} iterations"
)
if tolerance_cond:
msg = f"{msg_start} MaxChange < Tolerance after {msg_end}"
elif iteration_cond:
msg = f"{msg_start} maximum number_of_iterations {msg_end}"
logger.info(msg)
else:
converged = False

Expand Down
16 changes: 11 additions & 5 deletions aurora/transfer_function/regression/m_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,18 @@ def initial_estimate(self):
self.update_residual_variance()

def apply_huber_regression(self):
"""This is the 'convergence loop' from TRME, TRME_RR"""
converged = self.iter_control.max_number_of_iterations <= 0
self.iter_control.number_of_iterations = 0
"""
This is the 'convergence loop' from TRME, TRME_RR
TODO: Consider not setting iter_control.number_of_iterations
- Instead, Initialize a new iter_control object
"""
converged = self.iter_control.max_iterations_reached
if self.iter_control.number_of_iterations:
self.iter_control.reset_number_of_iterations()
while not converged:
b0 = self.b
self.iter_control.number_of_iterations += 1
self.iter_control.increment_iteration_number()
self.update_y_cleaned_via_huber_weights()
self.update_b()
self.update_y_hat()
Expand All @@ -210,7 +216,7 @@ def apply_redecending_influence_function(self):
if self.iter_control.max_number_of_redescending_iterations:
self.iter_control.number_of_redescending_iterations = 0 # reset per channel
while self.iter_control.continue_redescending:
self.iter_control.number_of_redescending_iterations += 1
self.iter_control.increment_redescending_iteration_number()
self.update_y_cleaned_via_redescend_weights()
self.update_b()
self.update_y_hat()
Expand Down

0 comments on commit 792c859

Please sign in to comment.