Skip to content

Commit ccabcd8

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas TPU] Introduce a BoundedSlice block shape type
* Also add Python pipeline emitter support PiperOrigin-RevId: 748131203
1 parent d109cd8 commit ccabcd8

File tree

9 files changed

+343
-34
lines changed

9 files changed

+343
-34
lines changed

jax/_src/pallas/core.py

+65-9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from jax._src.interpreters import mlir
4040
from jax._src.interpreters import partial_eval as pe
4141
from jax._src.state import discharge as state_discharge
42+
from jax._src.state import indexing
4243
from jax._src.state import types as state_types
4344
from jax._src.state.types import TransformedRef
4445
import jax.numpy as jnp
@@ -359,7 +360,20 @@ class Blocked:
359360
def __str__(self):
360361
return f"Blocked({self.block_size})"
361362

362-
BlockDim: TypeAlias = Element | Squeezed | Blocked
363+
@dataclasses.dataclass(frozen=True)
364+
class BoundedSlice:
365+
"""Allows to specify a bounded slice of a dimension.
366+
367+
Specifically, the index_map need to return a `pl.Slice/pl.ds` for this
368+
dimension. The start and size may be dynamic, as long as the size <=
369+
block_size.
370+
"""
371+
block_size: int
372+
373+
def __repr__(self):
374+
return f"BoundedSlice({self.block_size})"
375+
376+
BlockDim: TypeAlias = Element | Squeezed | Blocked | BoundedSlice
363377

364378

365379
def default_index_map(ndim: int) -> Callable:
@@ -372,7 +386,7 @@ def _canonicalize_block_dim(dim: BlockDim | int | None) -> BlockDim:
372386
return squeezed
373387
case int():
374388
return Blocked(int(dim))
375-
case Squeezed() | Blocked() | Element():
389+
case Squeezed() | Blocked() | Element() | BoundedSlice():
376390
return dim
377391
case _:
378392
# Handle case where the dim is a symbolic dimension so we assume it is
@@ -400,6 +414,8 @@ def _get_block_dim_size(dim: BlockDim) -> int:
400414
return block_size
401415
case Element():
402416
return dim.block_size
417+
case BoundedSlice(block_size):
418+
return block_size
403419
case _:
404420
raise ValueError(f"Unsupported block shape type: {type(dim)}")
405421

@@ -420,7 +436,16 @@ def _get_ref_block_shape(block_shape: tuple[BlockDim, ...]) -> tuple[int, ...]:
420436
class BlockSpec:
421437
"""Specifies how an array should be sliced for each invocation of a kernel.
422438
423-
See :ref:`pallas_blockspec` for more details.
439+
The `block_shape` is a sequence of `int | None`s, or `BlockDim` types (e.g.
440+
`pl.Element`, `pl.Squeezed`, `pl.Blocked`, `pl.BoundedSlice`). Each of these
441+
types specify the size of the block dimension. `None` is used to specify a
442+
dimension that is squeezed out of the kernel. The `BlockDim` types allow for
443+
more fine-grained control over the indexing of the dimension. The `index_map`
444+
needs to return a tuple of the same length as `block_shape`, which each entry
445+
depending on the type of `BlockDim`.
446+
447+
See :ref:`pallas_blockspec` and the individual `BlockDim` type docstrings for
448+
more details.
424449
"""
425450
# An internal canonicalized version is in BlockMapping.
426451
block_shape: Sequence[BlockDim | int | None] | None = None
@@ -437,6 +462,17 @@ def __post_init__(self):
437462
" block dimension in `block_shape` instead to enable 'Unblocked'"
438463
" indexing."
439464
)
465+
if self.index_map is not None:
466+
old_index_map = self.index_map
467+
@functools.wraps(old_index_map)
468+
def _wrapper_index_map(*args, **kwargs):
469+
indices = old_index_map(*args, **kwargs)
470+
if isinstance(indices, list):
471+
indices = tuple(indices)
472+
if not isinstance(indices, tuple):
473+
indices = (indices,)
474+
return indices
475+
self.index_map = _wrapper_index_map
440476

