Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: threadpool configuration #2671

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/zarr/core/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def _get_executor() -> ThreadPoolExecutor:
global _executor
if not _executor:
max_workers = config.get("threading.max_workers", None)
print(max_workers)
# if max_workers is not None and max_workers > 0:
# raise ValueError(max_workers)
logger.debug("Creating Zarr ThreadPoolExecutor with max_workers=%s", max_workers)
_executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="zarr_pool")
_get_loop().set_default_executor(_executor)
return _executor
Expand Down Expand Up @@ -118,6 +116,9 @@ def sync(
# NB: if the loop is not running *yet*, it is OK to submit work
# and we will wait for it
loop = _get_loop()
if _executor is None and config.get("threading.max_workers", None) is not None:
# trigger executor creation and attach to loop
_ = _get_executor()
if not isinstance(loop, asyncio.AbstractEventLoop):
raise TypeError(f"loop cannot be of type {type(loop)}")
if loop.is_closed():
Expand Down Expand Up @@ -153,6 +154,7 @@ def _get_loop() -> asyncio.AbstractEventLoop:
# repeat the check just in case the loop got filled between the
# previous two calls from another thread
if loop[0] is None:
logger.debug("Creating Zarr event loop")
new_loop = asyncio.new_event_loop()
loop[0] = new_loop
iothread[0] = threading.Thread(target=new_loop.run_forever, name="zarr_io")
Expand Down
18 changes: 14 additions & 4 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_get_lock,
_get_loop,
cleanup_resources,
loop,
sync,
)
from zarr.storage import MemoryStore
Expand Down Expand Up @@ -148,11 +149,20 @@ def test_open_positional_args_deprecate():


@pytest.mark.parametrize("workers", [None, 1, 2])
def test_get_executor(clean_state, workers) -> None:
def test_threadpool_executor(clean_state, workers: int | None) -> None:
with zarr.config.set({"threading.max_workers": workers}):
e = _get_executor()
if workers is not None and workers != 0:
assert e._max_workers == workers
_ = zarr.zeros(shape=(1,)) # trigger executor creation
assert loop != [None] # confirm loop was created
if workers is None:
# confirm no executor was created if no workers were specified
# (this is the default behavior)
assert loop[0]._default_executor is None
else:
# confirm executor was created and attached to loop as the default executor
# note: python doesn't have a direct way to get the default executor so we
# use the private attribute
assert _get_executor() is loop[0]._default_executor
assert _get_executor()._max_workers == workers


def test_cleanup_resources_idempotent() -> None:
Expand Down
Loading