Skip to content

[Pallas/Fuser] Bugfix for broadcasting, lax.slice_p, and lax.dynamic_slice_p with Element block shapes #28207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions jax/_src/pallas/fuser/block_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def wrapped(*args):
return wrapped


def _block_size(dim: pallas_core.Element | int | None) -> int | None:
if isinstance(dim, pallas_core.Element):
return dim.block_size
return dim


@dataclasses.dataclass
class UsageRuleContext:
avals_in: tuple[core.AbstractValue, ...]
Expand Down Expand Up @@ -420,11 +426,6 @@ def make_kernel_function(
invar_usages = util.safe_map(read_usage_env, jaxpr.invars)
bs_env, scalar_prefetch_fn_env = block_spec_env

def _block_size(dim: pallas_core.Element | int | None) -> int | None:
if isinstance(dim, pallas_core.Element):
return dim.block_size
return dim

def _remove_nones(
shape: tuple[pallas_core.Element | int | None, ...] | None
) -> tuple[int, ...]:
Expand Down Expand Up @@ -727,7 +728,14 @@ def new_index_map(i, *args):
idx = util.tuple_update(idx, i, 0)
return idx

new_block_shape = util.tuple_update(block_spec.block_shape, i, 1)
# TODO(wdvi): This is a hack needed since lowering rules require block shape
# to contain either all pl.Element or none
bcast_dim_block_shape = 1
if isinstance(block_spec.block_shape[i], pallas_core.Element):
bcast_dim_block_shape = pallas_core.Element(1)
new_block_shape = util.tuple_update(
block_spec.block_shape, i, bcast_dim_block_shape
)
return pallas_core.BlockSpec(
new_block_shape, functools.partial(new_index_map, i)
)
Expand Down Expand Up @@ -876,10 +884,13 @@ def _slice_rule(
):
if bs is None:
continue
assert slice_start % bs == 0, (start_indices, block_spec.block_shape)
assert slice_size % bs == 0, (slice_sizes, block_spec.block_shape)
block_size = _block_size(bs)
assert (
slice_start % block_size == 0
), (start_indices, block_spec.block_shape)
assert slice_size % block_size == 0, (slice_sizes, block_spec.block_shape)
offsets = tuple(
slice_start // bs if bs is not None else slice_start
slice_start // _block_size(bs) if bs is not None else slice_start
for slice_start, bs in zip(start_indices, block_spec.block_shape)
)

Expand Down Expand Up @@ -957,7 +968,7 @@ def new_index_map(*args):
# We then add these block indices to block indices produced by the index
# map.
block_indices = tuple(
_offset(i, o, s)
_offset(i, o, _block_size(s))
for i, o, s in zip(
idx, slice_starts, block_spec.block_shape, strict=True
)
Expand All @@ -976,6 +987,11 @@ def _concatenate_eval_rule(ctx: KernelEvalContext, *args, dimension):
# divides the block size.
block_spec = ctx.out_block_specs[0]
block_shape = block_spec.block_shape
is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape]
if any(is_element_block):
raise NotImplementedError(
"Concatenation with Element indexing is not yet supported."
)
block_dim = block_shape[dimension]
if block_dim is None:
block_dim = 1
Expand Down Expand Up @@ -1019,6 +1035,11 @@ def _concatenate_rule(
dimension: int,
):
block_shape = block_spec.block_shape
is_element_block = [isinstance(bd, pallas_core.Element) for bd in block_shape]
if any(is_element_block):
raise NotImplementedError(
"Concatenation with Element indexing is not yet supported."
)
num_blocks = []
block_dim = block_shape[dimension]
if block_dim is None:
Expand Down Expand Up @@ -1093,7 +1114,7 @@ def _broadcast_in_dim_eval_rule(
if not eval_ctx.avals_in[0].shape: # pytype: disable=attribute-error
# Scalar -> Array broadcast
block_spec = eval_ctx.out_block_specs[0]
shape = tuple(s for s in block_spec.block_shape if s is not None)
shape = tuple(_block_size(s) for s in block_spec.block_shape if s is not None)
return jax.lax.broadcast_in_dim(x, broadcast_dimensions=(), shape=shape)
return x

Expand Down
11 changes: 9 additions & 2 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,12 +764,19 @@ def dynamic_shape_replacement_fn(
is_element_block = [isinstance(bd, pallas_core.Element)
for bd in bm.block_shape]
if any(is_element_block):
if not all(is_element_block):
is_element_or_squeezed_block = [
isinstance(bd, (pallas_core.Element, pallas_core.Squeezed))
for bd in bm.block_shape
]
if not all(is_element_or_squeezed_block):
raise NotImplementedError(
"All block dimensions must be Elements or none of them can be"
" Elements."
)
padding = [bd.padding for bd in bm.block_shape] # pytype: disable=attribute-error
padding = [
bd.padding if isinstance(bd, pallas_core.Element) else (0, 0)
for bd in bm.block_shape
]
pad_low, pad_high = map(list, zip(*padding))
block_params["window_kind"] = ir.Attribute.parse(
f"#tpu.element_window<{pad_low},{pad_high}>"
Expand Down
Loading