441477
def to_block_mapping(
442478
self,
@@ -497,14 +533,36 @@ def to_block_mapping(
497533
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
498534
flat_index_map_fun, index_map_avals
499535
)
536+
index_map_out_tree = index_map_out_tree_thunk()
537+
unflat_avals = tree_util.tree_unflatten(index_map_out_tree, out_avals)
500538

501-
if len(out_avals) != len(block_shape):
539+
if len(unflat_avals) != len(block_shape):
502540
raise ValueError(
503541
f"Index map function {debug.func_src_info} for "
504542
f"{origin} must return "
505543
f"{len(block_shape)} values to match {block_shape=}. "
506-
f"Currently returning {len(out_avals)} values."
544+
f"Currently returning {len(unflat_avals)} values:"
507545
)
546+
# Verify types match
547+
for i, (idx_aval, bd) in enumerate(zip(unflat_avals, block_shape)):
548+
match bd:
549+
case BoundedSlice():
550+
if not isinstance(idx_aval, indexing.Slice):
551+
raise ValueError(
552+
"index_map returned a value of type"
553+
f" {type(idx_aval)} at position {i} with block dimension"
554+
f" {bd} when it should be pl.Slice"
555+
)
556+
case Blocked() | Element() | Squeezed() | int():
557+
if (
558+
not isinstance(idx_aval, jax_core.ShapedArray)
559+
and not idx_aval.shape
560+
):
561+
raise ValueError(
562+
"index_map returned a value of type"
563+
f" {type(idx_aval)} at position {i} with block dimension"
564+
f" {bd} when it should be a scalar"
565+
)
508566
for i, ov in enumerate(out_avals):
509567
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
510568
raise ValueError(
@@ -525,6 +583,7 @@ def to_block_mapping(
525583
block_shape=block_shape,
526584
transformed_block_aval=block_aval, # There are no transforms by default
527585
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
586+
index_map_out_tree=index_map_out_tree,
528587
array_shape_dtype=jax.ShapeDtypeStruct(
529588
array_aval_shape, array_aval.dtype
530589
),
@@ -566,6 +625,7 @@ class BlockMapping:
566625
block_shape: tuple[BlockDim, ...]
567626
transformed_block_aval: AbstractMemoryRef
568627
index_map_jaxpr: jax_core.ClosedJaxpr
628+
index_map_out_tree: tree_util.PyTreeDef
569629
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
570630
origin: OriginStr
571631
transforms: Sequence[MemoryRefTransform] = ()
@@ -582,10 +642,6 @@ def check_invariants(self) -> None:
582642
)
583643

584644
assert not self.index_map_jaxpr.consts
585-
assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), (
586-
self.block_shape,
587-
self.index_map_jaxpr.out_avals,
588-
)
589645
assert all(ov.shape == () and
590646
(ov.dtype == jnp.int32 or ov.dtype == jnp.int64)
591647
for ov in self.index_map_jaxpr.out_avals), (

jax/_src/pallas/mosaic/lowering.py

+14
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,12 @@ def dynamic_shape_replacement_fn(
717717
window_params = []
718718
static_grid = None
719719
grid = mosaic_grid_mapping.grid
720+
if not grid and any(
721+
not bm.has_trivial_window() for bm in grid_mapping.block_mappings
722+
):
723+
raise NotImplementedError(
724+
"Non-trivial windowing is not supported for grid-free pallas_call."
725+
)
720726
if grid:
721727
for i, bm in enumerate(grid_mapping.block_mappings):
722728
func_name = f"transform_{i}"
@@ -761,6 +767,14 @@ def dynamic_shape_replacement_fn(
761767
window_bounds=window_shape,
762768
transform_indices=ir.FlatSymbolRefAttr.get(func_name),
763769
)
770+
for bd in bm.block_shape:
771+
if not isinstance(
772+
bd, (pallas_core.Element, pallas_core.Squeezed, pallas_core.Blocked)
773+
):
774+
raise NotImplementedError(
775+
"Unsupported block dimension type: "
776+
f"{type(bd)} for block shape: {bm.block_shape}"
777+
)
764778
is_element_block = [isinstance(bd, pallas_core.Element)
765779
for bd in bm.block_shape]
766780
if any(is_element_block):

jax/_src/pallas/mosaic/pipeline.py

+72-12
Original file line numberDiff line numberDiff line change
@@ -111,25 +111,49 @@ def _round_up_to_nearest_multiple(s: int, multiple: int) -> int:
111111
return s - s % multiple + multiple
112112

113113

114-
def _make_ds(
114+
def _make_block_ds(
115115
idx: jax.Array | int, size: jax.Array | int
116116
) -> pl.Slice:
117117
"""Make a DMA slice with mosaic size hints."""
118118
out = pl.ds(idx * size, size)
119119
assert isinstance(out, pl.Slice)
120120
return out
121121

122-
123122
def _make_block_slice(
124-
block_index: jax.Array, block_size: int, size: int, tiling: int
123+
block_index: jax.Array, block_size: pl.BlockDim | int | None, size: int,
124+
tiling: int
125125
) -> pl.Slice | slice:
126126
# Computes a slice given a block index and block size. In the default case,
127127
# we return slice(block_index * block_size, (block_index + 1) * block_size).
128128
# However, if the total size of the ref does not divide block size and we are
129129
# selecting the last block, we need to pick the lowest tiling size multiple
130130
# that contains the block.
131+
match block_size:
132+
case pl.Blocked():
133+
block_start = block_size.block_size * block_index
134+
block_size = block_size.block_size
135+
case pl.Element():
136+
block_start = block_index
137+
block_size = block_size.block_size
138+
case pl.BoundedSlice():
139+
if not isinstance(block_index, pl.Slice):
140+
raise ValueError(
141+
"Must return a pl.ds from the index_map for a BoundedSlice"
142+
" dimension."
143+
)
144+
block_start = block_index.start
145+
block_size = block_index.size
146+
return pl.ds(block_start, block_size)
147+
case int():
148+
# This is same as Blocked.
149+
block_start = block_index * block_size
150+
case None | pl.Squeezed():
151+
block_start = block_index
152+
block_size = 1
153+
case _:
154+
raise ValueError(f"Unsupported block dimension type: {block_size}")
131155
if size % block_size == 0:
132-
return _make_ds(block_index, block_size)
156+
return pl.ds(block_start, block_size)
133157
if block_size % tiling != 0:
134158
raise ValueError(f"Block size must divide tiling: {block_size=}, {tiling=}")
135159
num_blocks = pl.cdiv(size, block_size)
@@ -145,7 +169,7 @@ def _make_block_slice(
145169

146170
def _tuples_differ(xs, ys):
147171
"""Dynamic index-tuple comparison calculation."""
148-
differences = jax.tree.map(lambda x, y: x != y, xs, ys)
172+
differences = jax.tree.leaves(jax.tree.map(lambda x, y: x != y, xs, ys))
149173
return functools.reduce(lambda x, y: x | y, differences, False)
150174

151175

@@ -167,6 +191,26 @@ class BufferType(enum.Enum):
167191

168192
MANUAL = 5
169193

194+
def _get_block_shape(spec: pl.BlockSpec) -> tuple[int, ...]:
195+
"""Get the block shape for a given block spec."""
196+
def _get_dim_size(bd):
197+
match bd:
198+
case pl.Blocked(block_size):
199+
return block_size
200+
case pl.Element():
201+
return bd.block_size
202+
case pl.BoundedSlice(block_size):
203+
return block_size
204+
case int():
205+
return bd
206+
case None:
207+
return 1
208+
case _:
209+
raise ValueError(f"Unsupported block dimension type: {bd}")
210+
if spec.block_shape is None:
211+
raise ValueError("Block shape must be specified.")
212+
block_shape = tuple(_get_dim_size(x) for x in spec.block_shape)
213+
return block_shape
170214

171215
@tree_util.register_pytree_node_class
172216
@dataclasses.dataclass(frozen=True)
@@ -236,7 +280,8 @@ def buffer_types() -> type[BufferType]:
236280
return BufferType
237281

238282
@classmethod
239-
def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
283+
def create(cls, spec: pl.BlockSpec, dtype, buffer_type, needs_swap_ref=True
284+
) -> BufferedRef:
240285
"""Create a BufferedRef.
241286
242287
Args:
@@ -249,7 +294,7 @@ def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
249294
Returns:
250295
Initialized BufferedRef
251296
"""
252-
block_shape = tuple(1 if x is None else x for x in spec.block_shape)
297+
block_shape = _get_block_shape(spec)
253298
if buffer_type is BufferType.ACCUMULATOR:
254299
accum_ref = VMEM(block_shape, dtype)
255300
else:
@@ -375,9 +420,22 @@ def bind_existing_ref(self, window_ref, indices):
375420

376421
def compute_slice(self, grid_indices):
377422
"""Compute DMA slice from grid indices."""
378-
block_shape = tuple(1 if x is None else x for x in self.block_shape)
423+
block_shape = []
424+
for bd in self.block_shape:
425+
if isinstance(bd, (pl.Element, pl.BoundedSlice)):
426+
raise ValueError(
427+
"Element and BoundedSlice block dimensions are not supported."
428+
)
429+
if bd is None:
430+
block_shape.append(1)
431+
elif isinstance(bd, pl.Blocked):
432+
block_shape.append(bd.block_size)
433+
elif isinstance(bd, int):
434+
block_shape.append(bd)
435+
else:
436+
raise ValueError(f"Unsupported block dimension type: {type(bd)}")
379437
indices = self.compute_index(*grid_indices)
380-
return jax.tree.map(_make_ds, indices, block_shape)
438+
return jax.tree.map(_make_block_ds, indices, tuple(block_shape))
381439

382440
def init_slots(self):
383441
"""Initialize slot indices."""
@@ -444,10 +502,12 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
444502
raise NotImplementedError("Must use >1D values.")
445503

446504
tiling = _make_tiling(src_shape, src_dtype)
447-
block_shape = tuple(1 if b is None else b for b in self.block_shape)
448505
block_indices = self.compute_index(*grid_indices)
449-
return jax.tree.map(
450-
_make_block_slice, block_indices, block_shape, src_shape, tiling
506+
return tuple(
507+
_make_block_slice(bi, bs, ss, t)
508+
for bi, bs, ss, t in zip(
509+
block_indices, self.block_shape, src_shape, tiling, strict=True
510+
)
451511
)
452512

453513
def copy_in(self, src_ref, grid_indices):

jax/_src/pallas/mosaic_gpu/pipeline.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@
4040
map = util.safe_map
4141
zip = util.safe_zip
4242

43-
def _get_block_size(bd: pl.Blocked | pl.Element | pl.Squeezed | int | None
44-
) -> int:
43+
def _get_block_size(
44+
bd: pl.Blocked | pl.Element | pl.Squeezed | pl.BoundedSlice | int | None,
45+
) -> int:
4546
match bd:
4647
case int():
4748
return bd

jax/_src/pallas/pallas_call.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,19 @@ def _block_map_function(new_idx, *args):
221221
block_mapping.index_map_jaxpr.consts,
222222
*drop_last_args,
223223
)
224+
unflat_indices = tree_util.tree_unflatten(
225+
block_mapping.index_map_out_tree, indices)
226+
if not isinstance(unflat_indices, tuple):
227+
unflat_indices = (unflat_indices,)
228+
unflat_indices = list(unflat_indices)
224229
if dim is not batching.not_mapped:
225230
if isinstance(dim, batching.RaggedAxis):
226231
assert for_ragged, "Ragged axis not supported for non-ragged batching."
227232
stacked_axis = dim.stacked_axis
228-
indices.insert(stacked_axis, new_idx)
233+
unflat_indices.insert(stacked_axis, new_idx)
229234
else:
230-
indices.insert(dim, new_idx)
231-
return tuple(indices)
235+
unflat_indices.insert(dim, new_idx)
236+
return tuple(unflat_indices)
232237
idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals]
233238

234239
if for_ragged:
@@ -243,11 +248,15 @@ def _block_map_function(new_idx, *args):
243248
)
244249
idx_avals = [*idx_avals, i32_aval_memref]
245250

251+
block_mapping_flat_fn, out_tree_thunk = api_util.flatten_fun_nokwargs(
252+
lu.wrap_init(_block_map_function,
253+
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
254+
tree_util.tree_structure(idx_avals))
246255
with grid_mapping.trace_env():
247256
block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(
248-
lu.wrap_init(_block_map_function,
249-
debug_info=block_mapping.index_map_jaxpr.jaxpr.debug_info),
257+
block_mapping_flat_fn,
250258
idx_avals)
259+
new_index_map_out_tree = out_tree_thunk()
251260
shape = block_mapping.block_shape
252261
if dim is batching.not_mapped:
253262
new_block_shape = shape
@@ -278,7 +287,8 @@ def _block_map_function(new_idx, *args):
278287
jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts)
279288
return block_mapping.replace(block_shape=new_block_shape,
280289
array_shape_dtype=new_array_shape_dtype,
281-
index_map_jaxpr=jaxpr)
290+
index_map_jaxpr=jaxpr,
291+
index_map_out_tree=new_index_map_out_tree)
282292

283293

284294
def _broadcast_input_output_aliases(

0 commit comments

Comments
 (0)