Skip to content

Commit

Permalink
[Unity][SWA] Overriding windowed cache support (#15963)
Browse files Browse the repository at this point in the history
[unity] window cache support
  • Loading branch information
davidpissarra authored Oct 22, 2023
1 parent b8a1d63 commit 9ff2450
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 0 deletions.
86 changes: 86 additions & 0 deletions src/runtime/relax_vm/lm_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class AttentionKVCacheObj : public Object {
*/
int64_t fill_count{0};

/*!
* \brief current cache position (windowed kv cache only).
*/
int64_t window_attention_current_pos{0};

/*!
* \brief View all current cached values as one array.
* \param shape The cached values.
Expand Down Expand Up @@ -105,6 +110,77 @@ class AttentionKVCacheObj : public Object {
this->fill_count = value->shape[0];
}

/*!
* \brief Append value to the cache, overrides if full.
* \param value The value to override previous elements.
*/
void WindowOverride(NDArray value, int64_t max_cache_size) {
CHECK(data.DataType() == value.DataType()) << "dtype mismatch";
CHECK_LE(value->shape[0], max_cache_size) << "dim 0 of value too large";
// reallocate cache
if (fill_count + value->shape[0] <= max_cache_size) {
int64_t reserved_slots = data->shape[0];
while (fill_count + value->shape[0] > reserved_slots) {
reserved_slots *= 2;
}
if (reserved_slots != data->shape[0]) {
std::vector<int64_t> new_shape(data->shape, data->shape + data->ndim);
new_shape[0] = reserved_slots;
NDArray new_data = NDArray::Empty(new_shape, data->dtype, data->device);
new_data.CreateView(data.Shape(), data->dtype).CopyFrom(data);
this->data = new_data;
}
}
// copy into the current position.
ICHECK(data.IsContiguous());

int64_t num_elements_to_copy =
std::min(value->shape[0], max_cache_size - window_attention_current_pos);
int64_t num_elements_p_entry = 1;
std::vector<int64_t> shape;
shape.push_back(num_elements_to_copy);
for (int i = 1; i < data->ndim; ++i) {
CHECK_EQ(value->shape[i], data->shape[i]) << "Dimension " << i << " mismatch";
num_elements_p_entry *= data->shape[i];
shape.push_back(data->shape[i]);
}
int64_t num_filled_elements = window_attention_current_pos * num_elements_p_entry;

DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = num_filled_elements * ((data->dtype.bits * data->dtype.lanes + 7) / 8);
copy_dst.shape = &shape[0];

DLTensor copy_src = *(value.operator->());
copy_src.byte_offset = 0;
copy_src.shape = &shape[0];

NDArray::CopyFromTo(&copy_src, &copy_dst);
this->fill_count = std::min(this->fill_count + value->shape[0], max_cache_size);
this->window_attention_current_pos =
std::min(this->window_attention_current_pos + value->shape[0], max_cache_size);

// copy the remainder to the beginning of the cache
if (num_elements_to_copy < value->shape[0]) {
ICHECK_EQ(this->fill_count, max_cache_size);
ICHECK_EQ(this->fill_count, this->window_attention_current_pos);

shape[0] = value->shape[0] - num_elements_to_copy;
num_filled_elements = num_elements_to_copy * num_elements_p_entry;

DLTensor copy_dst = *(data.operator->());
copy_dst.byte_offset = 0;
copy_dst.shape = &shape[0];

DLTensor copy_src = *(value.operator->());
copy_src.byte_offset =
num_filled_elements * ((value->dtype.bits * value->dtype.lanes + 7) / 8);
copy_src.shape = &shape[0];

NDArray::CopyFromTo(&copy_src, &copy_dst);
this->window_attention_current_pos = value->shape[0] - num_elements_to_copy;
}
}

/*!
* \brief Append value to the cache.
* \param value The value to be appended.
Expand Down Expand Up @@ -159,6 +235,7 @@ class AttentionKVCache : public ObjectRef {
n->Append(init_data);
if (init_fill_count >= 0) {
n->fill_count = init_fill_count;
n->window_attention_current_pos = init_fill_count; // window attention only
}
return AttentionKVCache(n);
}
Expand Down Expand Up @@ -232,6 +309,15 @@ AttentionKVCache AttentionKVCacheAppend(AttentionKVCache cache, NDArray value) {

TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_append").set_body_typed(AttentionKVCacheAppend);

AttentionKVCache AttentionKVCacheWindowOverride(AttentionKVCache cache, NDArray value,
int64_t max_cache_size) {
cache->WindowOverride(value, max_cache_size);
return cache;
}

TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_window_override")
.set_body_typed(AttentionKVCacheWindowOverride);

NDArray AttentionKVCacheView(AttentionKVCache cache, ShapeTuple shape) {
return cache->View(shape);
}
Expand Down
29 changes: 29 additions & 0 deletions tests/python/relax/test_runtime_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,34 @@ def test_ndarray_cache():
np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6)


def test_attention_kv_cache_window_override():
fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
foverride = tvm.get_global_func("vm.builtin.attention_kv_cache_window_override")
fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")

current_pos = 4
cache = fcreate(
tvm.nd.array(np.full((16, 2), -1).astype("int32")),
tvm.runtime.ShapeTuple([16, 2]),
current_pos,
)
np_all_arrays = np.zeros((0, 2)).astype("int32")

num_steps = 10
for i in range(1, num_steps):
np_array = i * np.ones((i, 2)).astype("int32")
np_all_arrays = np.concatenate((np_all_arrays, np_array), axis=0)
cache = foverride(cache, tvm.nd.array(np_array), 16)
current_pos = (current_pos + i) % 16

res = fview(cache, tvm.runtime.ShapeTuple((16, 2))).numpy()

# unrotate cache and assert cache matches last 16 elements
assert (
np_all_arrays[np_all_arrays.shape[0] - 16 :, :]
== np.concatenate((res[current_pos:], res[:current_pos]))
).all()


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 9ff2450

Please sign in to comment.