Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prolate spheroidal coordinates #220

Merged
merged 35 commits into from
Dec 1, 2024
Merged

Conversation

adrn
Copy link
Contributor

@adrn adrn commented Nov 13, 2024

I'm gearing up to add the Staeckel Fudge to Galax, so it would be useful to have a representation of prolate spheroidal coordinates. Transformations to this coordinate system need to specify a focal length parameter, Delta, so it's a bit different from other d3 representations. Thoughts on this approach?

@adrn
Copy link
Contributor Author

adrn commented Nov 13, 2024

Tests are failing for two reasons that might be related:

  • How can I tell coordinax that Delta is not a coordinate component and is instead a parameter that helps specify the coordinate system?
  • I can't get the self-transform to work (without explicitly specifying a value for Delta in represent_as) because something is inspecting the call signature and knows that that is a missing field in the dataclass.

One approach could be to allow a default Delta=None (because writing the coordinates themselves doesn't require a focal length, only transformations)?

@adrn adrn requested a review from nstarman November 13, 2024 16:40
@nstarman
Copy link
Contributor

nstarman commented Nov 13, 2024

How can I tell coordinax that Delta is not a coordinate component and is instead a parameter that helps specify the coordinate system?

That's a good question! We don't have a mechanism for that. This is complicated, essentially each value of $\Delta$ should result in a different representation. The non-viable but technically correct solution is to dynamically create a new class for each $Delta$. Not doing that, we could make $\Delta$ just be a property of the transformation, e.g. a mandatory kwarg to represent_as. However we still want to ensure math happens correctly, e.g. when 2 coords with different $\Delta$ are added it errors (or does some transformation then adds correctly). Therefore __add__ needs to know about $\Delta$, which can only happen if it's stored on the object. There are two options, both are a fair bit of work, and I'm not sure which is best.

  1. $\Delta$ is a field on the representation class. We would then need to change the fundamental assumption on representations, which is that all fields are coordinate axes. coordinax uses field_items() and related functions liberally to perform generalized operations on representations. We would need to define new functions like axis_field_items to replace the other functions everywhere (or leverage multiple dispatch as I discuss below).
  2. We define a new class (name TBD) RepDeltaWrapper that accepts the representation and the delta. The representation is still "pure" in that all fields are coordinates. RepDeltaWrapper would need to be made to work with the existing representation machinery. This might end up involving all the changes described for (1) + registering all the primitives.

So writing this out, (2) might be technically / mathematically better, but (1) seems best to do now.

Order of events: 😓

  1. Define an Axis field descriptor, like how I've done in astropy.cosmology.Parameter. This is used to annotate all the axis fields in coordinax (all of them currently). It should be parametric wrt the type it sets, so Axis[Distance] or Axis[Angle].
  2. Open a PR in dataclassish to define a filter flag type base class, like we did in unxt. The Flag ABC should error, the 'default' filter flag should be the same as if it weren't present. This is actually technically not needed because the multiple dispatch only cares about step 4. So maybe we just open an Issue on dataclassish for now.
  3. In coordinax define an AxesFilterFlag and register its behaviour for the field_items() etc functions. It can filter the dataclass' fields by whether they are annotated as Axis types in the class definition (they would still be Quantity on the instance).
    At this point all the existing representations should work completely unchanged since all their fields are axes.
  4. Define the prolate spheroidal coordinate classes, with Delta not as an axis. Register in all the behaviours for checking Delta when doing math operations, etc. Random extra thought: Delta should be a scalar only.

@nstarman
Copy link
Contributor

nstarman commented Nov 19, 2024

@adrn I started steps 1-3 in #225. The tests on using filtered axes is pretty minimal — just on FourVector — so there might be a couple functions I forgot to refactor. But this PR should now be able to define the Delta attribute!

@adrn adrn marked this pull request as ready for review November 25, 2024 19:15
@adrn
Copy link
Contributor Author

adrn commented Nov 25, 2024

@nstarman I think this is ready for a look! Tests are failing with a gnarly recursion error in test_jax_ops.py that I haven't figured out yet.

@adrn
Copy link
Contributor Author

adrn commented Nov 26, 2024

OK it turns out the recursion error was because one of the tests was calling the fallback represent_as(CartesianPos3D, ProlateSpheroidalPos), which shouldn't be possible because this transformation has no defined Delta, and (unrelated) was triggering a recursion. Fixing in the next batch of commits

pyproject.toml Outdated Show resolved Hide resolved
Copy link

codecov bot commented Nov 27, 2024

Codecov Report

Attention: Patch coverage is 97.35099% with 4 lines in your changes missing coverage. Please review.

Project coverage is 97.04%. Comparing base (7feffc9) to head (68ea19c).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/coordinax/_src/vectors/d3/spheroidal.py 93.65% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #220      +/-   ##
==========================================
+ Coverage   97.00%   97.04%   +0.03%     
==========================================
  Files         115      116       +1     
  Lines        3473     3617     +144     
==========================================
+ Hits         3369     3510     +141     
- Misses        104      107       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@adrn adrn added this to the v0.15.0 milestone Nov 27, 2024
@adrn
Copy link
Contributor Author

adrn commented Nov 27, 2024

@nstarman OK tests are passing, so take a look at how I handled Delta and the kwargs and let me know what you think!

Copy link
Contributor

@nstarman nstarman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good. Small comments only. Then should be ready 🚀

src/coordinax/_src/vectors/d3/spheroidal.py Outdated Show resolved Hide resolved
src/coordinax/_src/vectors/d3/spheroidal.py Show resolved Hide resolved
src/coordinax/_src/vectors/checks.py Outdated Show resolved Hide resolved
src/coordinax/_src/vectors/checks.py Outdated Show resolved Hide resolved
src/coordinax/_src/vectors/checks.py Outdated Show resolved Hide resolved
src/coordinax/_src/vectors/checks.py Show resolved Hide resolved
src/coordinax/_src/vectors/d3/transform.py Outdated Show resolved Hide resolved
src/coordinax/_src/vectors/d3/transform.py Show resolved Hide resolved
src/coordinax/_src/vectors/d3/transform.py Outdated Show resolved Hide resolved
src/coordinax/_src/vectors/d3/transform.py Show resolved Hide resolved
@adrn
Copy link
Contributor Author

adrn commented Nov 28, 2024

Hm, not sure about the failing test. Adding the converter to VectorAttribute causes filter_jit to fail here:

@eqx.filter_jit
def func(q):
    return q.represent_as(cx.ProlateSpheroidalPos, Delta=Quantity(1.0, "kpc"))

q = cx.CartesianPos3D.from_([1, 2, 3], "kpc")
func(q)

@nstarman
Copy link
Contributor

Can you try removing the parametrization in the converter? Maybe just Quantity.from_. I think the dimensions will be checked against the field annotation when it's being set.

@nstarman
Copy link
Contributor

But also Happy Thanksgiving 🦃! For after the holidays :)

