Skip to content

[Pallas TPU] Introduce a BoundedSlice block shape type #28127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.state import discharge as state_discharge
from jax._src.state import indexing
from jax._src.state import types as state_types
from jax._src.state.types import TransformedRef
import jax.numpy as jnp
Expand Down Expand Up @@ -359,7 +360,20 @@ class Blocked:
def __str__(self):
return f"Blocked({self.block_size})"

BlockDim: TypeAlias = Element | Squeezed | Blocked
@dataclasses.dataclass(frozen=True)
class BoundedSlice:
"""Allows to specify a bounded slice of a dimension.

Specifically, the index_map need to return a `pl.Slice/pl.ds` for this
dimension. The start and size may be dynamic, as long as the size <=
block_size.
"""
block_size: int

def __repr__(self):
return f"BoundedSlice({self.block_size})"

BlockDim: TypeAlias = Element | Squeezed | Blocked | BoundedSlice


def default_index_map(ndim: int) -> Callable:
Expand All @@ -372,7 +386,7 @@ def _canonicalize_block_dim(dim: BlockDim | int | None) -> BlockDim:
return squeezed
case int():
return Blocked(int(dim))
case Squeezed() | Blocked() | Element():
case Squeezed() | Blocked() | Element() | BoundedSlice():
return dim
case _:
# Handle case where the dim is a symbolic dimension so we assume it is
Expand Down Expand Up @@ -400,6 +414,8 @@ def _get_block_dim_size(dim: BlockDim) -> int:
return block_size
case Element():
return dim.block_size
case BoundedSlice(block_size):
return block_size
case _:
raise ValueError(f"Unsupported block shape type: {type(dim)}")

Expand All @@ -420,7 +436,16 @@ def _get_ref_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]:
class BlockSpec:
"""Specifies how an array should be sliced for each invocation of a kernel.

See :ref:`pallas_blockspec` for more details.
The `block_shape` is a sequence of `int | None`s, or `BlockDim` types (e.g.
`pl.Element`, `pl.Squeezed`, `pl.Blocked`, `pl.BoundedSlice`). Each of these
types specify the size of the block dimension. `None` is used to specify a
dimension that is squeezed out of the kernel. The `BlockDim` types allow for
more fine-grained control over the indexing of the dimension. The `index_map`
needs to return a tuple of the same length as `block_shape`, which each entry
depending on the type of `BlockDim`.

See :ref:`pallas_blockspec` and the individual `BlockDim` type docstrings for
more details.
"""
# An internal canonicalized version is in BlockMapping.
block_shape: Sequence[BlockDim | int | None] | None = None
Expand All @@ -437,6 +462,17 @@ def __post_init__(self):
" block dimension in `block_shape` instead to enable 'Unblocked'"
" indexing."
)
if self.index_map is not None:
old_index_map = self.index_map
@functools.wraps(old_index_map)
def _wrapper_index_map(*args, **kwargs):
indices = old_index_map(*args, **kwargs)
if isinstance(indices, list):
indices = tuple(indices)
if not isinstance(indices, tuple):
indices = (indices,)
return indices
self.index_map = _wrapper_index_map

