From cea939065c0811b06a880a1bb4c872d200808756 Mon Sep 17 00:00:00 2001 From: Edan Toledo <42650996+EdanToledo@users.noreply.github.com> Date: Tue, 21 Jan 2025 13:23:19 +0000 Subject: [PATCH] fix: overflow error (#48) --- flashbax/buffers/prioritised_trajectory_buffer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/flashbax/buffers/prioritised_trajectory_buffer.py b/flashbax/buffers/prioritised_trajectory_buffer.py index ea1206f..92a20a3 100644 --- a/flashbax/buffers/prioritised_trajectory_buffer.py +++ b/flashbax/buffers/prioritised_trajectory_buffer.py @@ -202,8 +202,16 @@ def calculate_item_indices_and_priorities( max_num_items = (add_sequence_length // period) + 1 # We get the actual number of items that will be created and use for masking. - actual_num_items = (ending_priority_item_index - starting_priority_item_index) % ( - max_length_time_axis // period + actual_num_items_given_full = ( + ending_priority_item_index - starting_priority_item_index + ) % (max_length_time_axis // period) + # If not full, we simply take the maximum + actual_num_items_given_not_full = jnp.maximum( + 0, (ending_priority_item_index - starting_priority_item_index) + ) + + actual_num_items = jax.lax.select( + state.is_full, actual_num_items_given_full, actual_num_items_given_not_full ) priority_indices = _get_priority_indices(