Skip to content

Commit

Permalink
[Pallas] Add support for setting operand as HBM memory.
Browse files Browse the repository at this point in the history
For inputs it operates like pl.ANY, for outputs it forces
memory space assignment to not assign the buffer as VMEM

PiperOrigin-RevId: 730222777
  • Loading branch information
Marcello Maggioni authored and Google-ML-Automation committed Feb 24, 2025
1 parent 7d3c63e commit 170612f
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 1 deletion.
1 change: 1 addition & 0 deletions jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class TPUCompilerParams(pallas_core.CompilerParams):

class TPUMemorySpace(enum.Enum):
ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY.
HBM = "hbm"
VMEM = "vmem"
SMEM = "smem"
CMEM = "cmem"
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
NDIndexer = indexing.NDIndexer
TPUMemorySpace = tpu_core.TPUMemorySpace
MemorySpace = pallas_core.MemorySpace | TPUMemorySpace
HBM = tpu_core.TPUMemorySpace.HBM
VMEM = tpu_core.TPUMemorySpace.VMEM
SMEM = tpu_core.TPUMemorySpace.SMEM
# Booleans are stored as the following type in memrefs.
Expand Down Expand Up @@ -653,11 +654,13 @@ def dynamic_shape_replacement_fn(
if grid:
for i, bm in enumerate(grid_mapping.block_mappings):
func_name = f"transform_{i}"
# ANY and SEMAPHORE operands don't support windowing and require empty window_params.
# ANY, HBM and SEMAPHORE operands don't support windowing and require empty window_params.
tpu_memory_space = _memory_space_to_tpu_memory_space(
bm.block_aval.memory_space)
print("index_map_jaxpr: ", bm.index_map_jaxpr)
if (
tpu_memory_space == tpu_core.TPUMemorySpace.ANY
or tpu_memory_space == tpu_core.TPUMemorySpace.HBM
or tpu_memory_space == tpu_core.TPUMemorySpace.SEMAPHORE
):
# We checked above that the block does not require windowing.
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def _get_memory_space_from_aval(
return None
case tpu_core.TPUMemorySpace.ANY:
return None
case tpu_core.TPUMemorySpace.HBM:
return tpu_custom_call.MemorySpace.HBM
case tpu_core.TPUMemorySpace.VMEM:
return tpu_custom_call.MemorySpace.VMEM
case tpu_core.TPUMemorySpace.SMEM:
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/pallas/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
del types, assume, pretend, skip, define_model # Clean up.

ANY = TPUMemorySpace.ANY
HBM = TPUMemorySpace.HBM
CMEM = TPUMemorySpace.CMEM
SMEM = TPUMemorySpace.SMEM
VMEM = TPUMemorySpace.VMEM
Expand Down
67 changes: 67 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,16 @@ def body(sem):
grid=(2,),
)(x)
np.testing.assert_allclose(y, x)
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.HBM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.float32),
grid=(2,),
)(x)
np.testing.assert_allclose(y, x)

def test_hbm_vmem_dma(self):
def kernel(x_hbm_ref, y_ref):
Expand All @@ -1278,6 +1288,14 @@ def body(x_ref, sem):
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.HBM),
],
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)

def test_vmem_hbm_dma(self):
def kernel(x_ref, y_hbm_ref):
Expand All @@ -1294,6 +1312,12 @@ def body(y_ref, sem):
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
y = self.pallas_call(
kernel,
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)

def test_vmem_hbm_vmem_dma(self):
def kernel(x_hbm_ref, y_hbm_ref):
Expand All @@ -1315,6 +1339,13 @@ def body(x_ref, y_ref, sem):
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
y = self.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.HBM)],
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)

def test_hbm_smem_dma(self):
def kernel(x_hbm_ref, y_ref):
Expand All @@ -1333,6 +1364,14 @@ def body(x_ref, sem):
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.HBM),
],
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x)

def test_smem_hbm_dma(self):
def kernel(x_ref, y_hbm_ref):
Expand All @@ -1354,6 +1393,16 @@ def body(y_ref, sem):
)(x)
expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4])
np.testing.assert_allclose(y, expected)
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.HBM),
out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32),
)(x)
expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4])
np.testing.assert_allclose(y, expected)

def test_vmem_vmem_dma(self):
def kernel(x_ref, y_ref):
Expand Down Expand Up @@ -1416,6 +1465,15 @@ def body(sem):
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x.reshape((16, 128)))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.HBM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x.reshape((16, 128)))

def test_hbm_vmem_dma_multiple_indexing(self):
if self.INTERPRET:
Expand Down Expand Up @@ -1445,6 +1503,15 @@ def body(sem):
out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x.reshape((3, 16, 128)))
y = self.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.HBM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32),
)(x)
np.testing.assert_allclose(y, x.reshape((3, 16, 128)))

def test_cannot_squeeze_lane_sublane(self):
if self.INTERPRET:
Expand Down

0 comments on commit 170612f

Please sign in to comment.