Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MLX arrays #301

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

gabrieldemarmiesse
Copy link

@gabrieldemarmiesse gabrieldemarmiesse commented Feb 21, 2025

Fixes #299

@gabrieldemarmiesse gabrieldemarmiesse changed the title Allow MLX types to be recognized Add support for MLX arrays Feb 21, 2025
@gabrieldemarmiesse gabrieldemarmiesse marked this pull request as ready for review February 21, 2025 10:50
@@ -617,15 +617,14 @@ def _make_array(x, dim_str, dtype):

if type(out) is tuple:
array_type, name, dtypes, dims, index_variadic, dim_str = out
# Nanobind classes can't be a base type
can_subclass = array_type is Any or "nanobind." in str(type(array_type))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was afraid things might look something like this!

FWIW we could maybe just remove subclassing altogether -- I'm not sure how much this buys us, and we're really threading the needle here to do so -- WDYT?

(I think I might have introduced this subclassing to make things work better when using type annotations with plum, although I might have that wrong. I think that's a use-case that can be supported in other ways, though.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no opinion on the matter. I'll remove the subclassing entirely and see where it goes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

______________________________________________________ test_subclass _______________________________________________________

    def test_subclass():
>       assert issubclass(Float[Array, ""], Array)
E       AssertionError: assert False
E        +  where False = issubclass(<class 'jaxtyping.Float[Array, '']'>, Array)

test/test_array.py:607: AssertionError

The only test failing is this one: test/test_array.py::test_subclass - AssertionError: assert False. Every assertion on this test is failing.
I find it strange that a type hint need to be a subclass of a type. If that's not needed, we can remove this completely. How should I proceed?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I think let's go ahead and remove it then!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, I'll let you review one more time!

else _make_metaclass(type(array_type))
)

metaclass = _make_metaclass(type)
out = metaclass(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip creating the metaclass too? And have this line just be type(...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for MLX
2 participants