You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is the internal annotation custom calls we might use to actually annotate a buffer to a specific memory space. However, we actually think the memory space annotations are too heavy and not necessary.
To get great performance and allow maximal flexibility, the users should be able to pin specific buffers onto some specific memory spaces. Let's consider the following case, a JAX user want to get the buffer as a persistent temp buffer or a pinned host buffer so that they can use them inside their FFI custom ops.
We should in general do the following,
Fixes SPMD partitioning to recognize heterogenous memory spaces.
Don't drop user-set memory space in layout assignment
Fixes layout normalization
Adds a bit memory space propagation.
The text was updated successfully, but these errors were encountered:
Can you confirm that by "persistent temp buffer" you want an activation to always have the same buffer?
Why that is useful?
JAX started to support supporting activation offloading to the host memory via the jax.remat API. Instead of recomputing, it offload. Is that what you want to do here?
Prototype Reference: https://github.com/openxla/xla/pull/23149/files
This is the internal annotation custom calls we might use to actually annotate a buffer to a specific memory space. However, we actually think the memory space annotations are too heavy and not necessary.
To get great performance and allow maximal flexibility, the users should be able to pin specific buffers onto some specific memory spaces. Let's consider the following case, a JAX user want to get the buffer as a persistent temp buffer or a pinned host buffer so that they can use them inside their FFI custom ops.
We should in general do the following,
The text was updated successfully, but these errors were encountered: