Skip to content

Commit

Permalink
feat: register concat(jax, Q[dimensionless]) (#53)
Browse files Browse the repository at this point in the history
* feat: register concat(jax, Q[dimensionless])

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Mar 3, 2024
1 parent 2bf8f20 commit c9f0e33
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/jax_quantity/_register_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _concatenate_p(*operands: Quantity, dimension: Any) -> Quantity:


@register(lax.concatenate_p)
def _concatenate_p_jqnd(
def _concatenate_p_qnd(
operand0: Quantity["dimensionless"], # type: ignore[type-arg]
*operands: Quantity["dimensionless"] | ArrayLike, # type: ignore[type-arg]
dimension: Any,
Expand Down Expand Up @@ -385,6 +385,41 @@ def _concatenate_p_jqnd(
)


@register(lax.concatenate_p)
def _concatenate_p_jqnd(
operand0: ArrayLike,
*operands: Quantity["dimensionless"], # type: ignore[type-arg]
dimension: Any,
) -> Quantity["dimensionless"]: # type: ignore[type-arg]
"""Concatenate quantities and arrays with dimensionless units.
Examples
--------
>>> import array_api_jax_compat as xp
>>> from jax_quantity import Quantity
>>> theta = Quantity(45, "deg")
>>> Rx = xp.asarray([[1.0, 0.0, 0.0 ],
... [0.0, xp.cos(theta), -xp.sin(theta)],
... [0.0, xp.sin(theta), xp.cos(theta) ]])
>>> Rx
Quantity[...](Array([[ 1. , 0. , 0. ],
[ 0. , 0.70710677, -0.70710677],
[ 0. , 0.70710677, 0.70710677]], dtype=float32),
unit='')
"""
return Quantity(
lax.concatenate(
[
(op.to_value(dimensionless) if hasattr(op, "unit") else op)
for op in (operand0, *operands)
],
dimension=dimension,
),
unit=dimensionless,
)


# ==============================================================================


Expand Down

0 comments on commit c9f0e33

Please sign in to comment.