From 4b594eb3df8cdb22c04c2d5c8ce0b026a45b110b Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Wed, 12 May 2021 13:43:59 +0300 Subject: [PATCH] Allow passing an existing ctx to scope_ctx --- tiledb/ctx.py | 11 ++++++++--- tiledb/tests/test_libtiledb.py | 34 +++++++++++++++++++++++----------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/tiledb/ctx.py b/tiledb/ctx.py index 1b3ffa16aa..f4755d8d40 100644 --- a/tiledb/ctx.py +++ b/tiledb/ctx.py @@ -7,15 +7,20 @@ @contextmanager -def scope_ctx(config=None): +def scope_ctx(ctx_or_config=None): """ Context manager for setting the default `tiledb.Ctx` context variable when entering a block of code and restoring it to its previous value when exiting the block. - :param config: :py:class:`tiledb.Config` object or dictionary with config parameters. + :param ctx_or_config: :py:class:`tiledb.Ctx` or :py:class:`tiledb.Config` object + or dictionary with config parameters. :return: Ctx """ - token = _ctx_var.set(tiledb.Ctx(config)) + if not isinstance(ctx_or_config, tiledb.Ctx): + ctx = tiledb.Ctx(ctx_or_config) + else: + ctx = ctx_or_config + token = _ctx_var.set(ctx) try: yield _ctx_var.get() finally: diff --git a/tiledb/tests/test_libtiledb.py b/tiledb/tests/test_libtiledb.py index 5da63bd8d4..f6daa7d2af 100644 --- a/tiledb/tests/test_libtiledb.py +++ b/tiledb/tests/test_libtiledb.py @@ -3907,17 +3907,29 @@ def test_default_ctx(self): def test_scope_ctx(self): key = "sm.tile_cache_size" ctx0 = tiledb.default_ctx() - assert ctx0.config()[key] == "10000000" - with tiledb.scope_ctx({key: 42}) as ctx1: - assert ctx1 is tiledb.default_ctx() - assert ctx1.config()[key] == "42" - with tiledb.scope_ctx({key: 6712}) as ctx2: - assert ctx2 is tiledb.default_ctx() - assert ctx2.config()[key] == "6712" - assert ctx1 is tiledb.default_ctx() - assert ctx1.config()[key] == "42" - assert ctx0 is tiledb.default_ctx() - assert ctx0.config()[key] == "10000000" + new_config_dict = {key: 42} + new_config = tiledb.Config({key: 78}) + new_ctx = tiledb.Ctx({key: 61}) + + assert tiledb.default_ctx() is ctx0 + assert tiledb.default_ctx().config()[key] == "10000000" + + with tiledb.scope_ctx(new_config_dict) as ctx1: + assert tiledb.default_ctx() is ctx1 + assert tiledb.default_ctx().config()[key] == "42" + with tiledb.scope_ctx(new_config) as ctx2: + assert tiledb.default_ctx() is ctx2 + assert tiledb.default_ctx().config()[key] == "78" + with tiledb.scope_ctx(new_ctx) as ctx3: + assert tiledb.default_ctx() is ctx3 is new_ctx + assert tiledb.default_ctx().config()[key] == "61" + assert tiledb.default_ctx() is ctx2 + assert tiledb.default_ctx().config()[key] == "78" + assert tiledb.default_ctx() is ctx1 + assert tiledb.default_ctx().config()[key] == "42" + + assert tiledb.default_ctx() is ctx0 + assert tiledb.default_ctx().config()[key] == "10000000" def test_init_config(self): self.assertEqual(-1, init_test_wrapper())