diff --git a/pyproject.toml b/pyproject.toml index 7fde9593..ef5dc5e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ "optional-dependencies>=0.3.2", "plum-dispatch>=2.5.4", "quax>=0.0.5", - "quaxed>=0.7.0", + "quaxed>=0.7.1", "xmmutablemap>=0.1", "zeroth>=1.0.0", ] diff --git a/src/unxt/_src/quantity/register_primitives.py b/src/unxt/_src/quantity/register_primitives.py index 1a052982..d3fe5ed0 100644 --- a/src/unxt/_src/quantity/register_primitives.py +++ b/src/unxt/_src/quantity/register_primitives.py @@ -77,7 +77,7 @@ def _abs_p(x: AbstractQuantity) -> AbstractQuantity: UncheckedQuantity(Array(1, dtype=int32, ...), unit='m') """ - return replace(x, value=lax.abs(ustrip(x))) + return replace(x, value=qlax.abs(ustrip(x))) # ============================================================================== @@ -102,7 +102,7 @@ def _acos_p_aq(x: AbstractQuantity) -> AbstractQuantity: """ x_ = ustrip(one, x) - return type_np(x)(value=lax.acos(x_), unit=radian) + return type_np(x)(value=qlax.acos(x_), unit=radian) # ============================================================================== @@ -127,7 +127,7 @@ def _acosh_p_aq(x: AbstractQuantity) -> AbstractQuantity: """ x_ = ustrip(one, x) - return type_np(x)(value=lax.acosh(x_), unit=radian) + return type_np(x)(value=qlax.acosh(x_), unit=radian) # ============================================================================== @@ -159,7 +159,7 @@ def _add_p_aqaq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array(1.5, dtype=float32, ...), unit='km') """ - return replace(x, value=lax.add(ustrip(x), ustrip(x.unit, y))) + return replace(x, value=qlax.add(ustrip(x), ustrip(x.unit, y))) @register(lax.add_p) @@ -226,7 +226,7 @@ def _add_p_vaq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity: """ y = uconvert(one, y) - return replace(y, value=lax.add(x, ustrip(y))) + return replace(y, value=qlax.add(x, ustrip(y))) @register(lax.add_p) @@ -299,7 +299,7 @@ def _add_p_aqv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity: """ x = uconvert(one, x) - return replace(x, value=lax.add(ustrip(x), y)) + return replace(x, value=qlax.add(ustrip(x), y)) # ============================================================================== @@ -380,7 +380,7 @@ def _argmax_p( UncheckedQuantity(Array(2, dtype=int32), unit='m') """ - return replace(operand, value=lax.argmax(ustrip(operand), axes[0], index_dtype)) + return replace(operand, value=qlax.argmax(ustrip(operand), axes[0], index_dtype)) # ============================================================================== @@ -406,7 +406,7 @@ def _argmin_p( UncheckedQuantity(Array(0, dtype=int32), unit='m') """ - return replace(operand, value=lax.argmin(ustrip(operand), axes[0], index_dtype)) + return replace(operand, value=qlax.argmin(ustrip(operand), axes[0], index_dtype)) # ============================================================================== @@ -688,7 +688,7 @@ def _atanh_p_q( @register(lax.broadcast_in_dim_p) def _broadcast_in_dim_p(operand: AbstractQuantity, **kwargs: Any) -> AbstractQuantity: """Broadcast a quantity in a specific dimension.""" - return replace(operand, value=lax.broadcast_in_dim(ustrip(operand), **kwargs)) + return replace(operand, value=qlax.broadcast_in_dim(ustrip(operand), **kwargs)) # ============================================================================== @@ -737,7 +737,7 @@ def _ceil_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array(2., dtype=float32, ...), unit='m') """ - return replace(x, value=lax.ceil(ustrip(x))) + return replace(x, value=qlax.ceil(ustrip(x))) # ============================================================================== @@ -770,7 +770,7 @@ def _clamp_p( """ return replace( - x, value=lax.clamp(ustrip(x.unit, min), ustrip(x), ustrip(x.unit, max)) + x, value=qlax.clamp(ustrip(x.unit, min), ustrip(x), ustrip(x.unit, max)) ) @@ -803,7 +803,7 @@ def _clamp_p_vaqaq( Quantity['dimensionless'](Array([0, 1, 2], dtype=int32), unit='') """ - return replace(x, value=lax.clamp(min, ustrip(one, x), ustrip(one, max))) + return replace(x, value=qlax.clamp(min, ustrip(one, x), ustrip(one, max))) # --------------------------- @@ -877,7 +877,7 @@ def _clamp_p_aqaqv( UncheckedQuantity(Array([0, 1, 2], dtype=int32), unit='') """ - return replace(x, value=lax.clamp(ustrip(one, min), ustrip(one, x), max)) + return replace(x, value=qlax.clamp(ustrip(one, min), ustrip(one, x), max)) @register(lax.clamp_p) @@ -901,7 +901,7 @@ def _clamp_p_qqv( Quantity['dimensionless'](Array([0, 1, 2], dtype=int32), unit='') """ - return replace(x, value=lax.clamp(ustrip(one, min), ustrip(one, x), max)) + return replace(x, value=qlax.clamp(ustrip(one, min), ustrip(one, x), max)) # ============================================================================== @@ -929,7 +929,7 @@ def _complex_p(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: """ x, y = promote(x, y) # e.g. Distance -> Quantity y_ = ustrip(x.unit, y) - return replace(x, value=lax.complex(ustrip(x), y_)) + return replace(x, value=qlax.complex(ustrip(x), y_)) # ============================================================================== @@ -961,7 +961,7 @@ def _concatenate_p_aq(*operands: AbstractQuantity, dimension: Any) -> AbstractQu units_ = operand0.unit return replace( operand0, - value=lax.concatenate( + value=qlax.concatenate( [ustrip(units_, op) for op in operands], dimension=dimension ), ) @@ -1096,7 +1096,7 @@ def _conj_p(x: AbstractQuantity, *, input_dtype: Any) -> AbstractQuantity: """ del input_dtype # TODO: use this? - return replace(x, value=lax.conj(ustrip(x))) + return replace(x, value=qlax.conj(ustrip(x))) # ============================================================================== @@ -1256,7 +1256,7 @@ def _cumlogsumexp_p( """ # TODO: double check units make sense here. return replace( - operand, value=lax.cumlogsumexp(ustrip(operand), axis=axis, reverse=reverse) + operand, value=qlax.cumlogsumexp(ustrip(operand), axis=axis, reverse=reverse) ) @@ -1285,7 +1285,7 @@ def _cummax_p( """ return replace( - operand, value=lax.cummax(ustrip(operand), axis=axis, reverse=reverse) + operand, value=qlax.cummax(ustrip(operand), axis=axis, reverse=reverse) ) @@ -1314,7 +1314,7 @@ def _cummin_p( """ return replace( - operand, value=lax.cummin(ustrip(operand), axis=axis, reverse=reverse) + operand, value=qlax.cummin(ustrip(operand), axis=axis, reverse=reverse) ) @@ -1343,7 +1343,7 @@ def _cumprod_p( """ return replace( - operand, value=lax.cumprod(ustrip(one, operand), axis=axis, reverse=reverse) + operand, value=qlax.cumprod(ustrip(one, operand), axis=axis, reverse=reverse) ) @@ -1372,7 +1372,7 @@ def _cumsum_p( """ return replace( - operand, value=lax.cumsum(ustrip(operand), axis=axis, reverse=reverse) + operand, value=qlax.cumsum(ustrip(operand), axis=axis, reverse=reverse) ) @@ -1423,7 +1423,7 @@ def _digamma_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(-0.5772154, dtype=float32, ...), unit='') """ - return replace(x, value=lax.digamma(ustrip(one, x))) + return replace(x, value=qlax.digamma(ustrip(one, x))) # ============================================================================== @@ -1777,7 +1777,7 @@ def _erf_inv_p(x: AbstractQuantity) -> AbstractQuantity: """ # TODO: can this support non-dimensionless quantities? - return replace(x, value=lax.erf_inv(ustrip(one, x))) + return replace(x, value=qlax.erf_inv(ustrip(one, x))) # ============================================================================== @@ -1804,7 +1804,7 @@ def _erf_p(x: AbstractQuantity) -> AbstractQuantity: """ # TODO: can this support non-dimensionless quantities? - return replace(x, value=lax.erf(ustrip(one, x))) + return replace(x, value=qlax.erf(ustrip(one, x))) # ============================================================================== @@ -1830,7 +1830,7 @@ def _erfc_p(x: AbstractQuantity) -> AbstractQuantity: """ # TODO: can this support non-dimensionless quantities? - return replace(x, value=lax.erfc(ustrip(one, x))) + return replace(x, value=qlax.erfc(ustrip(one, x))) # ============================================================================== @@ -1855,7 +1855,7 @@ def _exp2_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(8., dtype=float32, ...), unit='') """ - return replace(x, value=lax.exp2(ustrip(one, x))) + return replace(x, value=qlax.exp2(ustrip(one, x))) # ============================================================================== @@ -1888,7 +1888,7 @@ def _exp_p(x: AbstractQuantity) -> AbstractQuantity: """ # TODO: more meaningful error message. - return replace(x, value=lax.exp(ustrip(one, x))) + return replace(x, value=qlax.exp(ustrip(one, x))) # ============================================================================== @@ -1913,7 +1913,7 @@ def _expm1_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(0., dtype=float32, ...), unit='') """ - return replace(x, value=lax.expm1(ustrip(one, x))) + return replace(x, value=qlax.expm1(ustrip(one, x))) # ============================================================================== @@ -1941,7 +1941,7 @@ def _fft_p(x: AbstractQuantity, *, fft_type: Any, fft_lengths: Any) -> AbstractQ """ # TODO: what units can this support? - return replace(x, value=lax.fft(ustrip(one, x), fft_type, fft_lengths)) + return replace(x, value=qlax.fft(ustrip(one, x), fft_type, fft_lengths)) # ============================================================================== @@ -1966,7 +1966,7 @@ def _floor_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(1., dtype=float32, ...), unit='') """ - return replace(x, value=lax.floor(ustrip(x))) + return replace(x, value=qlax.floor(ustrip(x))) # ============================================================================== @@ -1979,7 +1979,7 @@ def _gather_p( ) -> AbstractQuantity: # TODO: examples return replace( - operand, value=lax.gather_p.bind(ustrip(operand), start_indices, **kwargs) + operand, value=qlax.gather_p.bind(ustrip(operand), start_indices, **kwargs) ) @@ -2219,7 +2219,7 @@ def _igamma_p(a: AbstractQuantity, x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(0.6321202, dtype=float32, ...), unit='') """ - return replace(x, value=lax.igamma(ustrip(one, a), ustrip(one, x))) + return replace(x, value=qlax.igamma(ustrip(one, a), ustrip(one, x))) # ============================================================================== @@ -2227,7 +2227,7 @@ def _igamma_p(a: AbstractQuantity, x: AbstractQuantity) -> AbstractQuantity: @register(lax.imag_p) def _imag_p(x: AbstractQuantity) -> AbstractQuantity: - return replace(x, value=lax.imag(ustrip(x))) + return replace(x, value=qlax.imag(ustrip(x))) # ============================================================================== @@ -2250,7 +2250,7 @@ def _integer_pow_p(x: AbstractQuantity, *, y: Any) -> AbstractQuantity: Quantity['volume'](Array(8, dtype=int32, ...), unit='m3') """ - return type_np(x)(value=lax.integer_pow(ustrip(x), y), unit=x.unit**y) + return type_np(x)(value=qlax.integer_pow(ustrip(x), y), unit=x.unit**y) # ============================================================================== @@ -2411,7 +2411,7 @@ def _lgamma_p(x: AbstractQuantity) -> AbstractQuantity: """ # TODO: are there any units that this can support? - return replace(x, value=lax.lgamma(ustrip(one, x))) + return replace(x, value=qlax.lgamma(ustrip(one, x))) # ============================================================================== @@ -2435,7 +2435,7 @@ def _log1p_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(-inf, dtype=float32, weak_type=True), unit='') """ - return replace(x, value=lax.log1p(ustrip(one, x))) + return replace(x, value=qlax.log1p(ustrip(one, x))) # ============================================================================== @@ -2459,7 +2459,7 @@ def _log_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(0., dtype=float32, weak_type=True), unit='') """ - return replace(x, value=lax.log(ustrip(one, x))) + return replace(x, value=qlax.log(ustrip(one, x))) # ============================================================================== @@ -2483,7 +2483,7 @@ def _logistic_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(0.7310586, dtype=float32, ...), unit='') """ - return replace(x, value=lax.logistic(ustrip(one, x))) + return replace(x, value=qlax.logistic(ustrip(one, x))) # ============================================================================== @@ -2713,7 +2713,7 @@ def _max_p_vq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity: """ yv = ustrip(one, y) - return replace(y, value=lax.max(x, yv)) + return replace(y, value=qlax.max(x, yv)) @register(lax.max_p) @@ -2737,7 +2737,7 @@ def _max_p_qv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity: """ xv = ustrip(one, x) - return replace(x, value=lax.max(xv, y)) + return replace(x, value=qlax.max(xv, y)) # ============================================================================== @@ -2769,7 +2769,7 @@ def _min_p_qq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array([1, 1, 3], dtype=int32), unit='m') """ - return replace(x, value=lax.min(ustrip(x), ustrip(x.unit, y))) + return replace(x, value=qlax.min(ustrip(x), ustrip(x.unit, y))) @register(lax.min_p) @@ -2791,7 +2791,7 @@ def _min_p_vq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array([1, 2, 2], dtype=int32), unit='') """ - return replace(y, value=lax.min(x, ustrip(one, y))) + return replace(y, value=qlax.min(x, ustrip(one, y))) @register(lax.min_p) @@ -2813,7 +2813,7 @@ def _min_p_qv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity: Quantity['dimensionless'](Array([1, 2, 2], dtype=int32), unit='') """ - return replace(x, value=lax.min(ustrip(one, x), y)) + return replace(x, value=qlax.min(ustrip(one, x), y)) # ============================================================================== @@ -2887,7 +2887,7 @@ def _mul_p_vq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array([4, 6], dtype=int32), unit='m') """ - return replace(y, value=lax.mul(x, ustrip(y))) + return replace(y, value=qlax.mul(x, ustrip(y))) @register(lax.mul_p) @@ -2923,7 +2923,7 @@ def _mul_p_qv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity: Quantity['length'](Array([4, 6], dtype=int32), unit='m') """ - return replace(x, value=lax.mul(ustrip(x), y)) + return replace(x, value=qlax.mul(ustrip(x), y)) # ============================================================================== @@ -3088,7 +3088,7 @@ def _neg_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array(-1, dtype=int32, weak_type=True), unit='m') """ - return replace(x, value=lax.neg(ustrip(x))) + return replace(x, value=qlax.neg(ustrip(x))) # ============================================================================== @@ -3122,7 +3122,7 @@ def _pow_p_qq( yv = ustrip(one, y) y0 = yv[(0,) * yv.ndim] yv = eqx.error_if(yv, jnp.any(yv != y0), "power must be a scalar") - return type_np(x)(value=lax.pow(ustrip(x), y0), unit=x.unit**y0) + return type_np(x)(value=qlax.pow(ustrip(x), y0), unit=x.unit**y0) @register(lax.pow_p) @@ -3148,7 +3148,7 @@ def _pow_p_qf(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity: Quantity['volume'](Array(8., dtype=float32, weak_type=True), unit='m3') """ - return type_np(x)(value=lax.pow(ustrip(x), y), unit=x.unit**y) + return type_np(x)(value=qlax.pow(ustrip(x), y), unit=x.unit**y) @register(lax.pow_p) @@ -3168,7 +3168,7 @@ def _pow_p_vq( Quantity['dimensionless'](Array([8.], dtype=float32), unit='') """ - return replace(y, value=lax.pow(x, ustrip(y))) + return replace(y, value=qlax.pow(x, ustrip(y))) # ============================================================================== @@ -3196,7 +3196,7 @@ def _real_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array(1., dtype=float32, weak_type=True), unit='m') """ - return replace(x, value=lax.real(ustrip(x))) + return replace(x, value=qlax.real(ustrip(x))) # ============================================================================== @@ -3272,7 +3272,7 @@ def _rem_p_qq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array(1, dtype=int32, ...), unit='m') """ - return replace(x, value=lax.rem(ustrip(x), ustrip(x.unit, y))) + return replace(x, value=qlax.rem(ustrip(x), ustrip(x.unit, y))) @register(lax.rem_p) @@ -3290,7 +3290,7 @@ def _rem_p_uqv(x: Quantity["dimensionless"], y: ArrayLike) -> Quantity["dimensio Quantity['dimensionless'](Array(1, dtype=int32, ...), unit='') """ - return replace(x, value=lax.rem(ustrip(x), y)) + return replace(x, value=qlax.rem(ustrip(x), y)) # ============================================================================== @@ -3320,7 +3320,7 @@ def _reshape_p( [4, 5]], dtype=int32), unit='m') """ - return replace(operand, value=lax.reshape(ustrip(operand), new_sizes, dimensions)) + return replace(operand, value=qlax.reshape(ustrip(operand), new_sizes, dimensions)) # ============================================================================== @@ -3344,7 +3344,7 @@ def _rev_p(operand: AbstractQuantity, *, dimensions: Any) -> AbstractQuantity: Quantity['length'](Array([3, 2, 1, 0], dtype=int32), unit='m') """ - return replace(operand, value=lax.rev(ustrip(operand), dimensions)) + return replace(operand, value=qlax.rev(ustrip(operand), dimensions)) # ============================================================================== @@ -3368,7 +3368,7 @@ def _round_p(x: AbstractQuantity, *, rounding_method: Any) -> AbstractQuantity: Quantity['length'](Array(1.23, dtype=float32, ...), unit='m') """ - return replace(x, value=lax.round(ustrip(x), rounding_method)) + return replace(x, value=qlax.round(ustrip(x), rounding_method)) # ============================================================================== @@ -3521,7 +3521,7 @@ def _select_n_p_jjq( ) -> AbstractQuantity: """Select from an array and quantity using a quantity selector.""" # Used by a `jnp.linalg.trace` - return replace(case1, value=lax.select_n(which, case0, ustrip(case1))) + return replace(case1, value=qlax.select_n(which, case0, ustrip(case1))) @register(lax.select_n_p) @@ -3546,7 +3546,7 @@ def _select_n_p_jqj( [0, 4]], dtype=int32), unit='km') """ - return replace(case0, value=lax.select_n(which, ustrip(case0), case1)) + return replace(case0, value=qlax.select_n(which, ustrip(case0), case1)) @register(lax.select_n_p) @@ -3575,7 +3575,7 @@ def _select_n_p_jqq(which: ArrayLike, *cases: AbstractQuantity) -> AbstractQuant """ u = unit_of(cases[0]) return replace( - cases[0], value=lax.select_n(which, *(ustrip(u, case) for case in cases)) + cases[0], value=qlax.select_n(which, *(ustrip(u, case) for case in cases)) ) @@ -3787,7 +3787,7 @@ def _stop_gradient_p(x: AbstractQuantity) -> AbstractQuantity: UncheckedQuantity(Array(1., dtype=float32, ...), unit='m') """ - return replace(x, value=lax.stop_gradient(ustrip(x))) + return replace(x, value=qlax.stop_gradient(ustrip(x))) # ============================================================================== @@ -3819,7 +3819,7 @@ def _sub_p_qq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array(0.5, dtype=float32, ...), unit='km') """ - return replace(x, value=lax.sub(ustrip(x.unit, x), ustrip(x.unit, y))) + return replace(x, value=qlax.sub(ustrip(x.unit, x), ustrip(x.unit, y))) @register(lax.sub_p) diff --git a/uv.lock b/uv.lock index 08a44836..98ab6dfe 100644 --- a/uv.lock +++ b/uv.lock @@ -1732,7 +1732,7 @@ wheels = [ [[package]] name = "quaxed" -version = "0.7.0" +version = "0.7.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jax" }, @@ -1741,9 +1741,9 @@ dependencies = [ { name = "plum-dispatch" }, { name = "quax" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2f/42/e49f9026cb99c9946a34be3849ec49cd3ee6c97729d0ea0e684038eddaf6/quaxed-0.7.0.tar.gz", hash = "sha256:dbd540ab9413ffe3a69cf4a1c9331106c8b8cdd3265e697da2208f98978c5a81", size = 121190 } +sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/01e425d053ee37ae2fd2cc736aaf418088c0b71c402719d06bda768bdf83/quaxed-0.7.1.tar.gz", hash = "sha256:bbc3596be9211324c2e5045e7fd919281f1a76eab63ad0dd1c8027378b1760f9", size = 123702 } wheels = [ - { url = "https://files.pythonhosted.org/packages/71/e3/3be4eabd6cc5d67ee9c8861a135c56e8b24f04701b061efd2107e1da9bc9/quaxed-0.7.0-py3-none-any.whl", hash = "sha256:2ce589bc63b7697e1e872fca36fa865c41ac61c457b4cdd5a83d4116dc1c3f56", size = 36527 }, + { url = "https://files.pythonhosted.org/packages/0a/63/a28e96b31fc27edac3c62d4359b82bdcb61da978b6a891061309a5bfb422/quaxed-0.7.1-py3-none-any.whl", hash = "sha256:1e1f66f3a523ff87f0a70b69c7079d5809f23091636747d8af4a6a99fa9925ba", size = 36547 }, ] [[package]] @@ -2267,7 +2267,7 @@ wheels = [ [[package]] name = "unxt" -version = "1.0.1.dev12+gb867ef8.d20241217" +version = "1.0.1.dev14+gcad0bc9.d20241217" source = { editable = "." } dependencies = [ { name = "astropy" }, @@ -2390,7 +2390,7 @@ requires-dist = [ { name = "optional-dependencies", specifier = ">=0.3.2" }, { name = "plum-dispatch", specifier = ">=2.5.4" }, { name = "quax", specifier = ">=0.0.5" }, - { name = "quaxed", specifier = ">=0.7.0" }, + { name = "quaxed", specifier = ">=0.7.1" }, { name = "xmmutablemap", specifier = ">=0.1" }, { name = "zeroth", specifier = ">=1.0.0" }, ]