Skip to content

Commit

Permalink
Improve a potential out-of-memory edge case in CowData by propagati…
Browse files Browse the repository at this point in the history
…ng errors out of `_copy_on_write`.
  • Loading branch information
Ivorforce committed Dec 20, 2024
1 parent 89001f9 commit 0ee79d2
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions core/templates/cowdata.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class CowData {
void _unref();
void _ref(const CowData *p_from);
void _ref(const CowData &p_from);
USize _copy_on_write();
Error _copy_on_write();
Error _realloc(Size p_alloc_size);

public:
Expand Down Expand Up @@ -288,43 +288,43 @@ void CowData<T>::_unref() {
}

template <typename T>
typename CowData<T>::USize CowData<T>::_copy_on_write() {
Error CowData<T>::_copy_on_write() {
if (!_ptr) {
return 0;
return OK;
}

SafeNumeric<USize> *refc = _get_refcount();

USize rc = refc->get();
if (unlikely(rc > 1)) {
/* in use by more than me */
USize current_size = *_get_size();
if (likely(refc->get() <= 1)) {
return OK;
}

uint8_t *mem_new = (uint8_t *)Memory::alloc_static(_get_alloc_size(current_size) + DATA_OFFSET, false);
ERR_FAIL_NULL_V(mem_new, 0);
// There are other references to the data, so we need to fork.
USize current_size = *_get_size();

SafeNumeric<USize> *_refc_ptr = _get_refcount_ptr(mem_new);
USize *_size_ptr = _get_size_ptr(mem_new);
T *_data_ptr = _get_data_ptr(mem_new);
uint8_t *mem_new = (uint8_t *)Memory::alloc_static(_get_alloc_size(current_size) + DATA_OFFSET, false);
ERR_FAIL_NULL_V(mem_new, ERR_OUT_OF_MEMORY);

new (_refc_ptr) SafeNumeric<USize>(1); //refcount
*(_size_ptr) = current_size; //size
SafeNumeric<USize> *_refc_ptr = _get_refcount_ptr(mem_new);
USize *_size_ptr = _get_size_ptr(mem_new);
T *_data_ptr = _get_data_ptr(mem_new);

// initialize new elements
if constexpr (std::is_trivially_copyable_v<T>) {
memcpy((uint8_t *)_data_ptr, _ptr, current_size * sizeof(T));
} else {
for (USize i = 0; i < current_size; i++) {
memnew_placement(&_data_ptr[i], T(_ptr[i]));
}
new (_refc_ptr) SafeNumeric<USize>(1);
*(_size_ptr) = current_size;

// Copy over the elements.
if constexpr (std::is_trivially_copyable_v<T>) {
memcpy((uint8_t *)_data_ptr, _ptr, current_size * sizeof(T));
} else {
for (USize i = 0; i < current_size; i++) {
memnew_placement(&_data_ptr[i], T(_ptr[i]));
}
}

_unref();
_ptr = _data_ptr;
_unref();
_ptr = _data_ptr;

rc = 1;
}
return rc;
return OK;
}

template <typename T>
Expand All @@ -346,7 +346,10 @@ Error CowData<T>::resize(Size p_size) {
}

// possibly changing size, copy on write
_copy_on_write();
Error error = _copy_on_write();
if (error) {
return error;
}

USize current_alloc_size = _get_alloc_size(current_size);
USize alloc_size;
Expand All @@ -369,7 +372,7 @@ Error CowData<T>::resize(Size p_size) {
_ptr = _data_ptr;

} else {
const Error error = _realloc(alloc_size);
error = _realloc(alloc_size);
if (error) {
return error;
}
Expand Down Expand Up @@ -398,7 +401,7 @@ Error CowData<T>::resize(Size p_size) {
}

if (alloc_size != current_alloc_size) {
const Error error = _realloc(alloc_size);
error = _realloc(alloc_size);
if (error) {
return error;
}
Expand Down

0 comments on commit 0ee79d2

Please sign in to comment.