Skip to content

Commit

Permalink
Clean up ticker class, make it immutable and hashable.
Browse files Browse the repository at this point in the history
  • Loading branch information
cmeyer committed Aug 28, 2024
1 parent 61f0213 commit 6daad9b
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 114 deletions.
245 changes: 147 additions & 98 deletions nion/utils/Geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def arange(start: float, stop: float, step: float) -> typing.Sequence[float]:
return [start + x * step for x in range(math.ceil((stop - start) / step))]


def make_pretty_range2(value_low: float, value_high: float, ticks: int = 5, logarithmic: bool = False) -> typing.Tuple[float, float, typing.Sequence[float], float, int, float]:
def make_pretty_range2(value_low: float, value_high: float, ticks: int = 5, logarithmic: bool = False) -> typing.Tuple[float, float, typing.Tuple[float, ...], float, int, float]:
"""Returns minimum, maximum, list of tick values, division, and precision.
Value high and value low specify the data range.
Expand All @@ -98,7 +98,7 @@ def make_pretty_range2(value_low: float, value_high: float, ticks: int = 5, loga

# check for small range
if value_high == value_low:
return value_low, value_low, [value_low], 0, 0, 0
return value_low, value_low, (value_low,), 0, 0, 0

# make the value range a pretty range
value_range = make_pretty2(value_high - value_low, False)
Expand All @@ -108,7 +108,7 @@ def make_pretty_range2(value_low: float, value_high: float, ticks: int = 5, loga

# calculate the graph minimum and maximum
if division == 0:
return 0, 0, [0], 0, 0, 0
return 0, 0, (0,), 0, 0, 0

graph_minimum = math.floor(value_low / division) * division
graph_maximum = math.ceil(value_high / division) * division
Expand All @@ -127,151 +127,200 @@ def make_pretty_range2(value_low: float, value_high: float, ticks: int = 5, loga
for x in arange(graph_minimum, graph_maximum + 0.5 * division, division):
tick_values.append(x)

return graph_minimum, graph_maximum, tick_values, division, precision, factor10
return graph_minimum, graph_maximum, tuple(tick_values), division, precision, factor10


def make_pretty_range(value_low: float, value_high: float, tight: bool = False, ticks: int = 5) -> typing.Tuple[float, float, typing.Sequence[float], float, int]:
return make_pretty_range2(value_low, value_high, ticks)[:-1]


@dataclasses.dataclass(frozen=True)
class TickerValues:
"""A class representing the initial values of a ticker."""
value_low: float
value_high: float
ticks: int = 5
tick_values: typing.Sequence[float] = dataclasses.field(default_factory=tuple)
tick_labels: typing.Sequence[str] = dataclasses.field(default_factory=tuple)
minor_tick_indices: typing.Sequence[int] = dataclasses.field(default_factory=tuple)
minimum: float = 0.0
maximum: float = 0.0
division: float = 1.0
precision: int = 0


class Ticker:
def __init__(self, ticker_values: TickerValues) -> None:
self.__ticker_values = ticker_values

def __init__(self, value_low: float, value_high: float, *, ticks: int = 5) -> None:
self._value_low = value_low
self._value_high = value_high
self._ticks = ticks

self._tick_values: typing.Sequence[float] = []
self._tick_labels: typing.Sequence[str] = []
self._minor_tick_indices: typing.List[int] = []
self._minimum = 0.0
self._maximum = 0.0
self._division = 1.0
self._precision = 0
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, Ticker):
return False
return self._is_equal(other)

def __hash__(self) -> int:
return hash(self._hash_components())

def _is_equal(self, other: Ticker) -> bool:
return self.__ticker_values == other.__ticker_values

def _hash_components(self) -> typing.Tuple[typing.Any, ...]:
return (self.__ticker_values,)

@property
def value_low(self) -> float:
return self.__ticker_values.value_low

@property
def value_high(self) -> float:
return self.__ticker_values.value_high

def value_label(self, value: float) -> str:
raise NotImplementedError
raise NotImplementedError()

@property
def ticks(self) -> int:
return self._ticks
return self.__ticker_values.ticks

@property
def values(self) -> typing.Sequence[float]:
return self._tick_values
return self.__ticker_values.tick_values

@property
def labels(self) -> typing.Sequence[str]:
return self._tick_labels
return self.__ticker_values.tick_labels

@property
def minimum(self) -> float:
return self._minimum
return self.__ticker_values.minimum

@property
def maximum(self) -> float:
return self._maximum
return self.__ticker_values.maximum

@property
def division(self) -> float:
return self._division
return self.__ticker_values.division

@property
def precision(self) -> int:
return self._precision
return self.__ticker_values.precision

@property
def minor_tick_indices(self) -> typing.Sequence[int]:
return self._minor_tick_indices
return self.__ticker_values.minor_tick_indices


class LinearTicker(Ticker):
def linear_value_label(value: float, precision: int, factor10: float) -> str:
f10 = int(math.log10(factor10)) if factor10 > 0 else 0
if abs(f10) > 5:
f10x = int(math.log10(value)) if value > 0 else f10
precision = max(0, f10x - f10)
return (u"{0:0." + u"{0:d}".format(precision) + "e}").format(value)
else:
return (u"{0:0." + u"{0:d}".format(precision) + "f}").format(value)

def __init__(self, value_low: float, value_high: float, *, ticks: int=5):
super().__init__(value_low, value_high, ticks=ticks)
self._minimum, self._maximum, self._tick_values, self._division, self._precision, self._factor10 = make_pretty_range2(value_low, value_high, ticks=ticks)
self._tick_labels = list(self.value_label(tick_value) for tick_value in self._tick_values)

def __nice_label(self, value: float, precision: int, factor10: float) -> str:
f10 = int(math.log10(factor10)) if factor10 > 0 else 0
if abs(f10) > 5:
f10x = int(math.log10(value)) if value > 0 else f10
precision = max(0, f10x - f10)
return (u"{0:0." + u"{0:d}".format(precision) + "e}").format(value)
else:
return (u"{0:0." + u"{0:d}".format(precision) + "f}").format(value)

def value_label(self, value: float) -> str:
return self.__nice_label(value, self.precision, self._factor10)
def log_value_label(value: float, precision: int) -> str:
return (u"{0:." + u"{0:d}".format(precision) + "e}").format(value)


class LogTicker(Ticker):
def configure_log_ticker_values(value_low: float, value_high: float, ticks: int, base: int) -> TickerValues:
if not all([math.isfinite(val) for val in [value_low, value_high, base]]):
return TickerValues(value_low, value_high, ticks, (1,), ("0e+00",))

def __init__(self, value_low: float, value_high:float, *, ticks: int = 5, base: int = 10):
super().__init__(value_low, value_high, ticks=ticks)
self._base = base

if not all([math.isfinite(val) for val in [value_low, value_high, base]]):
self._tick_values = [1]
self._tick_labels = ["0e+00"]
return

val_range = abs(self._value_high - self._value_low)
self._factor_b = math.pow(self.base, math.floor(math.log(val_range, self.base))) if (self._ticks-2)/self._base > val_range > 0 else 1.0
self._minimum = math.floor(self._value_low / self._factor_b)
self._maximum = max(math.ceil(self._value_high / self._factor_b), self._minimum + 1)
self._precision = round(abs(math.log(self._factor_b, self.base)))

numdec = self._maximum - self._minimum

while abs(numdec) > 1.5 * val_range and self._factor_b == 1.0 and numdec > 0:
numdec -= 1

self._division = max((numdec + 1) // self._ticks, 1)
decades = arange(self._minimum, self._maximum + self.division, self.division)
if self._factor_b == 1.0:
self._maximum = self._minimum + numdec
# We will get len(decades) * subs ticks, so calculate the number of subs we need
num_subs = self._ticks / (val_range / self._division) if val_range > 0 else 0.0

subs: typing.List[float]
if self._factor_b != 1.0:
subs = []
elif num_subs >= (self.base - 2):
subs = list(arange(2, self.base, 1))
elif num_subs >= (self.base - 2) * 0.5:
subs = list(arange(2, self.base, 2))
elif num_subs >= (self.base - 2) * 0.25:
subs = [round(self.base * 0.5)]
else:
subs = []
val_range = abs(value_high - value_low)
factor_b = math.pow(base, math.floor(math.log(val_range, base))) if (ticks - 2) / base > val_range > 0 else 1.0
minimum = float(math.floor(value_low / factor_b))
maximum = float(max(float(math.ceil(value_high / factor_b)), minimum + 1))
precision = round(abs(math.log(factor_b, base)))

numdec = maximum - minimum

if subs and self._value_high >= self._maximum:
high_floor = math.floor(self._value_high)
self._maximum = high_floor + math.log(math.floor(math.pow(self.base, self._value_high - high_floor)) + 1, self.base)
while abs(numdec) > 1.5 * val_range and factor_b == 1.0 and numdec > 0:
numdec -= 1

tick_values = list()
for decade_start in decades:
decade = math.pow(self.base, decade_start * self._factor_b)
tick_values.append(decade)
for sub in subs:
tick_values.append(sub * decade)
self._minor_tick_indices.append(len(tick_values) - 1)
division = max((numdec + 1) // ticks, 1)
decades = arange(minimum, maximum + division, division)
if factor_b == 1.0:
maximum = minimum + numdec
# We will get len(decades) * subs ticks, so calculate the number of subs we need
num_subs = ticks / (val_range / division) if val_range > 0 else 0.0

self._tick_labels = [self.value_label(value) for value in tick_values]
self._tick_values = [math.log(value, self.base) for value in tick_values]
subs: typing.List[float]
if factor_b != 1.0:
subs = []
elif num_subs >= (base - 2):
subs = list(arange(2, base, 1))
elif num_subs >= (base - 2) * 0.5:
subs = list(arange(2, base, 2))
elif num_subs >= (base - 2) * 0.25:
subs = [round(base * 0.5)]
else:
subs = []

if subs and value_high >= maximum:
high_floor = math.floor(value_high)
maximum = high_floor + math.log(math.floor(math.pow(base, value_high - high_floor)) + 1, base)

tick_values = list[float]()
minor_tick_indices = list[int]()
for decade_start in decades:
decade = math.pow(base, decade_start * factor_b)
tick_values.append(decade)
for sub in subs:
tick_values.append(sub * decade)
minor_tick_indices.append(len(tick_values) - 1)

tick_labels = [log_value_label(value, precision) for value in tick_values]
tick_values = [math.log(value, base) for value in tick_values]

# Revert maximum to its original value because it is used for auto display limits
maximum *= factor_b
# Set minimum slightly lower than the data minimum because it is used for auto display limits
minimum = value_low - (maximum - value_low) * 0.01

return TickerValues(value_low, value_high, ticks, tuple(tick_values), tuple(tick_labels), tuple(minor_tick_indices), minimum, maximum, division, precision)

# Revert maximum to its original value because it is used for auto display limits
self._maximum *= self._factor_b
# Set minimum slightly lower than the data minimum because it is used for auto display limits
self._minimum = self._value_low - (self._maximum - self._value_low) * 0.01

class LinearTicker(Ticker):
def __init__(self, value_low: float, value_high: float, *, ticks: int = 5) -> None:
minimum, maximum, tick_values, division, precision, factor10 = make_pretty_range2(value_low, value_high, ticks=ticks)
tick_labels = tuple(linear_value_label(tick_value, precision, factor10) for tick_value in tick_values)
super().__init__(TickerValues(value_low, value_high, ticks, tick_values, tick_labels, tuple(), minimum, maximum, division, precision))
self.__factor10 = factor10

def value_label(self, value: float) -> str:
return (u"{0:." + u"{0:d}".format(self.precision) + "e}").format(value)
return linear_value_label(value, self.precision, self.__factor10)

def _is_equal(self, other: Ticker) -> bool:
if not isinstance(other, LinearTicker):
return False
return super()._is_equal(other) and self.__factor10 == other.__factor10

def _hash_components(self) -> typing.Tuple[typing.Any, ...]:
return super()._hash_components() + (self.__factor10,)


class LogTicker(Ticker):
def __init__(self, value_low: float, value_high:float, *, ticks: int = 5, base: int = 10):
super().__init__(configure_log_ticker_values(value_low, value_high, ticks, base))
self.__base = base

def value_label(self, value: float) -> str:
return log_value_label(value, self.precision)

def _is_equal(self, other: Ticker) -> bool:
if not isinstance(other, LogTicker):
return False
return super()._is_equal(other) and self.base == other.base

def _hash_components(self) -> typing.Tuple[typing.Any, ...]:
return super()._hash_components() + (self.base,)

@property
def base(self) -> int:
return self._base
return self.__base


def fit_to_aspect_ratio(rect_: typing.Union[FloatRectTuple, IntRectTuple], aspect_ratio: float) -> FloatRect:
Expand Down
51 changes: 35 additions & 16 deletions nion/utils/test/Geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,24 @@ def test_ticker_produces_unique_labels(self) -> None:
# print(ticker.labels)

def test_linear_ticker_handles_edge_cases(self) -> None:
self.assertEqual(Geometry.LinearTicker(0, 0).labels, ['0'])
self.assertEqual(Geometry.LinearTicker(1, 1).labels, ['1'])
self.assertEqual(Geometry.LinearTicker(-1, -1).labels, ['-1'])
self.assertEqual(Geometry.LinearTicker(-math.inf, math.inf).labels, ['0'])
self.assertEqual(Geometry.LinearTicker(-math.nan, math.nan).labels, ['0'])
self.assertEqual(Geometry.LinearTicker(math.nan, 1).labels, ['0'])
self.assertEqual(Geometry.LinearTicker(-math.inf, 1).labels, ['0'])
self.assertEqual(Geometry.LinearTicker(0, math.inf).labels, ['0'])
self.assertEqual(Geometry.LinearTicker(0, 0).labels, ('0',))
self.assertEqual(Geometry.LinearTicker(1, 1).labels, ('1',))
self.assertEqual(Geometry.LinearTicker(-1, -1).labels, ('-1',))
self.assertEqual(Geometry.LinearTicker(-math.inf, math.inf).labels, ('0',))
self.assertEqual(Geometry.LinearTicker(-math.nan, math.nan).labels, ('0',))
self.assertEqual(Geometry.LinearTicker(math.nan, 1).labels, ('0',))
self.assertEqual(Geometry.LinearTicker(-math.inf, 1).labels, ('0',))
self.assertEqual(Geometry.LinearTicker(0, math.inf).labels, ('0',))

def test_log_ticker_handles_edge_cases(self) -> None:
self.assertEqual(Geometry.LogTicker(0, 0, ticks=3).labels, ['1e+00', '1e+01'])
self.assertEqual(Geometry.LogTicker(1, 1, ticks=3).labels, ['1e+01', '1e+02'])
self.assertEqual(Geometry.LogTicker(-1, -1, ticks=3).labels, ['1e-01', '1e+00'])
self.assertEqual(Geometry.LogTicker(-math.inf, math.inf).labels, ['0e+00'])
self.assertEqual(Geometry.LogTicker(-math.nan, math.nan).labels, ['0e+00'])
self.assertEqual(Geometry.LogTicker(math.nan, 1).labels, ['0e+00'])
self.assertEqual(Geometry.LogTicker(-math.inf, 1).labels, ['0e+00'])
self.assertEqual(Geometry.LogTicker(0, math.inf).labels, ['0e+00'])
self.assertEqual(Geometry.LogTicker(0, 0, ticks=3).labels, ('1e+00', '1e+01'))
self.assertEqual(Geometry.LogTicker(1, 1, ticks=3).labels, ('1e+01', '1e+02'))
self.assertEqual(Geometry.LogTicker(-1, -1, ticks=3).labels, ('1e-01', '1e+00'))
self.assertEqual(Geometry.LogTicker(-math.inf, math.inf).labels, ('0e+00',))
self.assertEqual(Geometry.LogTicker(-math.nan, math.nan).labels, ('0e+00',))
self.assertEqual(Geometry.LogTicker(math.nan, 1).labels, ('0e+00',))
self.assertEqual(Geometry.LogTicker(-math.inf, 1).labels, ('0e+00',))
self.assertEqual(Geometry.LogTicker(0, math.inf).labels, ('0e+00',))

def test_ticker_produces_expected_labels(self) -> None:
self.assertListEqual(list(Geometry.LinearTicker(0, 1e8, ticks=3).labels), ['0e+00', '5e+07', '1.0e+08'])
Expand All @@ -95,6 +95,25 @@ def test_ticker_value_label(self) -> None:
ticker = Geometry.LinearTicker(mn, mx)
self.assertIsNotNone(ticker.value_label(900000))

def test_linear_ticker_equal(self) -> None:
self.assertEqual(Geometry.LinearTicker(0, 1), Geometry.LinearTicker(0, 1))
self.assertNotEqual(Geometry.LinearTicker(0, 1), Geometry.LinearTicker(0, 2))
self.assertNotEqual(Geometry.LinearTicker(0, 1), Geometry.LogTicker(1, 1))

def test_log_ticker_equal(self) -> None:
self.assertEqual(Geometry.LogTicker(0, 1), Geometry.LogTicker(0, 1))
self.assertNotEqual(Geometry.LogTicker(0, 1), Geometry.LogTicker(0, 2))
self.assertNotEqual(Geometry.LogTicker(0, 1), Geometry.LinearTicker(1, 1))

def test_linear_ticker_hash(self) -> None:
d = {Geometry.LinearTicker(0, 1): Geometry.LinearTicker(0, 1)}
self.assertEqual(d[Geometry.LinearTicker(0, 1)], Geometry.LinearTicker(0, 1))

def test_log_ticker_hash(self) -> None:
d = {Geometry.LogTicker(0, 1): Geometry.LogTicker(0, 1)}
self.assertEqual(d[Geometry.LogTicker(0, 1)], Geometry.LogTicker(0, 1))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.DEBUG)
unittest.main()

0 comments on commit 6daad9b

Please sign in to comment.