def to_block_mapping(
self,
Expand Down Expand Up @@ -497,14 +533,36 @@ def to_block_mapping(
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
flat_index_map_fun, index_map_avals
)
index_map_out_tree = index_map_out_tree_thunk()
unflat_avals = tree_util.tree_unflatten(index_map_out_tree, out_avals)

if len(out_avals) != len(block_shape):
if len(unflat_avals) != len(block_shape):
raise ValueError(
f"Index map function {debug.func_src_info} for "
f"{origin} must return "
f"{len(block_shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
f"Currently returning {len(unflat_avals)} values:"
)
# Verify types match
for i, (idx_aval, bd) in enumerate(zip(unflat_avals, block_shape)):
match bd:
case BoundedSlice():
if not isinstance(idx_aval, indexing.Slice):
raise ValueError(
"index_map returned a value of type"
f" {type(idx_aval)} at position {i} with block dimension"
f" {bd} when it should be pl.Slice"
)
case Blocked() | Element() | Squeezed() | int():
if (
not isinstance(idx_aval, jax_core.ShapedArray)
and not idx_aval.shape
):
raise ValueError(
"index_map returned a value of type"
f" {type(idx_aval)} at position {i} with block dimension"
f" {bd} when it should be a scalar"
)
for i, ov in enumerate(out_avals):
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
raise ValueError(
Expand All @@ -525,6 +583,7 @@ def to_block_mapping(
block_shape=block_shape,
transformed_block_aval=block_aval, # There are no transforms by default
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_out_tree=index_map_out_tree,
array_shape_dtype=jax.ShapeDtypeStruct(
array_aval_shape, array_aval.dtype
),
Expand Down Expand Up @@ -566,6 +625,7 @@ class BlockMapping:
block_shape: tuple[BlockDim, ...]
transformed_block_aval: AbstractMemoryRef
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_out_tree: tree_util.PyTreeDef
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: OriginStr
transforms: Sequence[MemoryRefTransform] = ()
Expand All @@ -582,10 +642,6 @@ def check_invariants(self) -> None:
)

assert not self.index_map_jaxpr.consts
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), (
self.block_shape,
self.index_map_jaxpr.out_avals,
)
assert all(ov.shape == () and
(ov.dtype == jnp.int32 or ov.dtype == jnp.int64)
for ov in self.index_map_jaxpr.out_avals), (
Expand Down
14 changes: 14 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,12 @@ def dynamic_shape_replacement_fn(
window_params = []
static_grid = None
grid = mosaic_grid_mapping.grid
if not grid and any(
not bm.has_trivial_window() for bm in grid_mapping.block_mappings
):
raise NotImplementedError(
"Non-trivial windowing is not supported for grid-free pallas_call."
)
if grid:
for i, bm in enumerate(grid_mapping.block_mappings):
func_name = f"transform_{i}"
Expand Down Expand Up @@ -761,6 +767,14 @@ def dynamic_shape_replacement_fn(
window_bounds=window_shape,
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
)
for bd in bm.block_shape:
if not isinstance(
bd, (pallas_core.Element, pallas_core.Squeezed, pallas_core.Blocked)
):
raise NotImplementedError(
"Unsupported block dimension type: "
f"{type(bd)} for block shape: {bm.block_shape}"
)
is_element_block = [isinstance(bd, pallas_core.Element)
for bd in bm.block_shape]
if any(is_element_block):
Expand Down
84 changes: 72 additions & 12 deletions jax/_src/pallas/mosaic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,49 @@ def _round_up_to_nearest_multiple(s: int, multiple: int) -> int:
return s - s % multiple + multiple


def _make_ds(
def _make_block_ds(
idx: jax.Array | int, size: jax.Array | int
) -> pl.Slice:
"""Make a DMA slice with mosaic size hints."""
out = pl.ds(idx * size, size)
assert isinstance(out, pl.Slice)
return out


def _make_block_slice(
block_index: jax.Array, block_size: int, size: int, tiling: int
block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int,
tiling: int
) -> pl.Slice | slice:
# Computes a slice given a block index and block size. In the default case,
# we return slice(block_index * block_size, (block_index + 1) * block_size).
# However, if the total size of the ref does not divide block size and we are
# selecting the last block, we need to pick the lowest tiling size multiple
# that contains the block.
match block_size:
case pl.Blocked():
block_start = block_size.block_size * block_index
block_size = block_size.block_size
case pl.Element():
block_start = block_index
block_size = block_size.block_size
case pl.BoundedSlice():
if not isinstance(block_index, pl.Slice):
raise ValueError(
"Must return a pl.ds from the index_map for a BoundedSlice"
" dimension."
)
block_start = block_index.start
block_size = block_index.size
return pl.ds(block_start, block_size)
case int():
# This is same as Blocked.
block_start = block_index * block_size
case None | pl.Squeezed():
block_start = block_index
block_size = 1
case _:
raise ValueError(f"Unsupported block dimension type: {block_size}")
if size % block_size == 0:
return _make_ds(block_index, block_size)
return pl.ds(block_start, block_size)
if block_size % tiling != 0:
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
num_blocks = pl.cdiv(size, block_size)
Expand All @@ -145,7 +169,7 @@ def _make_block_slice(

def _tuples_differ(xs, ys):
"""Dynamic index-tuple comparison calculation."""
differences = jax.tree.map(lambda x, y: x != y, xs, ys)
differences = jax.tree.leaves(jax.tree.map(lambda x, y: x != y, xs, ys))
return functools.reduce(lambda x, y: x | y, differences, False)


Expand All @@ -167,6 +191,26 @@ class BufferType(enum.Enum):

MANUAL = 5

def _get_block_shape(spec: pl.BlockSpec) -> tuple[int, ...]:
"""Get the block shape for a given block spec."""
def _get_dim_size(bd):
match bd:
case pl.Blocked(block_size):
return block_size
case pl.Element():
return bd.block_size
case pl.BoundedSlice(block_size):
return block_size
case int():
return bd
case None:
return 1
case _:
raise ValueError(f"Unsupported block dimension type: {bd}")
if spec.block_shape is None:
raise ValueError("Block shape must be specified.")
block_shape = tuple(_get_dim_size(x) for x in spec.block_shape)
return block_shape

@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -236,7 +280,8 @@ def buffer_types() -> type[BufferType]:
return BufferType

@classmethod
def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
def create(cls, spec: pl.BlockSpec, dtype, buffer_type, needs_swap_ref=True
) -> BufferedRef:
"""Create a BufferedRef.

Args:
Expand All @@ -249,7 +294,7 @@ def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
Returns:
Initialized BufferedRef
"""
block_shape = tuple(1 if x is None else x for x in spec.block_shape)
block_shape = _get_block_shape(spec)
if buffer_type is BufferType.ACCUMULATOR:
accum_ref = VMEM(block_shape, dtype)
else:
Expand Down Expand Up @@ -375,9 +420,22 @@ def bind_existing_ref(self, window_ref, indices):

def compute_slice(self, grid_indices):
"""Compute DMA slice from grid indices."""
block_shape = tuple(1 if x is None else x for x in self.block_shape)
block_shape = []
for bd in self.block_shape:
if isinstance(bd, (pl.Element, pl.BoundedSlice)):
raise ValueError(
"Element and BoundedSlice block dimensions are not supported."
)
if bd is None:
block_shape.append(1)
elif isinstance(bd, pl.Blocked):
block_shape.append(bd.block_size)
elif isinstance(bd, int):
block_shape.append(bd)
else:
raise ValueError(f"Unsupported block dimension type: {type(bd)}")
indices = self.compute_index(*grid_indices)
return jax.tree.map(_make_ds, indices, block_shape)
return jax.tree.map(_make_block_ds, indices, tuple(block_shape))

def init_slots(self):
"""Initialize slot indices."""
Expand Down Expand Up @@ -444,10 +502,12 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
raise NotImplementedError("Must use >1D values.")

tiling = _make_tiling(src_shape, src_dtype)
block_shape = tuple(1 if b is None else b for b in self.block_shape)
block_indices = self.compute_index(*grid_indices)
return jax.tree.map(
_make_block_slice, block_indices, block_shape, src_shape, tiling
return tuple(
_make_block_slice(bi, bs, ss, t)
for bi, bs, ss, t in zip(
block_indices, self.block_shape, src_shape, tiling, strict=True
)
)

def copy_in(self, src_ref, grid_indices):
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/pallas/mosaic_gpu/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
map = util.safe_map
zip = util.safe_zip

def _get_block_size(bd: pl.Blocked | pl.Element | pl.Squeezed | int | None
) -> int:
def _get_block_size(
bd: pl.Blocked | pl.Element | pl.Squeezed | pl.BoundedSlice | int | None,
) -> int:
match bd:
case int():
return bd
Expand Down
22 changes: 16 additions & 6 deletions jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,19 @@ def _block_map_function(new_idx, *args):
block_mapping.index_map_jaxpr.consts,
*drop_last_args,
)
unflat_indices = tree_util.tree_unflatten(
block_mapping.index_map_out_tree, indices)
if not isinstance(unflat_indices, tuple):
unflat_indices = (unflat_indices,)
unflat_indices = list(unflat_indices)
if dim is not batching.not_mapped:
if isinstance(dim, batching.RaggedAxis):
assert for_ragged, "Ragged axis not supported for non-ragged batching."
stacked_axis = dim.stacked_axis
indices.insert(stacked_axis, new_idx)
unflat_indices.insert(stacked_axis, new_idx)
else:
indices.insert(dim, new_idx)
return tuple(indices)
unflat_indices.insert(dim, new_idx)
return tuple(unflat_indices)
idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals]

if for_ragged:
Expand All @@ -243,11 +248,15 @@ def _block_map_function(new_idx, *args):
)
idx_avals = [*idx_avals, i32_aval_memref]

block_mapping_flat_fn, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(_block_map_function,
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
tree_util.tree_structure(idx_avals))
with grid_mapping.trace_env():
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(_block_map_function,
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
block_mapping_flat_fn,
idx_avals)
new_index_map_out_tree = out_tree_thunk()
shape = block_mapping.block_shape
if dim is batching.not_mapped:
new_block_shape = shape
Expand Down Expand Up @@ -278,7 +287,8 @@ def _block_map_function(new_idx, *args):
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
return block_mapping.replace(block_shape=new_block_shape,
array_shape_dtype=new_array_shape_dtype,
index_map_jaxpr=jaxpr)
index_map_jaxpr=jaxpr,
index_map_out_tree=new_index_map_out_tree)


def _broadcast_input_output_aliases(
Expand Down
Loading
Loading