@adrn
Copy link
Contributor Author

adrn commented Nov 29, 2024

Hm, removing the "length" parametrization didn't fix it! No rush to respond (but I'll be working on and off this weekend)

@nstarman
Copy link
Contributor

Then remove and flag in an Issue for followup! Let's get this in. Looks great otherwise.

nstarman
nstarman previously approved these changes Nov 29, 2024
@adrn
Copy link
Contributor Author

adrn commented Nov 30, 2024

I realized one other issue with this PR! Velocity transforms aren't working at the moment - I can look into it tomorrow:

pos = cx.CartesianPos3D.from_([1., 2., 3.], "kpc")
vel = cx.CartesianVel3D.from_([1., 2., 3.], "km/s")
vel.represent_as(cx.ProlateSpheroidalVel, pos, Delta=Quantity(0.1, "kpc"))
---------------------------------------------------------------------------
EquinoxTracetimeError                     Traceback (most recent call last)
Cell In[3], line 3
      1 pos = cx.CartesianPos3D.from_([1., 2., 3.], "kpc")
      2 vel = cx.CartesianVel3D.from_([1., 2., 3.], "km/s")
----> 3 vel.represent_as(cx.ProlateSpheroidalVel, pos, Delta=Quantity(0.1, "kpc"))

    [... skipping hidden 2 frame]

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py:449, in jaxtyped.<locals>.wrapped_fn_impl(args, kwargs, bound, memos)
    446             raise TypeCheckError(msg) from e
    448 # Actually call the function.
--> 449 out = fn(*args, **kwargs)
    451 if full_signature.return_annotation is not inspect.Signature.empty:
    452     # Now type-check the return value. We need to include the
    453     # parameters in the type-checking here in case there are any
   (...)
    464     # checking of the parameters. Unfortunately there doesn't seem
    465     # to be a way around that, so c'est la vie.
    466     kwargs[output_name] = out

