Skip to content

Commit

Permalink
some examples
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed Nov 25, 2024
1 parent 817a054 commit fc5023d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
37 changes: 25 additions & 12 deletions src/coordinax/_src/vectors/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def check_non_negative_non_zero(


def check_less_than(
x: AbstractQuantity, max_val: AbstractQuantity, /, name: str = ""
x: AbstractQuantity,
max_val: AbstractQuantity,
/,
name: str = "",
comparison_name: str = "the specified maximum value",
) -> AbstractQuantity:
"""Check that the input value is less than the input maximum value.
Expand All @@ -153,12 +157,16 @@ def check_less_than(
"""
name = f" {name}" if name else name
msg = f"The input{name} must be less than the specified maximum value."
msg = f"The input{name} must be less than {comparison_name}."
return eqx.error_if(x, xp.any(x >= max_val), msg)


def check_less_than_equal(
x: AbstractQuantity, max_val: AbstractQuantity, /, name: str = ""
x: AbstractQuantity,
max_val: AbstractQuantity,
/,
name: str = "",
comparison_name: str = "the specified maximum value",
) -> AbstractQuantity:
"""Check that the input value is less than or equal to the input maximum value.
Expand All @@ -174,12 +182,16 @@ def check_less_than_equal(
"""
name = f" {name}" if name else name
msg = f"The input{name} must be less than or equal to the specified maximum value."
msg = f"The input{name} must be less than or equal to {comparison_name}."
return eqx.error_if(x, xp.any(x > max_val), msg)


def check_greater_than(
x: AbstractQuantity, min_val: AbstractQuantity, /, name: str = ""
x: AbstractQuantity,
min_val: AbstractQuantity,
/,
name: str = "",
comparison_name: str = "the specified minimum value",
) -> AbstractQuantity:
"""Check that the input value is greater than the input minimum value.
Expand All @@ -195,12 +207,16 @@ def check_greater_than(
"""
name = f" {name}" if name else name
msg = f"The input{name} must be greater than the specified minimum value."
msg = f"The input{name} must be greater than {comparison_name}."
return eqx.error_if(x, xp.any(x <= min_val), msg)


def check_greater_than_equal(
x: AbstractQuantity, min_val: AbstractQuantity, /, name: str = ""
x: AbstractQuantity,
min_val: AbstractQuantity,
/,
name: str = "",
comparison_name: str = "the specified minimum value",
) -> AbstractQuantity:
"""Check that the input value is greater than or equal to the input minimum value.
Expand All @@ -211,13 +227,10 @@ def check_greater_than_equal(
Raise an error if the input is smaller than the minimum value.
>>> x = Quantity([-1, 1, 2], "m")
>>> try: check_greater_than(x, 1.0)
>>> try: check_greater_than_equal(x, 1.0)
... except Exception: pass
"""
name = f" {name}" if name else name
msg = (
f"The input{name} must be greater than or equal to the specified minimum "
"value."
)
msg = f"The input{name} must be greater than or equal to {comparison_name}."
return eqx.error_if(x, xp.any(x < min_val), msg)
45 changes: 42 additions & 3 deletions src/coordinax/_src/vectors/d3/spheroidal.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,42 @@ class ProlateSpheroidalPos(AbstractPos3D):
Examples
--------
TODO: add valid and invalid examples
>>> from unxt import Quantity
>>> import coordinax as cx
>>> vec = cx.ProlateSpheroidalPos(
... mu=Quantity(3.0, "kpc2"),
... nu=Quantity(0.5, "kpc2"),
... phi=Quantity(0.25, "rad"),
... Delta=Quantity(1.5, "kpc"),
... )
>>> vec
ProlateSpheroidalPos(
mu=Quantity[PhysicalType('area')](value=f32[], unit=Unit("kpc2")),
nu=Quantity[PhysicalType('area')](value=f32[], unit=Unit("kpc2")),
phi=Angle(value=f32[], unit=Unit("rad")),
Delta=Quantity[PhysicalType('length')](value=weak_f32[], unit=Unit("kpc"))
)
This fails with a zero or negative Delta:
>>> try: vec = cx.ProlateSpheroidalPos(
... mu=Quantity(3.0, "kpc2"),
... nu=Quantity(0.5, "kpc2"),
... phi=Quantity(0.25, "rad"),
... Delta=Quantity(0.0, "kpc"),
... )
... except Exception as e: pass
Or with invalid mu and nu:
>>> try: vec = cx.ProlateSpheroidalPos(
... mu=Quantity(0.5, "kpc2"),
... nu=Quantity(0.5, "kpc2"),
... phi=Quantity(0.25, "rad"),
... Delta=Quantity(1.5, "kpc"),
... )
... except Exception as e: pass
"""

Expand Down Expand Up @@ -77,8 +112,12 @@ class ProlateSpheroidalPos(AbstractPos3D):
def __check_init__(self) -> None:
"""Check the validity of the initialization."""
check_non_negative_non_zero(self.Delta, name="Delta")
check_greater_than_equal(self.mu, self.Delta**2, name="mu")
check_less_than_equal(jnp.abs(self.nu), self.Delta**2, name="nu")
check_greater_than_equal(
self.mu, self.Delta**2, name="mu", comparison_name="Delta^2"
)
check_less_than_equal(
jnp.abs(self.nu), self.Delta**2, name="nu", comparison_name="Delta^2"
)

@classproperty
@classmethod
Expand Down

0 comments on commit fc5023d

Please sign in to comment.