diff --git a/ci/conda/recipes/libmrc/meta.yaml b/ci/conda/recipes/libmrc/meta.yaml index fe402a7f0..30916f85c 100644 --- a/ci/conda/recipes/libmrc/meta.yaml +++ b/ci/conda/recipes/libmrc/meta.yaml @@ -58,7 +58,7 @@ requirements: - gtest =1.14 - libhwloc =2.9.2 - librmm {{ rapids_version }} - - nlohmann_json =3.9 + - nlohmann_json =3.11 - pybind11-abi # See: https://conda-forge.org/docs/maintainer/knowledge_base.html#pybind11-abi-constraints - pybind11-stubgen =0.10 - python {{ python }} @@ -90,12 +90,12 @@ outputs: - libgrpc =1.59 - libhwloc =2.9.2 - librmm {{ rapids_version }} - - nlohmann_json =3.9 + - nlohmann_json =3.11 - ucx =1.15 run: # Manually add any packages necessary for run that do not have run_exports. Keep sorted! - cuda-version {{ cuda_version }}.* - - nlohmann_json =3.9 + - nlohmann_json =3.11 - ucx =1.15 - cuda-cudart - boost-cpp =1.84 diff --git a/ci/iwyu/mappings.imp b/ci/iwyu/mappings.imp index 7e9f70083..97872205b 100644 --- a/ci/iwyu/mappings.imp +++ b/ci/iwyu/mappings.imp @@ -11,6 +11,7 @@ # boost { "include": ["@<boost/fiber/future/detail/.*>", "private", "<boost/fiber/future/future.hpp>", "public"] }, +{ "include": ["@<boost/algorithm/string/detail/.*>", "private", "<boost/algorithm/string.hpp>", "public"] }, # cuda { "include": ["<cuda_runtime_api.h>", "private", "<cuda_runtime.h>", "public"] }, @@ -33,6 +34,7 @@ { "symbol": ["@grpc::.*", "private", "<grpcpp/grpcpp.h>", "public"] }, # nlohmann json +{ "include": ["<nlohmann/json_fwd.hpp>", "public", "<nlohmann/json.hpp>", "public"] }, { "include": ["<nlohmann/detail/iterators/iter_impl.hpp>", "private", "<nlohmann/json.hpp>", "public"] }, { "include": ["<nlohmann/detail/iterators/iteration_proxy.hpp>", "private", "<nlohmann/json.hpp>", "public"] }, { "include": ["<nlohmann/detail/json_ref.hpp>", "private", "<nlohmann/json.hpp>", "public"] }, diff --git a/conda/environments/all_cuda-121_arch-x86_64.yaml b/conda/environments/all_cuda-121_arch-x86_64.yaml index 518b2f271..1e08b3f19 100644 --- a/conda/environments/all_cuda-121_arch-x86_64.yaml +++ b/conda/environments/all_cuda-121_arch-x86_64.yaml @@ -39,7 +39,7 @@ dependencies: - libxml2=2.11.6 - llvmdev=16 - ninja=1.11 -- nlohmann_json=3.9 +- nlohmann_json=3.11 - numactl-libs-cos7-x86_64 - numpy=1.24 - pkg-config=0.29 diff --git a/conda/environments/ci_cuda-121_arch-x86_64.yaml b/conda/environments/ci_cuda-121_arch-x86_64.yaml index 0d5803adb..d425705fb 100644 --- a/conda/environments/ci_cuda-121_arch-x86_64.yaml +++ b/conda/environments/ci_cuda-121_arch-x86_64.yaml @@ -29,7 +29,7 @@ dependencies: - librmm=24.02 - libxml2=2.11.6 - ninja=1.11 -- nlohmann_json=3.9 +- nlohmann_json=3.11 - numactl-libs-cos7-x86_64 - pkg-config=0.29 - pre-commit diff --git a/cpp/mrc/include/mrc/coroutines/task_container.hpp b/cpp/mrc/include/mrc/coroutines/task_container.hpp index 20cab894e..88730b919 100644 --- a/cpp/mrc/include/mrc/coroutines/task_container.hpp +++ b/cpp/mrc/include/mrc/coroutines/task_container.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -37,15 +37,14 @@ */ #pragma once - #include "mrc/coroutines/task.hpp" -#include <atomic> #include <cstddef> #include <list> #include <memory> #include <mutex> #include <optional> +#include <queue> #include <vector> namespace mrc::coroutines { @@ -60,7 +59,7 @@ class TaskContainer * @param e Tasks started in the container are scheduled onto this executor. For tasks created * from a coro::io_scheduler, this would usually be that coro::io_scheduler instance. */ - TaskContainer(std::shared_ptr<Scheduler> e); + TaskContainer(std::shared_ptr<Scheduler> e, std::size_t max_concurrent_tasks = 0); TaskContainer(const TaskContainer&) = delete; TaskContainer(TaskContainer&&) = delete; @@ -93,30 +92,20 @@ class TaskContainer */ auto garbage_collect() -> std::size_t; - /** - * @return The number of tasks that are awaiting deletion. - */ - auto delete_task_size() const -> std::size_t; - - /** - * @return True if there are no tasks awaiting deletion. - */ - auto delete_tasks_empty() const -> bool; - /** * @return The number of active tasks in the container. */ - auto size() const -> std::size_t; + auto size() -> std::size_t; /** * @return True if there are no active tasks in the container. */ - auto empty() const -> bool; + auto empty() -> bool; /** * @return The capacity of this task manager before it will need to grow in size. */ - auto capacity() const -> std::size_t; + auto capacity() -> std::size_t; /** * Will continue to garbage collect and yield until all tasks are complete. This method can be @@ -138,6 +127,11 @@ class TaskContainer */ auto gc_internal() -> std::size_t; + /** + * Starts the next taks in the queue if one is available and max concurrent tasks has not yet been met. + */ + void try_start_next_task(std::unique_lock<std::mutex> lock); + /** * Encapsulate the users tasks in a cleanup task which marks itself for deletion upon * completion. Simply co_await the users task until its completed and then mark the given @@ -156,7 +150,7 @@ class TaskContainer /// thread pools for indeterminate lifetime requests. std::mutex m_mutex{}; /// The number of alive tasks. - std::atomic<std::size_t> m_size{}; + std::size_t m_size{}; /// Maintains the lifetime of the tasks until they are completed. std::list<std::optional<Task<void>>> m_tasks{}; /// The set of tasks that have completed and need to be deleted. @@ -166,6 +160,10 @@ class TaskContainer std::shared_ptr<Scheduler> m_scheduler_lifetime{nullptr}; /// This is used internally since io_scheduler cannot pass itself in as a shared_ptr. Scheduler* m_scheduler{nullptr}; + /// tasks to be processed in order of start + std::queue<decltype(m_tasks.end())> m_next_tasks; + /// maximum number of tasks to be run simultaneously + std::size_t m_max_concurrent_tasks; friend Scheduler; }; diff --git a/cpp/mrc/include/mrc/coroutines/test_scheduler.hpp b/cpp/mrc/include/mrc/coroutines/test_scheduler.hpp index ba2847415..5d74f2168 100644 --- a/cpp/mrc/include/mrc/coroutines/test_scheduler.hpp +++ b/cpp/mrc/include/mrc/coroutines/test_scheduler.hpp @@ -80,6 +80,13 @@ class TestScheduler : public Scheduler */ mrc::coroutines::Task<> yield_until(std::chrono::time_point<std::chrono::steady_clock> time) override; + /** + * Returns the time according to the scheduler. Time may be progressed by resume_next, resume_for, and resume_until. + * + * @return the current time according to the scheduler. + */ + std::chrono::time_point<std::chrono::steady_clock> time(); + /** * Immediately resumes the next-in-queue coroutine handle. * diff --git a/cpp/mrc/src/public/coroutines/task_container.cpp b/cpp/mrc/src/public/coroutines/task_container.cpp index 317f489f9..85a765517 100644 --- a/cpp/mrc/src/public/coroutines/task_container.cpp +++ b/cpp/mrc/src/public/coroutines/task_container.cpp @@ -30,9 +30,10 @@ namespace mrc::coroutines { -TaskContainer::TaskContainer(std::shared_ptr<Scheduler> e) : +TaskContainer::TaskContainer(std::shared_ptr<Scheduler> e, std::size_t max_concurrent_tasks) : m_scheduler_lifetime(std::move(e)), - m_scheduler(m_scheduler_lifetime.get()) + m_scheduler(m_scheduler_lifetime.get()), + m_max_concurrent_tasks(max_concurrent_tasks) { if (m_scheduler_lifetime == nullptr) { @@ -43,17 +44,17 @@ TaskContainer::TaskContainer(std::shared_ptr<Scheduler> e) : TaskContainer::~TaskContainer() { // This will hang the current thread.. but if tasks are not complete thats also pretty bad. - while (!this->empty()) + while (not empty()) { - this->garbage_collect(); + garbage_collect(); } } auto TaskContainer::start(Task<void>&& user_task, GarbageCollectPolicy cleanup) -> void { - m_size.fetch_add(1, std::memory_order::relaxed); + auto lock = std::unique_lock(m_mutex); - std::scoped_lock lk{m_mutex}; + m_size += 1; if (cleanup == GarbageCollectPolicy::yes) { @@ -64,48 +65,42 @@ auto TaskContainer::start(Task<void>&& user_task, GarbageCollectPolicy cleanup) auto pos = m_tasks.emplace(m_tasks.end(), std::nullopt); auto task = make_cleanup_task(std::move(user_task), pos); *pos = std::move(task); + m_next_tasks.push(pos); - // Start executing from the cleanup task to schedule the user's task onto the thread pool. - pos->value().resume(); + auto current_task_count = m_size - m_next_tasks.size(); + + if (m_max_concurrent_tasks == 0 or current_task_count < m_max_concurrent_tasks) + { + try_start_next_task(std::move(lock)); + } } auto TaskContainer::garbage_collect() -> std::size_t { - std::scoped_lock lk{m_mutex}; + auto lock = std::scoped_lock(m_mutex); return gc_internal(); } -auto TaskContainer::delete_task_size() const -> std::size_t -{ - std::atomic_thread_fence(std::memory_order::acquire); - return m_tasks_to_delete.size(); -} - -auto TaskContainer::delete_tasks_empty() const -> bool +auto TaskContainer::size() -> std::size_t { - std::atomic_thread_fence(std::memory_order::acquire); - return m_tasks_to_delete.empty(); + auto lock = std::scoped_lock(m_mutex); + return m_size; } -auto TaskContainer::size() const -> std::size_t -{ - return m_size.load(std::memory_order::relaxed); -} - -auto TaskContainer::empty() const -> bool +auto TaskContainer::empty() -> bool { return size() == 0; } -auto TaskContainer::capacity() const -> std::size_t +auto TaskContainer::capacity() -> std::size_t { - std::atomic_thread_fence(std::memory_order::acquire); + auto lock = std::scoped_lock(m_mutex); return m_tasks.size(); } auto TaskContainer::garbage_collect_and_yield_until_empty() -> Task<void> { - while (!empty()) + while (not empty()) { garbage_collect(); co_await m_scheduler->yield(); @@ -115,22 +110,44 @@ auto TaskContainer::garbage_collect_and_yield_until_empty() -> Task<void> TaskContainer::TaskContainer(Scheduler& e) : m_scheduler(&e) {} auto TaskContainer::gc_internal() -> std::size_t { - std::size_t deleted{0}; - if (!m_tasks_to_delete.empty()) + if (m_tasks_to_delete.empty()) + { + return 0; + } + + std::size_t delete_count = m_tasks_to_delete.size(); + + for (const auto& pos : m_tasks_to_delete) { - for (const auto& pos : m_tasks_to_delete) + // Destroy the cleanup task and the user task. + if (pos->has_value()) { - // Destroy the cleanup task and the user task. - if (pos->has_value()) - { - pos->value().destroy(); - } - m_tasks.erase(pos); + pos->value().destroy(); } - deleted = m_tasks_to_delete.size(); - m_tasks_to_delete.clear(); + + m_tasks.erase(pos); + } + + m_tasks_to_delete.clear(); + + return delete_count; +} + +void TaskContainer::try_start_next_task(std::unique_lock<std::mutex> lock) +{ + if (m_next_tasks.empty()) + { + // no tasks to process + return; } - return deleted; + + auto pos = m_next_tasks.front(); + m_next_tasks.pop(); + + // release the lock before starting the task + lock.unlock(); + + pos->value().resume(); } auto TaskContainer::make_cleanup_task(Task<void> user_task, task_position_t pos) -> Task<void> @@ -155,11 +172,14 @@ auto TaskContainer::make_cleanup_task(Task<void> user_task, task_position_t pos) LOG(ERROR) << "coro::task_container user_task had unhandle exception, not derived from std::exception.\n"; } - std::scoped_lock lk{m_mutex}; + auto lock = std::unique_lock(m_mutex); m_tasks_to_delete.push_back(pos); // This has to be done within scope lock to make sure this coroutine task completes before the // task container object destructs -- if it was waiting on .empty() to become true. - m_size.fetch_sub(1, std::memory_order::relaxed); + m_size -= 1; + + try_start_next_task(std::move(lock)); + co_return; } diff --git a/cpp/mrc/src/public/coroutines/test_scheduler.cpp b/cpp/mrc/src/public/coroutines/test_scheduler.cpp index 0cc3ef130..fba53c250 100644 --- a/cpp/mrc/src/public/coroutines/test_scheduler.cpp +++ b/cpp/mrc/src/public/coroutines/test_scheduler.cpp @@ -17,6 +17,7 @@ #include "mrc/coroutines/test_scheduler.hpp" +#include <chrono> #include <compare> namespace mrc::coroutines { @@ -56,8 +57,15 @@ mrc::coroutines::Task<> TestScheduler::yield_until(std::chrono::time_point<std:: co_return co_await TestScheduler::Operation{this, time}; } +std::chrono::time_point<std::chrono::steady_clock> TestScheduler::time() +{ + return m_time; +} + bool TestScheduler::resume_next() { + using namespace std::chrono_literals; + if (m_queue.empty()) { return false; @@ -69,6 +77,11 @@ bool TestScheduler::resume_next() m_time = handle.second; + if (not m_queue.empty()) + { + m_time = m_queue.top().second; + } + handle.first.resume(); return true; diff --git a/cpp/mrc/tests/coroutines/test_task_container.cpp b/cpp/mrc/tests/coroutines/test_task_container.cpp index a55f88039..3a5a1bbf0 100644 --- a/cpp/mrc/tests/coroutines/test_task_container.cpp +++ b/cpp/mrc/tests/coroutines/test_task_container.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,9 +15,78 @@ * limitations under the License. */ +#include "mrc/coroutines/sync_wait.hpp" +#include "mrc/coroutines/task.hpp" +#include "mrc/coroutines/task_container.hpp" +#include "mrc/coroutines/test_scheduler.hpp" + #include <gtest/gtest.h> +#include <chrono> +#include <coroutine> +#include <cstdint> +#include <memory> +#include <ratio> +#include <thread> +#include <vector> + class TestCoroTaskContainer : public ::testing::Test {}; TEST_F(TestCoroTaskContainer, LifeCycle) {} + +TEST_F(TestCoroTaskContainer, MaxSimultaneousTasks) +{ + using namespace std::chrono_literals; + + const int32_t num_threads = 16; + const int32_t num_tasks_per_thread = 16; + const int32_t num_tasks = num_threads * num_tasks_per_thread; + const int32_t max_concurrent_tasks = 2; + + auto on = std::make_shared<mrc::coroutines::TestScheduler>(); + auto task_container = mrc::coroutines::TaskContainer(on, max_concurrent_tasks); + + auto start_time = on->time(); + + std::vector<std::chrono::time_point<std::chrono::steady_clock>> execution_times; + + auto delay = [](std::shared_ptr<mrc::coroutines::TestScheduler> on, + std::vector<std::chrono::time_point<std::chrono::steady_clock>>& execution_times) + -> mrc::coroutines::Task<> { + co_await on->yield_for(100ms); + execution_times.emplace_back(on->time()); + }; + + std::vector<std::thread> threads; + + for (auto i = 0; i < num_threads; i++) + { + threads.emplace_back([&]() { + for (auto i = 0; i < num_tasks_per_thread; i++) + { + task_container.start(delay(on, execution_times)); + } + }); + } + + for (auto& thread : threads) + { + thread.join(); + } + + auto task = task_container.garbage_collect_and_yield_until_empty(); + + task.resume(); + + while (on->resume_next()) {} + + mrc::coroutines::sync_wait(task); + + ASSERT_EQ(execution_times.size(), num_tasks); + + for (auto i = 0; i < execution_times.size(); i++) + { + ASSERT_EQ(execution_times[i], start_time + (i / max_concurrent_tasks + 1) * 100ms) << "Failed at index " << i; + } +} diff --git a/dependencies.yaml b/dependencies.yaml index a14046b6c..329cc8106 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -63,7 +63,7 @@ dependencies: - librmm=24.02 - libxml2=2.11.6 # 2.12 has a bug preventing round-trip serialization in hwloc - ninja=1.11 - - nlohmann_json=3.9 + - nlohmann_json=3.11 - numactl-libs-cos7-x86_64 - pkg-config=0.29 - pybind11-stubgen=0.10 diff --git a/python/mrc/_pymrc/CMakeLists.txt b/python/mrc/_pymrc/CMakeLists.txt index 8e9d12310..adfc03c21 100644 --- a/python/mrc/_pymrc/CMakeLists.txt +++ b/python/mrc/_pymrc/CMakeLists.txt @@ -37,6 +37,7 @@ add_library(pymrc src/utilities/acquire_gil.cpp src/utilities/deserializers.cpp src/utilities/function_wrappers.cpp + src/utilities/json_values.cpp src/utilities/object_cache.cpp src/utilities/object_wrappers.cpp src/utilities/serializers.cpp diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp index 77541d06b..929e37ac5 100644 --- a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp +++ b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp @@ -209,12 +209,6 @@ class AsyncioRunnable : public AsyncSink<InputT>, std::shared_ptr<mrc::coroutines::Scheduler> on) = 0; std::stop_source m_stop_source; - - /** - * @brief A semaphore used to control the number of outstanding operations. Acquire one before - * beginning a task, and release it when finished. - */ - std::counting_semaphore<8> m_task_tickets{8}; }; template <typename InputT, typename OutputT> @@ -269,9 +263,16 @@ void AsyncioRunnable<InputT, OutputT>::run(mrc::runnable::Context& ctx) loop.attr("close")(); } - // Need to drop the output edges - mrc::node::SourceProperties<OutputT>::release_edge_connection(); - mrc::node::SinkProperties<InputT>::release_edge_connection(); + // Sync all progress engines if there are more than one + ctx.barrier(); + + // Only drop the output edges if we are rank 0 + if (ctx.rank() == 0) + { + // Need to drop the output edges + mrc::node::SourceProperties<OutputT>::release_edge_connection(); + mrc::node::SinkProperties<InputT>::release_edge_connection(); + } if (exception != nullptr) { @@ -282,14 +283,12 @@ void AsyncioRunnable<InputT, OutputT>::run(mrc::runnable::Context& ctx) template <typename InputT, typename OutputT> coroutines::Task<> AsyncioRunnable<InputT, OutputT>::main_task(std::shared_ptr<mrc::coroutines::Scheduler> scheduler) { - coroutines::TaskContainer outstanding_tasks(scheduler); + coroutines::TaskContainer outstanding_tasks(scheduler, 8); ExceptionCatcher catcher{}; while (not m_stop_source.stop_requested() and not catcher.has_exception()) { - m_task_tickets.acquire(); - InputT data; auto read_status = co_await this->read_async(data); @@ -335,8 +334,6 @@ coroutines::Task<> AsyncioRunnable<InputT, OutputT>::process_one(InputT value, { catcher.push_exception(std::current_exception()); } - - m_task_tickets.release(); } template <typename InputT, typename OutputT> diff --git a/python/mrc/_pymrc/include/pymrc/types.hpp b/python/mrc/_pymrc/include/pymrc/types.hpp index fcaa9942b..5446ec28a 100644 --- a/python/mrc/_pymrc/include/pymrc/types.hpp +++ b/python/mrc/_pymrc/include/pymrc/types.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,9 +21,12 @@ #include "mrc/segment/object.hpp" +#include <nlohmann/json_fwd.hpp> #include <rxcpp/rx.hpp> -#include <functional> +#include <functional> // for function +#include <map> +#include <string> namespace mrc::pymrc { @@ -37,4 +40,16 @@ using PyNode = mrc::segment::ObjectProperties; using PyObjectOperateFn = std::function<PyObjectObservable(PyObjectObservable source)>; // NOLINTEND(readability-identifier-naming) +using python_map_t = std::map<std::string, pybind11::object>; + +/** + * @brief Unserializable handler function type, invoked by `cast_from_pyobject` when an object cannot be serialized to + * JSON. Implementations should return a valid json object, or throw an exception if the object cannot be serialized. + * @param source : pybind11 object + * @param path : string json path to object + * @return nlohmann::json. + */ +using unserializable_handler_fn_t = + std::function<nlohmann::json(const pybind11::object& /* source*/, const std::string& /* path */)>; + } // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/utilities/json_values.hpp b/python/mrc/_pymrc/include/pymrc/utilities/json_values.hpp new file mode 100644 index 000000000..24bbd52a2 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/utilities/json_values.hpp @@ -0,0 +1,157 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "pymrc/types.hpp" // for python_map_t & unserializable_handler_fn_t + +#include <nlohmann/json.hpp> +#include <pybind11/pytypes.h> // for PYBIND11_EXPORT & pybind11::object + +#include <cstddef> // for size_t +#include <string> +// IWYU wants us to use the pybind11.h for the PYBIND11_EXPORT macro, but we already have it in pytypes.h +// IWYU pragma: no_include <pybind11/pybind11.h> + +namespace mrc::pymrc { + +#pragma GCC visibility push(default) + +/** + * @brief Immutable container for holding Python values as JSON objects if possible, and as pybind11::object otherwise. + * The container can be copied and moved, but the underlying JSON object is immutable. + **/ +class PYBIND11_EXPORT JSONValues +{ + public: + JSONValues(); + JSONValues(pybind11::object values); + JSONValues(nlohmann::json values); + + JSONValues(const JSONValues& other) = default; + JSONValues(JSONValues&& other) = default; + ~JSONValues() = default; + + JSONValues& operator=(const JSONValues& other) = default; + JSONValues& operator=(JSONValues&& other) = default; + + /** + * @brief Sets a value in the JSON object at the specified path with the provided Python object. If `value` is + * serializable as JSON it will be stored as JSON, otherwise it will be stored as-is. + * @param path The path in the JSON object where the value should be set. + * @param value The Python object to set. + * @throws std::runtime_error If the path is invalid. + * @return A new JSONValues object with the updated value. + */ + JSONValues set_value(const std::string& path, const pybind11::object& value) const; + + /** + * @brief Sets a value in the JSON object at the specified path with the provided JSON object. + * @param path The path in the JSON object where the value should be set. + * @param value The JSON object to set. + * @throws std::runtime_error If the path is invalid. + * @return A new JSONValues object with the updated value. + */ + JSONValues set_value(const std::string& path, nlohmann::json value) const; + + /** + * @brief Sets a value in the JSON object at the specified path with the provided JSONValues object. + * @param path The path in the JSON object where the value should be set. + * @param value The JSONValues object to set. + * @throws std::runtime_error If the path is invalid. + * @return A new JSONValues object with the updated value. + */ + JSONValues set_value(const std::string& path, const JSONValues& value) const; + + /** + * @brief Returns the number of unserializable Python objects. + * @return The number of unserializable Python objects. + */ + std::size_t num_unserializable() const; + + /** + * @brief Checks if there are any unserializable Python objects. + * @return True if there are unserializable Python objects, false otherwise. + */ + bool has_unserializable() const; + + /** + * @brief Convert to a Python object. + * @return The Python object representation of the values. + */ + pybind11::object to_python() const; + + /** + * @brief Returns a constant reference to the underlying JSON object. Any unserializable Python objects, will be + * represented in the JSON object with a string place-holder with the value `"**pymrc_placeholder"`. + * @return A constant reference to the JSON object. + */ + nlohmann::json::const_reference view_json() const; + + /** + * @brief Converts the JSON object to a JSON object. If any unserializable Python objects are present, the + * `unserializable_handler_fn` will be invoked to handle the object. + * @param unserializable_handler_fn Optional function to handle unserializable objects. + * @return The JSON string representation of the JSON object. + */ + nlohmann::json to_json(unserializable_handler_fn_t unserializable_handler_fn) const; + + /** + * @brief Converts a Python object to a JSON string. Convienence function that matches the + * `unserializable_handler_fn_t` signature. Convienent for use with `to_json` and `get_json`. + * @param obj The Python object to convert. + * @param path The path in the JSON object where the value should be set. + * @return The JSON string representation of the Python object. + */ + static nlohmann::json stringify(const pybind11::object& obj, const std::string& path); + + /** + * @brief Returns the object at the specified path as a Python object. + * @param path Path to the specified object. + * @throws std::runtime_error If the path does not exist or is not a valid path. + * @return Python representation of the object at the specified path. + */ + pybind11::object get_python(const std::string& path) const; + + /** + * @brief Returns the object at the specified path. If the object is an unserializable Python object the + * `unserializable_handler_fn` will be invoked. + * @param path Path to the specified object. + * @param unserializable_handler_fn Function to handle unserializable objects. + * @throws std::runtime_error If the path does not exist or is not a valid path. + * @return The JSON object at the specified path. + */ + nlohmann::json get_json(const std::string& path, unserializable_handler_fn_t unserializable_handler_fn) const; + + /** + * @brief Return a new JSONValues object with the value at the specified path. + * @param path Path to the specified object. + * @throws std::runtime_error If the path does not exist or is not a valid path. + * @return The value at the specified path. + */ + JSONValues operator[](const std::string& path) const; + + private: + JSONValues(nlohmann::json&& values, python_map_t&& py_objects); + nlohmann::json unserializable_handler(const pybind11::object& obj, const std::string& path); + + nlohmann::json m_serialized_values; + python_map_t m_py_objects; +}; + +#pragma GCC visibility pop +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/utilities/object_cache.hpp b/python/mrc/_pymrc/include/pymrc/utilities/object_cache.hpp index 2721eb5db..68c106064 100644 --- a/python/mrc/_pymrc/include/pymrc/utilities/object_cache.hpp +++ b/python/mrc/_pymrc/include/pymrc/utilities/object_cache.hpp @@ -17,6 +17,8 @@ #pragma once +#include "pymrc/types.hpp" + #include <pybind11/pytypes.h> #include <cstddef> @@ -95,7 +97,7 @@ class __attribute__((visibility("default"))) PythonObjectCache */ void atexit_callback(); - std::map<std::string, pybind11::object> m_object_cache; + python_map_t m_object_cache; }; #pragma GCC visibility pop diff --git a/python/mrc/_pymrc/include/pymrc/utils.hpp b/python/mrc/_pymrc/include/pymrc/utils.hpp index fbfe2e02f..714605e6a 100644 --- a/python/mrc/_pymrc/include/pymrc/utils.hpp +++ b/python/mrc/_pymrc/include/pymrc/utils.hpp @@ -17,6 +17,8 @@ #pragma once +#include "pymrc/types.hpp" + #include <nlohmann/json_fwd.hpp> #include <pybind11/pybind11.h> #include <pybind11/pytypes.h> @@ -31,8 +33,25 @@ namespace mrc::pymrc { #pragma GCC visibility push(default) pybind11::object cast_from_json(const nlohmann::json& source); + +/** + * @brief Convert a pybind11 object to a JSON object. If the object cannot be serialized, a pybind11::type_error + * exception be thrown. + * @param source : pybind11 object + * @return nlohmann::json. + */ nlohmann::json cast_from_pyobject(const pybind11::object& source); +/** + * @brief Convert a pybind11 object to a JSON object. If the object cannot be serialized, the unserializable_handler_fn + * will be invoked to handle the object. + * @param source : pybind11 object + * @param unserializable_handler_fn : unserializable_handler_fn_t + * @return nlohmann::json. + */ +nlohmann::json cast_from_pyobject(const pybind11::object& source, + unserializable_handler_fn_t unserializable_handler_fn); + void import_module_object(pybind11::module_&, const std::string&, const std::string&); void import_module_object(pybind11::module_& dest, const pybind11::module_& mod); diff --git a/python/mrc/_pymrc/src/utilities/json_values.cpp b/python/mrc/_pymrc/src/utilities/json_values.cpp new file mode 100644 index 000000000..0a898e4d9 --- /dev/null +++ b/python/mrc/_pymrc/src/utilities/json_values.cpp @@ -0,0 +1,309 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pymrc/utilities/json_values.hpp" + +#include "pymrc/utilities/acquire_gil.hpp" +#include "pymrc/utils.hpp" + +#include "mrc/utils/string_utils.hpp" // for MRC_CONCAT_STR + +#include <boost/algorithm/string.hpp> // for split +#include <glog/logging.h> +#include <pybind11/cast.h> + +#include <functional> // for function +#include <iterator> // for next +#include <map> // for map +#include <sstream> // for operator<< & stringstream +#include <stdexcept> // for runtime_error +#include <utility> // for move +#include <vector> // for vector + +// We already have <boost/algorithm/string.hpp> included we don't need these others, it is also the only public header +// with a definition for boost::is_any_of, so even if we replaced string.hpp with these others we would still need to +// include string.hpp or a detail/ header +// IWYU pragma: no_include <boost/algorithm/string/classification.hpp> +// IWYU pragma: no_include <boost/algorithm/string/split.hpp> +// IWYU pragma: no_include <boost/iterator/iterator_facade.hpp> + +namespace py = pybind11; +using namespace std::string_literals; + +namespace { + +std::vector<std::string> split_path(const std::string& path) +{ + std::vector<std::string> path_parts; + boost::split(path_parts, path, boost::is_any_of("/")); + return path_parts; +} + +struct PyFoundObject +{ + py::object obj; + py::object index = py::none(); +}; + +PyFoundObject find_object_at_path(py::object& obj, + std::vector<std::string>::const_iterator path, + std::vector<std::string>::const_iterator path_end) +{ + // Terminal case + const auto& path_str = *path; + if (path_str.empty()) + { + return PyFoundObject(obj); + } + + // Nested object, since obj is a de-serialized python object the only valid container types will be dict and + // list. There are one of two possibilities here: + // 1. The next_path is terminal and we should assign value to the container + // 2. The next_path is not terminal and we should recurse into the container + auto next_path = std::next(path); + + if (py::isinstance<py::dict>(obj) || py::isinstance<py::list>(obj)) + { + py::object index; + if (py::isinstance<py::dict>(obj)) + { + index = py::cast(path_str); + } + else + { + index = py::cast(std::stoul(path_str)); + } + + if (next_path == path_end) + { + return PyFoundObject{obj, index}; + } + + py::object next_obj = obj[index]; + return find_object_at_path(next_obj, next_path, path_end); + } + + throw std::runtime_error("Invalid path"); +} + +PyFoundObject find_object_at_path(py::object& obj, const std::string& path) +{ + auto path_parts = split_path(path); + + // Since our paths always begin with a '/', the first element will always be empty in the case where path="/" + // path_parts will be {"", ""} and we can skip the first element + auto itr = path_parts.cbegin(); + return find_object_at_path(obj, std::next(itr), path_parts.cend()); +} + +void patch_object(py::object& obj, const std::string& path, const py::object& value) +{ + if (path == "/") + { + // Special case for the root object since find_object_at_path will return a copy not a reference we need to + // perform the assignment here + obj = value; + } + else + { + auto found = find_object_at_path(obj, path); + DCHECK(!found.index.is_none()); + found.obj[found.index] = value; + } +} + +std::string validate_path(const std::string& path) +{ + if (path.empty() || path[0] != '/') + { + return "/" + path; + } + + return path; +} +} // namespace + +namespace mrc::pymrc { +JSONValues::JSONValues() : JSONValues(nlohmann::json()) {} + +JSONValues::JSONValues(py::object values) +{ + AcquireGIL gil; + m_serialized_values = cast_from_pyobject(values, [this](const py::object& source, const std::string& path) { + return this->unserializable_handler(source, path); + }); +} + +JSONValues::JSONValues(nlohmann::json values) : m_serialized_values(std::move(values)) {} + +JSONValues::JSONValues(nlohmann::json&& values, python_map_t&& py_objects) : + m_serialized_values(std::move(values)), + m_py_objects(std::move(py_objects)) +{} + +std::size_t JSONValues::num_unserializable() const +{ + return m_py_objects.size(); +} + +bool JSONValues::has_unserializable() const +{ + return !m_py_objects.empty(); +} + +py::object JSONValues::to_python() const +{ + AcquireGIL gil; + py::object results = cast_from_json(m_serialized_values); + for (const auto& [path, obj] : m_py_objects) + { + DCHECK(path[0] == '/'); + DVLOG(10) << "Restoring object at path: " << path; + patch_object(results, path, obj); + } + + return results; +} + +nlohmann::json::const_reference JSONValues::view_json() const +{ + return m_serialized_values; +} + +nlohmann::json JSONValues::to_json(unserializable_handler_fn_t unserializable_handler_fn) const +{ + // start with a copy + nlohmann::json json_doc = m_serialized_values; + nlohmann::json patches = nlohmann::json::array(); + for (const auto& [path, obj] : m_py_objects) + { + nlohmann::json patch{{"op", "replace"}, {"path", path}, {"value", unserializable_handler_fn(obj, path)}}; + patches.emplace_back(std::move(patch)); + } + + if (!patches.empty()) + { + json_doc.patch_inplace(patches); + } + + return json_doc; +} + +JSONValues JSONValues::operator[](const std::string& path) const +{ + auto validated_path = validate_path(path); + + if (validated_path == "/") + { + return *this; // Return a copy of the object + } + + nlohmann::json::json_pointer node_json_ptr(validated_path); + if (!m_serialized_values.contains(node_json_ptr)) + { + throw std::runtime_error(MRC_CONCAT_STR("Path: '" << path << "' not found in json")); + } + + // take a copy of the sub-object + nlohmann::json value = m_serialized_values[node_json_ptr]; + python_map_t py_objects; + for (const auto& [py_path, obj] : m_py_objects) + { + if (py_path.find(validated_path) == 0) + { + py_objects[py_path] = obj; + } + } + + return {std::move(value), std::move(py_objects)}; +} + +pybind11::object JSONValues::get_python(const std::string& path) const +{ + return (*this)[path].to_python(); +} + +nlohmann::json JSONValues::get_json(const std::string& path, + unserializable_handler_fn_t unserializable_handler_fn) const +{ + return (*this)[path].to_json(unserializable_handler_fn); +} + +nlohmann::json JSONValues::stringify(const pybind11::object& obj, const std::string& path) +{ + AcquireGIL gil; + return py::str(obj).cast<std::string>(); +} + +JSONValues JSONValues::set_value(const std::string& path, const pybind11::object& value) const +{ + AcquireGIL gil; + py::object py_obj = this->to_python(); + patch_object(py_obj, validate_path(path), value); + return {py_obj}; +} + +JSONValues JSONValues::set_value(const std::string& path, nlohmann::json value) const +{ + // Two possibilities: + // 1) We don't have any unserializable objects, in which case we can just update the JSON object + // 2) We do have unserializable objects, in which case we need to cast value to python and call the python + // version of set_value + + if (!has_unserializable()) + { + // The add operation will update an existing value if it exists, or add a new value if it does not + // ref: https://datatracker.ietf.org/doc/html/rfc6902#section-4.1 + nlohmann::json patch{{"op", "add"}, {"path", validate_path(path)}, {"value", value}}; + nlohmann::json patches = nlohmann::json::array({std::move(patch)}); + auto new_values = m_serialized_values.patch(std::move(patches)); + return {std::move(new_values)}; + } + + AcquireGIL gil; + py::object py_obj = cast_from_json(value); + return set_value(path, py_obj); +} + +JSONValues JSONValues::set_value(const std::string& path, const JSONValues& value) const +{ + if (value.has_unserializable()) + { + AcquireGIL gil; + py::object py_obj = value.to_python(); + return set_value(path, py_obj); + } + + return set_value(path, value.to_json([](const py::object& source, const std::string& path) { + DLOG(FATAL) << "Should never be called"; + return nlohmann::json(); // unreachable but needed to satisfy the signature + })); +} + +nlohmann::json JSONValues::unserializable_handler(const py::object& obj, const std::string& path) +{ + /* We don't know how to serialize the Object, throw it into m_py_objects and return a place-holder */ + + // Take a non-const copy of the object + py::object non_const_copy = obj; + DVLOG(10) << "Storing unserializable object at path: " << path; + m_py_objects[path] = std::move(non_const_copy); + + return "**pymrc_placeholder"s; +} + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/src/utils.cpp b/python/mrc/_pymrc/src/utils.cpp index 02b94a269..22379b594 100644 --- a/python/mrc/_pymrc/src/utils.cpp +++ b/python/mrc/_pymrc/src/utils.cpp @@ -28,12 +28,12 @@ #include <pyerrors.h> #include <warnings.h> +#include <functional> // for function #include <sstream> #include <string> #include <utility> namespace mrc::pymrc { - namespace py = pybind11; using nlohmann::json; @@ -139,7 +139,9 @@ py::object cast_from_json(const json& source) // throw std::runtime_error("Unsupported conversion type."); } -json cast_from_pyobject_impl(const py::object& source, const std::string& parent_path = "") +json cast_from_pyobject_impl(const py::object& source, + unserializable_handler_fn_t unserializable_handler_fn, + const std::string& parent_path = "") { // Dont return via initializer list with JSON. It performs type deduction and gives different results // NOLINTBEGIN(modernize-return-braced-init-list) @@ -147,6 +149,7 @@ json cast_from_pyobject_impl(const py::object& source, const std::string& parent { return json(); } + if (py::isinstance<py::dict>(source)) { const auto py_dict = source.cast<py::dict>(); @@ -155,34 +158,40 @@ json cast_from_pyobject_impl(const py::object& source, const std::string& parent { std::string key{p.first.cast<std::string>()}; std::string path{parent_path + "/" + key}; - json_obj[key] = cast_from_pyobject_impl(p.second.cast<py::object>(), path); + json_obj[key] = cast_from_pyobject_impl(p.second.cast<py::object>(), unserializable_handler_fn, path); } return json_obj; } + if (py::isinstance<py::list>(source) || py::isinstance<py::tuple>(source)) { const auto py_list = source.cast<py::list>(); auto json_arr = json::array(); for (const auto& p : py_list) { - json_arr.push_back(cast_from_pyobject_impl(p.cast<py::object>(), parent_path)); + std::string path{parent_path + "/" + std::to_string(json_arr.size())}; + json_arr.push_back(cast_from_pyobject_impl(p.cast<py::object>(), unserializable_handler_fn, path)); } return json_arr; } + if (py::isinstance<py::bool_>(source)) { return json(py::cast<bool>(source)); } + if (py::isinstance<py::int_>(source)) { return json(py::cast<long>(source)); } + if (py::isinstance<py::float_>(source)) { return json(py::cast<double>(source)); } + if (py::isinstance<py::str>(source)) { return json(py::cast<std::string>(source)); @@ -198,6 +207,11 @@ json cast_from_pyobject_impl(const py::object& source, const std::string& parent path = "/"; } + if (unserializable_handler_fn != nullptr) + { + return unserializable_handler_fn(source, path); + } + error_message << "Object (" << py::str(source).cast<std::string>() << ") of type: " << get_py_type_name(source) << " at path: " << path << " is not JSON serializable"; @@ -208,9 +222,14 @@ json cast_from_pyobject_impl(const py::object& source, const std::string& parent // NOLINTEND(modernize-return-braced-init-list) } +json cast_from_pyobject(const py::object& source, unserializable_handler_fn_t unserializable_handler_fn) +{ + return cast_from_pyobject_impl(source, unserializable_handler_fn); +} + json cast_from_pyobject(const py::object& source) { - return cast_from_pyobject_impl(source); + return cast_from_pyobject_impl(source, nullptr); } void show_deprecation_warning(const std::string& deprecation_message, ssize_t stack_level) diff --git a/python/mrc/_pymrc/tests/CMakeLists.txt b/python/mrc/_pymrc/tests/CMakeLists.txt index 02186de90..c056bb2cc 100644 --- a/python/mrc/_pymrc/tests/CMakeLists.txt +++ b/python/mrc/_pymrc/tests/CMakeLists.txt @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,6 +24,7 @@ add_executable(test_pymrc test_asyncio_runnable.cpp test_codable_pyobject.cpp test_executor.cpp + test_json_values.cpp test_main.cpp test_object_cache.cpp test_pickle_wrapper.cpp diff --git a/python/mrc/_pymrc/tests/test_json_values.cpp b/python/mrc/_pymrc/tests/test_json_values.cpp new file mode 100644 index 000000000..b6ad784f4 --- /dev/null +++ b/python/mrc/_pymrc/tests/test_json_values.cpp @@ -0,0 +1,544 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "test_pymrc.hpp" + +#include "pymrc/types.hpp" +#include "pymrc/utilities/json_values.hpp" + +#include <gtest/gtest.h> +#include <nlohmann/json.hpp> +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> // IWYU pragma: keep + +#include <array> +#include <cstddef> // for size_t +#include <initializer_list> // for initializer_list +#include <stdexcept> +#include <string> +#include <utility> // for pair +#include <vector> +// We already included pybind11.h don't need these others +// IWYU pragma: no_include <pybind11/cast.h> +// IWYU pragma: no_include <pybind11/eval.h> +// IWYU pragma: no_include <pybind11/pytypes.h> + +namespace py = pybind11; +using namespace mrc::pymrc; +using namespace std::string_literals; +using namespace pybind11::literals; // to bring in the `_a` literal + +PYMRC_TEST_CLASS(JSONValues); + +py::dict mk_py_dict() +{ + // return a simple python dict with a nested dict, a list, an integer, and a float + std::array<std::string, 3> alphabet = {"a", "b", "c"}; + return py::dict("this"_a = py::dict("is"_a = "a test"s), + "alphabet"_a = py::cast(alphabet), + "ncc"_a = 1701, + "cost"_a = 47.47); +} + +nlohmann::json mk_json() +{ + // return a simple json object comparable to that returned by mk_py_dict + return {{"this", {{"is", "a test"}}}, {"alphabet", {"a", "b", "c"}}, {"ncc", 1701}, {"cost", 47.47}}; +} + +py::object mk_decimal(const std::string& value = "1.0"s) +{ + // return a Python decimal.Decimal object, as a simple object without a supported JSON serialization + return py::module_::import("decimal").attr("Decimal")(value); +} + +TEST_F(TestJSONValues, DefaultConstructor) +{ + JSONValues j; + + EXPECT_EQ(j.to_json(JSONValues::stringify), nlohmann::json()); + EXPECT_TRUE(j.to_python().is_none()); +} + +TEST_F(TestJSONValues, ToPythonSerializable) +{ + auto py_dict = mk_py_dict(); + + JSONValues j{py_dict}; + auto result = j.to_python(); + + EXPECT_TRUE(result.equal(py_dict)); + EXPECT_FALSE(result.is(py_dict)); // Ensure we actually serialized the object and not stored it +} + +TEST_F(TestJSONValues, ToPythonFromJSON) +{ + py::dict py_expected_results = mk_py_dict(); + + nlohmann::json json_input = mk_json(); + JSONValues j{json_input}; + auto result = j.to_python(); + + EXPECT_TRUE(result.equal(py_expected_results)); +} + +TEST_F(TestJSONValues, ToJSONFromPython) +{ + auto expected_results = mk_json(); + + py::dict py_input = mk_py_dict(); + + JSONValues j{py_input}; + auto result = j.to_json(JSONValues::stringify); + + EXPECT_EQ(result, expected_results); +} + +TEST_F(TestJSONValues, ToJSONFromPythonUnserializable) +{ + std::string dec_val{"2.2"}; + auto expected_results = mk_json(); + expected_results["other"] = dec_val; + + py::dict py_input = mk_py_dict(); + py_input["other"] = mk_decimal(dec_val); + + JSONValues j{py_input}; + EXPECT_EQ(j.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, ToJSONFromJSON) +{ + JSONValues j{mk_json()}; + auto result = j.to_json(JSONValues::stringify); + + EXPECT_EQ(result, mk_json()); +} + +TEST_F(TestJSONValues, ToPythonRootUnserializable) +{ + py::object py_dec = mk_decimal(); + + JSONValues j{py_dec}; + auto result = j.to_python(); + + EXPECT_TRUE(result.equal(py_dec)); + EXPECT_TRUE(result.is(py_dec)); // Ensure we stored the object +} + +TEST_F(TestJSONValues, ToPythonSimpleDict) +{ + py::object py_dec = mk_decimal(); + py::dict py_dict; + py_dict[py::str("test"s)] = py_dec; + + JSONValues j{py_dict}; + py::dict result = j.to_python(); + + EXPECT_TRUE(result.equal(py_dict)); + EXPECT_FALSE(result.is(py_dict)); // Ensure we actually serialized the dict and not stored it + + py::object result_dec = result["test"]; + EXPECT_TRUE(result_dec.is(py_dec)); // Ensure we stored the decimal object +} + +TEST_F(TestJSONValues, ToPythonNestedDictUnserializable) +{ + // decimal.Decimal is not serializable + py::object py_dec1 = mk_decimal("1.1"); + py::object py_dec2 = mk_decimal("1.2"); + py::object py_dec3 = mk_decimal("1.3"); + + std::vector<py::object> py_values = {py::cast(1), py::cast(2), py_dec3, py::cast(4)}; + py::list py_list = py::cast(py_values); + + // Test with object in a nested dict + py::dict py_dict("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = py_dec1))), + "other"_a = py_dec2, + "nested_list"_a = py_list); + + JSONValues j{py_dict}; + auto result = j.to_python(); + EXPECT_TRUE(result.equal(py_dict)); + EXPECT_FALSE(result.is(py_dict)); // Ensure we actually serialized the object and not stored it + + // Individual Decimal instances shoudl be stored and thus pass an `is` test + py::object result_dec1 = result["a"]["b"]["c"]["d"]; + EXPECT_TRUE(result_dec1.is(py_dec1)); + + py::object result_dec2 = result["other"]; + EXPECT_TRUE(result_dec2.is(py_dec2)); + + py::list nested_list = result["nested_list"]; + py::object result_dec3 = nested_list[2]; + EXPECT_TRUE(result_dec3.is(py_dec3)); +} + +TEST_F(TestJSONValues, ToPythonList) +{ + py::object py_dec = mk_decimal("1.1"s); + + std::vector<py::object> py_values = {py::cast(1), py::cast(2), py_dec, py::cast(4)}; + py::list py_list = py::cast(py_values); + + JSONValues j{py_list}; + py::list result = j.to_python(); + EXPECT_TRUE(result.equal(py_list)); + py::object result_dec = result[2]; + EXPECT_TRUE(result_dec.is(py_dec)); +} + +TEST_F(TestJSONValues, ToPythonMultipleTypes) +{ + // Test with miultiple types not json serializable: module, class, function, generator + py::object py_mod = py::module_::import("decimal"); + py::object py_cls = py_mod.attr("Decimal"); + py::object globals = py::globals(); + py::exec( + R"( + def gen_fn(): + yield 1 + )", + globals); + + py::object py_fn = globals["gen_fn"]; + py::object py_gen = py_fn(); + + std::vector<std::pair<std::size_t, py::object>> expected_list_objs = {{1, py_mod}, + {3, py_cls}, + {5, py_fn}, + {7, py_gen}}; + + std::vector<py::object> py_values = + {py::cast(0), py_mod, py::cast(2), py_cls, py::cast(4), py_fn, py::cast(6), py_gen}; + py::list py_list = py::cast(py_values); + + std::vector<std::pair<std::string, py::object>> expected_dict_objs = {{"module", py_mod}, + {"class", py_cls}, + {"function", py_fn}, + {"generator", py_gen}}; + + // Test with object in a nested dict + py::dict py_dict("module"_a = py_mod, + "class"_a = py_cls, + "function"_a = py_fn, + "generator"_a = py_gen, + "nested_list"_a = py_list); + + JSONValues j{py_dict}; + auto result = j.to_python(); + EXPECT_TRUE(result.equal(py_dict)); + EXPECT_FALSE(result.is(py_dict)); // Ensure we actually serialized the object and not stored it + + for (const auto& [key, value] : expected_dict_objs) + { + py::object result_value = result[key.c_str()]; + EXPECT_TRUE(result_value.is(value)); + } + + py::list nested_list = result["nested_list"]; + for (const auto& [index, value] : expected_list_objs) + { + py::object result_value = nested_list[index]; + EXPECT_TRUE(result_value.is(value)); + } +} + +TEST_F(TestJSONValues, NumUnserializable) +{ + { + JSONValues j{mk_json()}; + EXPECT_EQ(j.num_unserializable(), 0); + EXPECT_FALSE(j.has_unserializable()); + } + { + JSONValues j{mk_py_dict()}; + EXPECT_EQ(j.num_unserializable(), 0); + EXPECT_FALSE(j.has_unserializable()); + } + { + // Test with object in a nested dict + py::object py_dec = mk_decimal(); + { + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = py_dec))), "other"_a = 2); + + JSONValues j{d}; + EXPECT_EQ(j.num_unserializable(), 1); + EXPECT_TRUE(j.has_unserializable()); + } + { + // Storing the same object twice should count twice + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = py_dec))), "other"_a = py_dec); + + JSONValues j{d}; + EXPECT_EQ(j.num_unserializable(), 2); + EXPECT_TRUE(j.has_unserializable()); + } + { + py::object py_dec2 = mk_decimal("2.0"); + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = py_dec, "e"_a = py_dec2))), + "other"_a = py_dec); + + JSONValues j{d}; + EXPECT_EQ(j.num_unserializable(), 3); + EXPECT_TRUE(j.has_unserializable()); + } + } +} + +TEST_F(TestJSONValues, SetValueNewKeyJSON) +{ + // Set to new key that doesn't exist + auto expected_results = mk_json(); + expected_results["other"] = mk_json(); + + JSONValues values{mk_json()}; + auto new_values = values.set_value("/other", mk_json()); + EXPECT_EQ(new_values.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, SetValueExistingKeyJSON) +{ + // Set to existing key + auto expected_results = mk_json(); + expected_results["this"] = mk_json(); + + JSONValues values{mk_json()}; + auto new_values = values.set_value("/this", mk_json()); + EXPECT_EQ(new_values.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, SetValueNewKeyJSONWithUnserializable) +{ + // Set to new key that doesn't exist + auto expected_results = mk_py_dict(); + expected_results["other"] = mk_py_dict(); + expected_results["dec"] = mk_decimal(); + + auto input = mk_py_dict(); + input["dec"] = mk_decimal(); + + JSONValues values{input}; + auto new_values = values.set_value("/other", mk_json()); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, SetValueExistingKeyJSONWithUnserializable) +{ + // Set to existing key + auto expected_results = mk_py_dict(); + expected_results["dec"] = mk_decimal(); + expected_results["this"] = mk_py_dict(); + + auto input = mk_py_dict(); + input["dec"] = mk_decimal(); + + JSONValues values{input}; + auto new_values = values.set_value("/this", mk_json()); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, SetValueNewKeyPython) +{ + // Set to new key that doesn't exist + auto expected_results = mk_py_dict(); + expected_results["other"] = mk_decimal(); + + JSONValues values{mk_json()}; + auto new_values = values.set_value("/other", mk_decimal()); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, SetValueNestedUnsupportedPython) +{ + JSONValues values{mk_json()}; + EXPECT_THROW(values.set_value("/other/nested", mk_decimal()), py::error_already_set); +} + +TEST_F(TestJSONValues, SetValueNestedUnsupportedJSON) +{ + JSONValues values{mk_json()}; + EXPECT_THROW(values.set_value("/other/nested", nlohmann::json(1.0)), nlohmann::json::out_of_range); +} + +TEST_F(TestJSONValues, SetValueExistingKeyPython) +{ + // Set to existing key + auto expected_results = mk_py_dict(); + expected_results["this"] = mk_decimal(); + + JSONValues values{mk_json()}; + auto new_values = values.set_value("/this", mk_decimal()); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, SetValueNewKeyJSONDefaultConstructed) +{ + nlohmann::json expected_results{{"other", mk_json()}}; + + JSONValues values; + auto new_values = values.set_value("/other", mk_json()); + EXPECT_EQ(new_values.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, SetValueJSONValues) +{ + // Set to new key that doesn't exist + auto expected_results = mk_json(); + expected_results["other"] = mk_json(); + + JSONValues values1{mk_json()}; + JSONValues values2{mk_json()}; + auto new_values = values1.set_value("/other", values2); + EXPECT_EQ(new_values.to_json(JSONValues::stringify), expected_results); +} + +TEST_F(TestJSONValues, SetValueJSONValuesWithUnserializable) +{ + // Set to new key that doesn't exist + auto expected_results = mk_py_dict(); + expected_results["other"] = py::dict("dec"_a = mk_decimal()); + + JSONValues values1{mk_json()}; + + auto input_dict = py::dict("dec"_a = mk_decimal()); + JSONValues values2{input_dict}; + + auto new_values = values1.set_value("/other", values2); + EXPECT_TRUE(new_values.to_python().equal(expected_results)); +} + +TEST_F(TestJSONValues, GetJSON) +{ + using namespace nlohmann; + const auto json_doc = mk_json(); + std::vector<std::string> paths = {"/", "/this", "/this/is", "/alphabet", "/ncc", "/cost"}; + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + json::json_pointer jp; + if (path != "/") + { + jp = json::json_pointer(path); + } + + EXPECT_TRUE(json_doc.contains(jp)) << "Path: '" << path << "' not found in json"; + EXPECT_EQ(value.get_json(path, JSONValues::stringify), json_doc[jp]); + } + } +} + +TEST_F(TestJSONValues, GetJSONError) +{ + std::vector<std::string> paths = {"/doesntexist", "/this/fake"}; + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + EXPECT_THROW(value.get_json(path, JSONValues::stringify), std::runtime_error); + } + } +} + +TEST_F(TestJSONValues, GetPython) +{ + const auto py_dict = mk_py_dict(); + + // <path, expected_result> + std::vector<std::pair<std::string, py::object>> tests = {{"/", py_dict}, + {"/this", py::dict("is"_a = "a test"s)}, + {"/this/is", py::str("a test"s)}, + {"/alphabet", py_dict["alphabet"]}, + {"/ncc", py::int_(1701)}, + {"/cost", py::float_(47.47)}}; + + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& p : tests) + { + const auto& path = p.first; + const auto& expected_result = p.second; + EXPECT_TRUE(value.get_python(path).equal(expected_result)); + } + } +} + +TEST_F(TestJSONValues, GetPythonError) +{ + std::vector<std::string> paths = {"/doesntexist", "/this/fake"}; + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + EXPECT_THROW(value.get_python(path), std::runtime_error) << "Expected failure with path: '" << path << "'"; + } + } +} + +TEST_F(TestJSONValues, SubscriptOpt) +{ + using namespace nlohmann; + const auto json_doc = mk_json(); + std::vector<std::string> values = {"", "this", "this/is", "alphabet", "ncc", "cost"}; + std::vector<std::string> paths; + for (const auto& value : values) + { + paths.push_back(value); + paths.push_back("/" + value); + } + + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + auto jv = value[path]; + + json::json_pointer jp; + if (!path.empty() && path != "/") + { + std::string json_path = path; + if (json_path[0] != '/') + { + json_path = "/"s + json_path; + } + + jp = json::json_pointer(json_path); + } + + EXPECT_EQ(jv.to_json(JSONValues::stringify), json_doc[jp]); + } + } +} + +TEST_F(TestJSONValues, SubscriptOptError) +{ + std::vector<std::string> paths = {"/doesntexist", "/this/fake"}; + for (const auto& value : {JSONValues{mk_json()}, JSONValues{mk_py_dict()}}) + { + for (const auto& path : paths) + { + EXPECT_THROW(value[path], std::runtime_error); + } + } +} + +TEST_F(TestJSONValues, Stringify) +{ + auto dec_val = mk_decimal("2.2"s); + EXPECT_EQ(JSONValues::stringify(dec_val, "/"s), nlohmann::json("2.2"s)); +} diff --git a/python/mrc/_pymrc/tests/test_utils.cpp b/python/mrc/_pymrc/tests/test_utils.cpp index e518bbd87..7606b6502 100644 --- a/python/mrc/_pymrc/tests/test_utils.cpp +++ b/python/mrc/_pymrc/tests/test_utils.cpp @@ -32,6 +32,7 @@ #include <array> #include <cfloat> #include <climits> +#include <cstddef> // for size_t #include <map> #include <memory> #include <string> @@ -159,6 +160,47 @@ TEST_F(TestUtils, CastFromPyObjectSerializeErrors) EXPECT_THROW(pymrc::cast_from_pyobject(d), py::type_error); } +TEST_F(TestUtils, CastFromPyObjectUnserializableHandlerFn) +{ + // Test to verify that cast_from_pyobject calls the unserializable_handler_fn when encountering an object that it + // does not know how to serialize + + bool handler_called{false}; + pymrc::unserializable_handler_fn_t handler_fn = [&handler_called](const py::object& source, + const std::string& path) { + handler_called = true; + return nlohmann::json(py::cast<float>(source)); + }; + + // decimal.Decimal is not serializable + py::object Decimal = py::module_::import("decimal").attr("Decimal"); + py::object o = Decimal("1.0"); + EXPECT_EQ(pymrc::cast_from_pyobject(o, handler_fn), nlohmann::json(1.0)); + EXPECT_TRUE(handler_called); +} + +TEST_F(TestUtils, CastFromPyObjectUnserializableHandlerFnNestedObj) +{ + std::size_t handler_call_count{0}; + + // Test with object in a nested dict + pymrc::unserializable_handler_fn_t handler_fn = [&handler_call_count](const py::object& source, + const std::string& path) { + ++handler_call_count; + return nlohmann::json(py::cast<float>(source)); + }; + + // decimal.Decimal is not serializable + py::object Decimal = py::module_::import("decimal").attr("Decimal"); + py::object o = Decimal("1.0"); + + py::dict d("a"_a = py::dict("b"_a = py::dict("c"_a = py::dict("d"_a = o))), "other"_a = o); + nlohmann::json expected_results = {{"a", {{"b", {{"c", {{"d", 1.0}}}}}}}, {"other", 1.0}}; + + EXPECT_EQ(pymrc::cast_from_pyobject(d, handler_fn), expected_results); + EXPECT_EQ(handler_call_count, 2); +} + TEST_F(TestUtils, GetTypeName) { // invalid objects should return an empty string