File ~/projects/coordinax/src/coordinax/_src/vectors/base/base.py:203, in AbstractVector.represent_as(self, target, *args, **kwargs)
    153 def represent_as(self, target: type, *args: Any, **kwargs: Any) -> "AbstractVector":
    154     """Represent the vector as another type.
    155 
    156     This just forwards to `coordinax.represent_as`.
   (...)
    201 
    202     """
--> 203     return represent_as(self, target, *args, **kwargs)

    [... skipping hidden 2 frame]

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/jaxtyping/_decorator.py:449, in jaxtyped.<locals>.wrapped_fn_impl(args, kwargs, bound, memos)
    446             raise TypeCheckError(msg) from e
    448 # Actually call the function.
--> 449 out = fn(*args, **kwargs)
    451 if full_signature.return_annotation is not inspect.Signature.empty:
    452     # Now type-check the return value. We need to include the
    453     # parameters in the type-checking here in case there are any
   (...)
    464     # checking of the parameters. Unfortunately there doesn't seem
    465     # to be a way around that, so c'est la vie.
    466     kwargs[output_name] = out

File ~/projects/coordinax/src/coordinax/_src/vectors/transform/differentials.py:130, in represent_as(current, target, position, **kwargs)
    115 current_pos = replace(
    116     current_pos,
    117     **{
   (...)
    121     },
    122 )
    124 # Takes the Jacobian through the representation transformation function.  This
    125 # returns a representation of the target type, where the value of each field the
    126 # corresponding row of the Jacobian. The value of the field is a Quantity with
    127 # the correct numerator unit (of the Jacobian row). The value is a Vector of the
    128 # original type, with fields that are the columns of that row, but with only the
    129 # denomicator's units.
--> 130 jac_nested_vecs = jac_rep_as(current_pos, target.integral_cls)
    132 # This changes the Jacobian to be a dictionary of each row, with the value
    133 # being that row's column as a dictionary, now with the correct units for
    134 # each element:  {row_i: {col_j: Quantity(value, row.unit / column.unit)}}
    135 jac_rows = {
    136     f"d_{k}": {
    137         kk: Quantity(vv.value, unit=v.unit / vv.unit)
   (...)
    140     for k, v in field_items(jac_nested_vecs)
    141 }

    [... skipping hidden 27 frame]

File ~/projects/coordinax/src/coordinax/_src/vectors/d3/transform.py:801, in represent_as(current, target, **kwargs)
    765 @dispatch
    766 def represent_as(
    767     current: AbstractPos3D,
   (...)
    770     **kwargs: Any,
    771 ) -> ProlateSpheroidalPos:
    772     """AbstractPos3D -> ProlateSpheroidalPos.
    773 
    774     Examples
   (...)
    799 
    800     """
--> 801     Delta = eqx.error_if(
    802         kwargs.get("Delta"),
    803         "Delta" not in kwargs,
    804         "Delta must be provided for ProlateSpheroidalPos.",
    805     )
    806     cyl = represent_as(current, CylindricalPos)
    807     return represent_as(cyl, target, Delta=Delta)

    [... skipping hidden 2 frame]

File ~/projects/coordinax/.venv/lib/python3.10/site-packages/equinox/_errors.py:313, in branched_error_if_impl(x, pred, index, msgs, on_error)
    311 assert type(index) is int
    312 if on_error == "raise":
--> 313     raise EquinoxTracetimeError(msgs[index])
    314 elif on_error == "breakpoint":
    315     print(msgs[index])

EquinoxTracetimeError: Delta must be provided for ProlateSpheroidalPos.

@nstarman
Copy link
Contributor

Looks like Delta needs to be passed through the represent_as calls.

@adrn
Copy link
Contributor Author

adrn commented Nov 30, 2024

Looks like Delta needs to be passed through the represent_as calls.

Yes exactly - but the way it is vmap'd, I think the kwargs get batched alone axis=0 which would obviously fail for a scalar Delta.

@nstarman
Copy link
Contributor

nstarman commented Nov 30, 2024

If it's enforced scalar, then vmap axis for Delta should be None! Should any kwargs be mapped along an axis? Are the 2D - to - 3D kwargs in this vmap?

src/coordinax/_src/vectors/transform/d3.py Outdated Show resolved Hide resolved
src/coordinax/_src/vectors/transform/d3.py Show resolved Hide resolved
src/coordinax/_src/vectors/transform/d3.py Outdated Show resolved Hide resolved
Copy link
Contributor

@nstarman nstarman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@adrn adrn merged commit 1a79c13 into GalacticDynamics:main Dec 1, 2024
14 checks passed
@adrn
Copy link
Contributor Author

adrn commented Dec 1, 2024

🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants