diff --git a/flashbax/buffers/prioritised_trajectory_buffer.py b/flashbax/buffers/prioritised_trajectory_buffer.py index ea1206f..92a20a3 100644 --- a/flashbax/buffers/prioritised_trajectory_buffer.py +++ b/flashbax/buffers/prioritised_trajectory_buffer.py @@ -202,8 +202,16 @@ def calculate_item_indices_and_priorities( max_num_items = (add_sequence_length // period) + 1 # We get the actual number of items that will be created and use for masking. - actual_num_items = (ending_priority_item_index - starting_priority_item_index) % ( - max_length_time_axis // period + actual_num_items_given_full = ( + ending_priority_item_index - starting_priority_item_index + ) % (max_length_time_axis // period) + # If not full, we simply take the maximum + actual_num_items_given_not_full = jnp.maximum( + 0, (ending_priority_item_index - starting_priority_item_index) + ) + + actual_num_items = jax.lax.select( + state.is_full, actual_num_items_given_full, actual_num_items_given_not_full ) priority_indices = _get_priority_indices(