@@ -111,25 +111,49 @@ def _round_up_to_nearest_multiple(s: int, multiple: int) -> int:
111
111
return s - s % multiple + multiple
112
112
113
113
114
- def _make_ds (
114
+ def _make_block_ds (
115
115
idx : jax .Array | int , size : jax .Array | int
116
116
) -> pl .Slice :
117
117
"""Make a DMA slice with mosaic size hints."""
118
118
out = pl .ds (idx * size , size )
119
119
assert isinstance (out , pl .Slice )
120
120
return out
121
121
122
-
123
122
def _make_block_slice (
124
- block_index : jax .Array , block_size : int , size : int , tiling : int
123
+ block_index : jax .Array , block_size : pl .BlockDim | int | None , size : int ,
124
+ tiling : int
125
125
) -> pl .Slice | slice :
126
126
# Computes a slice given a block index and block size. In the default case,
127
127
# we return slice(block_index * block_size, (block_index + 1) * block_size).
128
128
# However, if the total size of the ref does not divide block size and we are
129
129
# selecting the last block, we need to pick the lowest tiling size multiple
130
130
# that contains the block.
131
+ match block_size :
132
+ case pl .Blocked ():
133
+ block_start = block_size .block_size * block_index
134
+ block_size = block_size .block_size
135
+ case pl .Element ():
136
+ block_start = block_index
137
+ block_size = block_size .block_size
138
+ case pl .BoundedSlice ():
139
+ if not isinstance (block_index , pl .Slice ):
140
+ raise ValueError (
141
+ "Must return a pl.ds from the index_map for a BoundedSlice"
142
+ " dimension."
143
+ )
144
+ block_start = block_index .start
145
+ block_size = block_index .size
146
+ return pl .ds (block_start , block_size )
147
+ case int ():
148
+ # This is same as Blocked.
149
+ block_start = block_index * block_size
150
+ case None | pl .Squeezed ():
151
+ block_start = block_index
152
+ block_size = 1
153
+ case _:
154
+ raise ValueError (f"Unsupported block dimension type: { block_size } " )
131
155
if size % block_size == 0 :
132
- return _make_ds ( block_index , block_size )
156
+ return pl . ds ( block_start , block_size )
133
157
if block_size % tiling != 0 :
134
158
raise ValueError (f"Block size must divide tiling: { block_size = } , { tiling = } " )
135
159
num_blocks = pl .cdiv (size , block_size )
@@ -145,7 +169,7 @@ def _make_block_slice(
145
169
146
170
def _tuples_differ (xs , ys ):
147
171
"""Dynamic index-tuple comparison calculation."""
148
- differences = jax .tree .map (lambda x , y : x != y , xs , ys )
172
+ differences = jax .tree .leaves ( jax . tree . map (lambda x , y : x != y , xs , ys ) )
149
173
return functools .reduce (lambda x , y : x | y , differences , False )
150
174
151
175
@@ -167,6 +191,26 @@ class BufferType(enum.Enum):
167
191
168
192
MANUAL = 5
169
193
194
+ def _get_block_shape (spec : pl .BlockSpec ) -> tuple [int , ...]:
195
+ """Get the block shape for a given block spec."""
196
+ def _get_dim_size (bd ):
197
+ match bd :
198
+ case pl .Blocked (block_size ):
199
+ return block_size
200
+ case pl .Element ():
201
+ return bd .block_size
202
+ case pl .BoundedSlice (block_size ):
203
+ return block_size
204
+ case int ():
205
+ return bd
206
+ case None :
207
+ return 1
208
+ case _:
209
+ raise ValueError (f"Unsupported block dimension type: { bd } " )
210
+ if spec .block_shape is None :
211
+ raise ValueError ("Block shape must be specified." )
212
+ block_shape = tuple (_get_dim_size (x ) for x in spec .block_shape )
213
+ return block_shape
170
214
171
215
@tree_util .register_pytree_node_class
172
216
@dataclasses .dataclass (frozen = True )
@@ -236,7 +280,8 @@ def buffer_types() -> type[BufferType]:
236
280
return BufferType
237
281
238
282
@classmethod
239
- def create (cls , spec , dtype , buffer_type , needs_swap_ref = True ) -> BufferedRef :
283
+ def create (cls , spec : pl .BlockSpec , dtype , buffer_type , needs_swap_ref = True
284
+ ) -> BufferedRef :
240
285
"""Create a BufferedRef.
241
286
242
287
Args:
@@ -249,7 +294,7 @@ def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
249
294
Returns:
250
295
Initialized BufferedRef
251
296
"""
252
- block_shape = tuple ( 1 if x is None else x for x in spec . block_shape )
297
+ block_shape = _get_block_shape ( spec )
253
298
if buffer_type is BufferType .ACCUMULATOR :
254
299
accum_ref = VMEM (block_shape , dtype )
255
300
else :
@@ -375,9 +420,22 @@ def bind_existing_ref(self, window_ref, indices):
375
420
376
421
def compute_slice (self , grid_indices ):
377
422
"""Compute DMA slice from grid indices."""
378
- block_shape = tuple (1 if x is None else x for x in self .block_shape )
423
+ block_shape = []
424
+ for bd in self .block_shape :
425
+ if isinstance (bd , (pl .Element , pl .BoundedSlice )):
426
+ raise ValueError (
427
+ "Element and BoundedSlice block dimensions are not supported."
428
+ )
429
+ if bd is None :
430
+ block_shape .append (1 )
431
+ elif isinstance (bd , pl .Blocked ):
432
+ block_shape .append (bd .block_size )
433
+ elif isinstance (bd , int ):
434
+ block_shape .append (bd )
435
+ else :
436
+ raise ValueError (f"Unsupported block dimension type: { type (bd )} " )
379
437
indices = self .compute_index (* grid_indices )
380
- return jax .tree .map (_make_ds , indices , block_shape )
438
+ return jax .tree .map (_make_block_ds , indices , tuple ( block_shape ) )
381
439
382
440
def init_slots (self ):
383
441
"""Initialize slot indices."""
@@ -444,10 +502,12 @@ def get_dma_slice(self, src_shape, src_dtype, grid_indices):
444
502
raise NotImplementedError ("Must use >1D values." )
445
503
446
504
tiling = _make_tiling (src_shape , src_dtype )
447
- block_shape = tuple (1 if b is None else b for b in self .block_shape )
448
505
block_indices = self .compute_index (* grid_indices )
449
- return jax .tree .map (
450
- _make_block_slice , block_indices , block_shape , src_shape , tiling
506
+ return tuple (
507
+ _make_block_slice (bi , bs , ss , t )
508
+ for bi , bs , ss , t in zip (
509
+ block_indices , self .block_shape , src_shape , tiling , strict = True
510
+ )
451
511
)
452
512
453
513
def copy_in (self , src_ref , grid_indices ):
0 commit comments