From afc1c2b3e8369e53f1ed94a09cf3d082704704e5 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Fri, 1 Dec 2023 16:54:03 -0600 Subject: [PATCH 01/14] Use LOCALS dict for context --- tests/conftest.py | 15 +++++++-------- tests/schema.py | 10 ++++++---- tests/schema_advanced.py | 6 +++--- tests/schema_simple.py | 10 ++++++---- tests/test_relation_u.py | 2 +- 5 files changed, 23 insertions(+), 20 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e13a1363..109bda6c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,6 @@ import pytest from . import PREFIX, schema, schema_simple, schema_advanced -namespace = locals() - @pytest.fixture(scope="session") def connection_root(): @@ -64,11 +62,12 @@ def connection_test(connection_root): connection.close() -@pytest.fixture(scope="module") +@pytest.fixture def schema_any(connection_test): schema_any = dj.Schema( - PREFIX + "_test1", schema.__dict__, connection=connection_test + PREFIX + "_test1", schema.LOCALS_ANY, connection=connection_test ) + assert schema.LOCALS_ANY, "LOCALS_ANY is empty" schema_any(schema.TTest) schema_any(schema.TTest2) schema_any(schema.TTest3) @@ -109,10 +108,10 @@ def schema_any(connection_test): schema_any.drop() -@pytest.fixture(scope="module") +@pytest.fixture def schema_simp(connection_test): schema = dj.Schema( - PREFIX + "_relational", schema_simple.__dict__, connection=connection_test + PREFIX + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test ) schema(schema_simple.IJ) schema(schema_simple.JI) @@ -136,10 +135,10 @@ def schema_simp(connection_test): schema.drop() -@pytest.fixture(scope="module") +@pytest.fixture def schema_adv(connection_test): schema = dj.Schema( - PREFIX + "_advanced", schema_advanced.__dict__, connection=connection_test + PREFIX + "_advanced", schema_advanced.LOCALS_ADVANCED, connection=connection_test ) schema(schema_advanced.Person) schema(schema_advanced.Parent) diff --git a/tests/schema.py b/tests/schema.py index 864c5efe..7bc4dccd 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -7,8 +7,6 @@ import datajoint as dj import inspect -LOCALS_ANY = locals() - class TTest(dj.Lookup): """ @@ -33,7 +31,7 @@ class TTest2(dj.Manual): class TTest3(dj.Manual): definition = """ - key : int + key : int --- value : varchar(300) """ @@ -41,7 +39,7 @@ class TTest3(dj.Manual): class NullableNumbers(dj.Manual): definition = """ - key : int + key : int --- fvalue = null : float dvalue = null : double @@ -450,3 +448,7 @@ class Longblob(dj.Manual): --- data: longblob """ + + +LOCALS_ANY = {k: v for k, v in locals().items() if inspect.isclass(v)} + diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py index 104e4d1e..726fc819 100644 --- a/tests/schema_advanced.py +++ b/tests/schema_advanced.py @@ -1,7 +1,5 @@ import datajoint as dj - -LOCALS_ADVANCED = locals() - +import inspect class Person(dj.Manual): definition = """ @@ -135,3 +133,5 @@ class GlobalSynapse(dj.Manual): -> Cell.proj(pre_slice="slice", pre_cell="cell") -> Cell.proj(post_slice="slice", post_cell="cell") """ + +LOCALS_ADVANCED = {k: v for k, v in locals().items() if inspect.isclass(v)} diff --git a/tests/schema_simple.py b/tests/schema_simple.py index bb5c21ff..7742ba1c 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -9,8 +9,7 @@ import faker import numpy as np from datetime import date, timedelta - -LOCALS_SIMPLE = locals() +import inspect class IJ(dj.Lookup): @@ -237,8 +236,8 @@ class ReservedWord(dj.Manual): # Test of SQL reserved words key : int --- - in : varchar(25) - from : varchar(25) + in : varchar(25) + from : varchar(25) int : int select : varchar(25) """ @@ -260,3 +259,6 @@ class OutfitPiece(dj.Part, dj.Lookup): piece: varchar(20) """ contents = [(0, "jeans"), (0, "sneakers"), (0, "polo")] + + +LOCALS_SIMPLE = {k: v for k, v in locals().items() if inspect.isclass(v)} diff --git a/tests/test_relation_u.py b/tests/test_relation_u.py index d225bccb..3494f4bf 100644 --- a/tests/test_relation_u.py +++ b/tests/test_relation_u.py @@ -17,7 +17,7 @@ def setup_class(request, schema_any): request.cls.img = Image() request.cls.trash = UberTrash() - +@pytest.mark.skip(reason="temporary") class TestU: """ Test tables: insert, delete From a59466e23328d0906737d3cdb1830662a92aefd5 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Fri, 1 Dec 2023 16:59:03 -0600 Subject: [PATCH 02/14] Clean up imports for test_blob --- tests/schema.py | 2 +- tests/schema_advanced.py | 1 + tests/schema_simple.py | 1 + tests/test_blob.py | 2 +- 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/schema.py b/tests/schema.py index 7bc4dccd..13ff945a 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -451,4 +451,4 @@ class Longblob(dj.Manual): LOCALS_ANY = {k: v for k, v in locals().items() if inspect.isclass(v)} - +__all__ = list(LOCALS_ANY.keys()) diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py index 726fc819..f925e497 100644 --- a/tests/schema_advanced.py +++ b/tests/schema_advanced.py @@ -135,3 +135,4 @@ class GlobalSynapse(dj.Manual): """ LOCALS_ADVANCED = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_ADVANCED.keys()) diff --git a/tests/schema_simple.py b/tests/schema_simple.py index 7742ba1c..addd70c2 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -262,3 +262,4 @@ class OutfitPiece(dj.Part, dj.Lookup): LOCALS_SIMPLE = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_SIMPLE.keys()) diff --git a/tests/test_blob.py b/tests/test_blob.py index 23de7be7..e5548898 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -7,7 +7,7 @@ from datajoint.blob import pack, unpack from numpy.testing import assert_array_equal from pytest import approx -from .schema import * +from .schema import Longblob def test_pack(): From 980e818c7e373561467fe31f679e10d324536859 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Fri, 1 Dec 2023 21:57:13 -0600 Subject: [PATCH 03/14] Clean up imports for test_blob_matlab --- tests/test_blob_matlab.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_blob_matlab.py b/tests/test_blob_matlab.py index 06154b1f..575e6b0b 100644 --- a/tests/test_blob_matlab.py +++ b/tests/test_blob_matlab.py @@ -16,15 +16,15 @@ class Blob(dj.Manual): """ -@pytest.fixture(scope="module") +@pytest.fixture def schema(connection_test): - schema = dj.Schema(PREFIX + "_test1", locals(), connection=connection_test) + schema = dj.Schema(PREFIX + "_test1", dict(Blob=Blob), connection=connection_test) schema(Blob) yield schema schema.drop() -@pytest.fixture(scope="module") +@pytest.fixture def insert_blobs_func(schema): def insert_blobs(): """ @@ -63,7 +63,7 @@ def insert_blobs(): yield insert_blobs -@pytest.fixture(scope="class") +@pytest.fixture def setup_class(schema, insert_blobs_func): assert not dj.config["safemode"], "safemode must be disabled" Blob().delete() From 4ffbca2011749dc7007e05ae65977d1be1a62620 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Fri, 1 Dec 2023 22:56:06 -0600 Subject: [PATCH 04/14] Clean up recently migrated pytests --- tests/test_connection.py | 2 +- tests/test_erd.py | 4 +- tests/test_json.py | 415 +++++++++++++++++----------------- tests/test_plugin.py | 8 +- tests/test_relation_u.py | 43 ++-- tests/test_schema_keywords.py | 2 +- tests/test_utils.py | 8 - 7 files changed, 239 insertions(+), 243 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 795d3761..a73677ae 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -12,7 +12,7 @@ @pytest.fixture def schema(connection_test): - schema = dj.Schema(PREFIX + "_transactions", locals(), connection=connection_test) + schema = dj.Schema(PREFIX + "_transactions", context=dict(), connection=connection_test) yield schema schema.drop() diff --git a/tests/test_erd.py b/tests/test_erd.py index f1274ec1..aebf62ea 100644 --- a/tests/test_erd.py +++ b/tests/test_erd.py @@ -45,13 +45,13 @@ def test_erd_algebra(schema_simp): def test_repr_svg(schema_adv): - erd = dj.ERD(schema_adv, context=locals()) + erd = dj.ERD(schema_adv, context=dict()) svg = erd._repr_svg_() assert svg.startswith("") def test_make_image(schema_simp): - erd = dj.ERD(schema_simp, context=locals()) + erd = dj.ERD(schema_simp, context=dict()) img = erd.make_image() assert img.ndim == 3 and img.shape[2] in (3, 4) diff --git a/tests/test_json.py b/tests/test_json.py index 760475a1..37a33c82 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,3 +1,4 @@ +import pytest import inspect from datajoint.declare import declare import datajoint as dj @@ -5,213 +6,215 @@ from packaging.version import Version from . import PREFIX -if Version(dj.conn().query("select @@version;").fetchone()[0]) >= Version("8.0.0"): - schema = dj.Schema(PREFIX + "_json") - Team = None - - def setup(): - global Team - - @schema - class Team(dj.Lookup): - definition = """ - name: varchar(40) - --- - car=null: json - unique index(car.name:char(20)) - uniQue inDex ( name, car.name:char(20), (json_value(`car`, _utf8mb4'$.length' returning decimal(4, 1))) ) - """ - contents = [ - ( - "engineering", +if Version(dj.conn().query("select @@version;").fetchone()[0]) < Version("8.0.0"): + pytest.skip("skipping windows-only tests", allow_module_level=True) + + +class Team(dj.Lookup): + definition = """ + name: varchar(40) + --- + car=null: json + unique index(car.name:char(20)) + uniQue inDex ( name, car.name:char(20), (json_value(`car`, _utf8mb4'$.length' returning decimal(4, 1))) ) + """ + contents = [ + ( + "engineering", + { + "name": "Rever", + "length": 20.5, + "inspected": True, + "tire_pressure": [32, 31, 33, 34], + "headlights": [ { - "name": "Rever", - "length": 20.5, - "inspected": True, - "tire_pressure": [32, 31, 33, 34], - "headlights": [ - { - "side": "left", - "hyper_white": None, - }, - { - "side": "right", - "hyper_white": None, - }, - ], + "side": "left", + "hyper_white": None, }, - ), - ( - "business", { - "name": "Chaching", - "length": 100, - "safety_inspected": False, - "tire_pressure": [34, 30, 27, 32], - "headlights": [ - { - "side": "left", - "hyper_white": True, - }, - { - "side": "right", - "hyper_white": True, - }, - ], + "side": "right", + "hyper_white": None, }, - ), - ( - "marketing", - None, - ), - ] - - def teardown(): - schema.drop() - - def test_insert_update(): - car = { - "name": "Discovery", - "length": 22.9, - "inspected": None, - "tire_pressure": [35, 36, 34, 37], - "headlights": [ - { - "side": "left", - "hyper_white": True, - }, - { - "side": "right", - "hyper_white": True, - }, - ], - } - - Team.insert1({"name": "research", "car": car}) - q = Team & {"name": "research"} - assert q.fetch1("car") == car - - car.update({"length": 23}) - Team.update1({"name": "research", "car": car}) - assert q.fetch1("car") == car - - try: - Team.insert1({"name": "hr", "car": car}) - raise Exception("Inserted non-unique car name.") - except dj.DataJointError: - pass - - q.delete_quick() - assert not q - - def test_describe(): - rel = Team() - context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 - - def test_restrict(): - # dict - assert (Team & {"car.name": "Chaching"}).fetch1("name") == "business" - - assert (Team & {"car.length": 20.5}).fetch1("name") == "engineering" - - assert (Team & {"car.inspected": "true"}).fetch1("name") == "engineering" - - assert (Team & {"car.inspected:unsigned": True}).fetch1("name") == "engineering" - - assert (Team & {"car.safety_inspected": "false"}).fetch1("name") == "business" - - assert (Team & {"car.safety_inspected:unsigned": False}).fetch1( - "name" - ) == "business" - - assert (Team & {"car.headlights[0].hyper_white": None}).fetch( - "name", order_by="name", as_dict=True - ) == [ - {"name": "engineering"}, - {"name": "marketing"}, - ] # if entire record missing, JSON key is missing, or value set to JSON null - - assert (Team & {"car": None}).fetch1("name") == "marketing" - - assert (Team & {"car.tire_pressure": [34, 30, 27, 32]}).fetch1( - "name" - ) == "business" - - assert ( - Team & {"car.headlights[1]": {"side": "right", "hyper_white": True}} - ).fetch1("name") == "business" - - # sql operators - assert (Team & "`car`->>'$.name' LIKE '%ching%'").fetch1( - "name" - ) == "business", "Missing substring" - - assert (Team & "`car`->>'$.length' > 30").fetch1("name") == "business", "<= 30" - - assert ( - Team & "JSON_VALUE(`car`, '$.safety_inspected' RETURNING UNSIGNED) = 0" - ).fetch1("name") == "business", "Has `safety_inspected` set to `true`" - - assert (Team & "`car`->>'$.headlights[0].hyper_white' = 'null'").fetch1( - "name" - ) == "engineering", "Has 1st `headlight` with `hyper_white` not set to `null`" - - assert (Team & "`car`->>'$.inspected' IS NOT NULL").fetch1( - "name" - ) == "engineering", "Missing `inspected` key" - - assert (Team & "`car`->>'$.tire_pressure' = '[34, 30, 27, 32]'").fetch1( - "name" - ) == "business", "`tire_pressure` array did not match" - - assert ( - Team - & """`car`->>'$.headlights[1]' = '{"side": "right", "hyper_white": true}'""" - ).fetch1("name") == "business", "2nd `headlight` object did not match" - - def test_proj(): - # proj necessary since we need to rename indexed value into a proper attribute name - assert Team.proj(car_length="car.length").fetch( - as_dict=True, order_by="car_length" - ) == [ - {"name": "marketing", "car_length": None}, - {"name": "business", "car_length": "100"}, - {"name": "engineering", "car_length": "20.5"}, - ] - - assert Team.proj(car_length="car.length:decimal(4, 1)").fetch( - as_dict=True, order_by="car_length" - ) == [ - {"name": "marketing", "car_length": None}, - {"name": "engineering", "car_length": 20.5}, - {"name": "business", "car_length": 100.0}, - ] - - assert Team.proj( - car_width="JSON_VALUE(`car`, '$.length' RETURNING float) - 15" - ).fetch(as_dict=True, order_by="car_width") == [ - {"name": "marketing", "car_width": None}, - {"name": "engineering", "car_width": 5.5}, - {"name": "business", "car_width": 85.0}, - ] - - assert ( - (Team & {"name": "engineering"}).proj(car_tire_pressure="car.tire_pressure") - ).fetch1("car_tire_pressure") == "[32, 31, 33, 34]" - - assert np.array_equal( - Team.proj(car_inspected="car.inspected").fetch( - "car_inspected", order_by="name" - ), - np.array([None, "true", None]), - ) - - assert np.array_equal( - Team.proj(car_inspected="car.inspected:unsigned").fetch( - "car_inspected", order_by="name" - ), - np.array([None, 1, None]), - ) + ], + }, + ), + ( + "business", + { + "name": "Chaching", + "length": 100, + "safety_inspected": False, + "tire_pressure": [34, 30, 27, 32], + "headlights": [ + { + "side": "left", + "hyper_white": True, + }, + { + "side": "right", + "hyper_white": True, + }, + ], + }, + ), + ( + "marketing", + None, + ), + ] + + +@pytest.fixture +def schema(connection_test): + schema = dj.Schema(PREFIX + "_json", context=dict(), connection=connection_test) + schema(Team) + yield schema + schema.drop() + + +def test_insert_update(schema): + car = { + "name": "Discovery", + "length": 22.9, + "inspected": None, + "tire_pressure": [35, 36, 34, 37], + "headlights": [ + { + "side": "left", + "hyper_white": True, + }, + { + "side": "right", + "hyper_white": True, + }, + ], + } + + Team.insert1({"name": "research", "car": car}) + q = Team & {"name": "research"} + assert q.fetch1("car") == car + + car.update({"length": 23}) + Team.update1({"name": "research", "car": car}) + assert q.fetch1("car") == car + + try: + Team.insert1({"name": "hr", "car": car}) + raise Exception("Inserted non-unique car name.") + except dj.DataJointError: + pass + + q.delete_quick() + assert not q + +def test_describe(schema): + rel = Team() + context = inspect.currentframe().f_globals + s1 = declare(rel.full_table_name, rel.definition, context) + s2 = declare(rel.full_table_name, rel.describe(), context) + assert s1 == s2 + +def test_restrict(schema): + # dict + assert (Team & {"car.name": "Chaching"}).fetch1("name") == "business" + + assert (Team & {"car.length": 20.5}).fetch1("name") == "engineering" + + assert (Team & {"car.inspected": "true"}).fetch1("name") == "engineering" + + assert (Team & {"car.inspected:unsigned": True}).fetch1("name") == "engineering" + + assert (Team & {"car.safety_inspected": "false"}).fetch1("name") == "business" + + assert (Team & {"car.safety_inspected:unsigned": False}).fetch1( + "name" + ) == "business" + + assert (Team & {"car.headlights[0].hyper_white": None}).fetch( + "name", order_by="name", as_dict=True + ) == [ + {"name": "engineering"}, + {"name": "marketing"}, + ] # if entire record missing, JSON key is missing, or value set to JSON null + + assert (Team & {"car": None}).fetch1("name") == "marketing" + + assert (Team & {"car.tire_pressure": [34, 30, 27, 32]}).fetch1( + "name" + ) == "business" + + assert ( + Team & {"car.headlights[1]": {"side": "right", "hyper_white": True}} + ).fetch1("name") == "business" + + # sql operators + assert (Team & "`car`->>'$.name' LIKE '%ching%'").fetch1( + "name" + ) == "business", "Missing substring" + + assert (Team & "`car`->>'$.length' > 30").fetch1("name") == "business", "<= 30" + + assert ( + Team & "JSON_VALUE(`car`, '$.safety_inspected' RETURNING UNSIGNED) = 0" + ).fetch1("name") == "business", "Has `safety_inspected` set to `true`" + + assert (Team & "`car`->>'$.headlights[0].hyper_white' = 'null'").fetch1( + "name" + ) == "engineering", "Has 1st `headlight` with `hyper_white` not set to `null`" + + assert (Team & "`car`->>'$.inspected' IS NOT NULL").fetch1( + "name" + ) == "engineering", "Missing `inspected` key" + + assert (Team & "`car`->>'$.tire_pressure' = '[34, 30, 27, 32]'").fetch1( + "name" + ) == "business", "`tire_pressure` array did not match" + + assert ( + Team + & """`car`->>'$.headlights[1]' = '{"side": "right", "hyper_white": true}'""" + ).fetch1("name") == "business", "2nd `headlight` object did not match" + +def test_proj(schema): + # proj necessary since we need to rename indexed value into a proper attribute name + assert Team.proj(car_length="car.length").fetch( + as_dict=True, order_by="car_length" + ) == [ + {"name": "marketing", "car_length": None}, + {"name": "business", "car_length": "100"}, + {"name": "engineering", "car_length": "20.5"}, + ] + + assert Team.proj(car_length="car.length:decimal(4, 1)").fetch( + as_dict=True, order_by="car_length" + ) == [ + {"name": "marketing", "car_length": None}, + {"name": "engineering", "car_length": 20.5}, + {"name": "business", "car_length": 100.0}, + ] + + assert Team.proj( + car_width="JSON_VALUE(`car`, '$.length' RETURNING float) - 15" + ).fetch(as_dict=True, order_by="car_width") == [ + {"name": "marketing", "car_width": None}, + {"name": "engineering", "car_width": 5.5}, + {"name": "business", "car_width": 85.0}, + ] + + assert ( + (Team & {"name": "engineering"}).proj(car_tire_pressure="car.tire_pressure") + ).fetch1("car_tire_pressure") == "[32, 31, 33, 34]" + + assert np.array_equal( + Team.proj(car_inspected="car.inspected").fetch( + "car_inspected", order_by="name" + ), + np.array([None, "true", None]), + ) + + assert np.array_equal( + Team.proj(car_inspected="car.inspected:unsigned").fetch( + "car_inspected", order_by="name" + ), + np.array([None, 1, None]), + ) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index f70f4c2e..e4122411 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -1,3 +1,4 @@ +import pytest import datajoint.errors as djerr import datajoint.plugin as p import pkg_resources @@ -22,7 +23,8 @@ def test_normal_djerror(): assert e.__cause__ is None -def test_verified_djerror(category="connection"): +@pytest.mark.parametrize('category', ('connection', )) +def test_verified_djerror(category): try: curr_plugins = getattr(p, "{}_plugins".format(category)) setattr( @@ -39,8 +41,8 @@ def test_verified_djerror(category="connection"): def test_verified_djerror_type(): test_verified_djerror(category="type") - -def test_unverified_djerror(category="connection"): +@pytest.mark.parametrize('category', ('connection', )) +def test_unverified_djerror(category): try: curr_plugins = getattr(p, "{}_plugins".format(category)) setattr( diff --git a/tests/test_relation_u.py b/tests/test_relation_u.py index 3494f4bf..50997662 100644 --- a/tests/test_relation_u.py +++ b/tests/test_relation_u.py @@ -5,25 +5,24 @@ from .schema_simple import * -@pytest.fixture(scope="class") -def setup_class(request, schema_any): - request.cls.user = User() - request.cls.language = Language() - request.cls.subject = Subject() - request.cls.experiment = Experiment() - request.cls.trial = Trial() - request.cls.ephys = Ephys() - request.cls.channel = Ephys.Channel() - request.cls.img = Image() - request.cls.trash = UberTrash() - -@pytest.mark.skip(reason="temporary") class TestU: """ Test tables: insert, delete """ - def test_restriction(self, setup_class): + @classmethod + def setup_class(cls): + cls.user = User() + cls.language = Language() + cls.subject = Subject() + cls.experiment = Experiment() + cls.trial = Trial() + cls.ephys = Ephys() + cls.channel = Ephys.Channel() + cls.img = Image() + cls.trash = UberTrash() + + def test_restriction(self, schema_any): language_set = {s[1] for s in self.language.contents} rel = dj.U("language") & self.language assert list(rel.heading.names) == ["language"] @@ -35,15 +34,15 @@ def test_restriction(self, setup_class): assert list(rel.primary_key) == list((rel & "trial_id>3").primary_key) assert list((dj.U("start_time") & self.trial).primary_key) == ["start_time"] - def test_invalid_restriction(self, setup_class): + def test_invalid_restriction(self, schema_any): with raises(dj.DataJointError): result = dj.U("color") & dict(color="red") - def test_ineffective_restriction(self, setup_class): + def test_ineffective_restriction(self, schema_any): rel = self.language & dj.U("language") assert rel.make_sql() == self.language.make_sql() - def test_join(self, setup_class): + def test_join(self, schema_any): rel = self.experiment * dj.U("experiment_date") assert self.experiment.primary_key == ["subject_id", "experiment_id"] assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] @@ -52,16 +51,16 @@ def test_join(self, setup_class): assert self.experiment.primary_key == ["subject_id", "experiment_id"] assert rel.primary_key == self.experiment.primary_key + ["experiment_date"] - def test_invalid_join(self, setup_class): + def test_invalid_join(self, schema_any): with raises(dj.DataJointError): rel = dj.U("language") * dict(language="English") - def test_repr_without_attrs(self, setup_class): + def test_repr_without_attrs(self, schema_any): """test dj.U() display""" query = dj.U().aggr(Language, n="count(*)") repr(query) - def test_aggregations(self, setup_class): + def test_aggregations(self, schema_any): lang = Language() # test total aggregation on expression object n1 = dj.U().aggr(lang, n="count(*)").fetch1("n") @@ -73,13 +72,13 @@ def test_aggregations(self, setup_class): assert len(rel) == len(set(l[1] for l in Language.contents)) assert (rel & 'language="English"').fetch1("number_of_speakers") == 3 - def test_argmax(self, setup_class): + def test_argmax(self, schema_any): rel = TTest() # get the tuples corresponding to the maximum value mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" assert mx.fetch("value")[0] == max(rel.fetch("value")) - def test_aggr(self, setup_class, schema_simp): + def test_aggr(self, schema_any, schema_simp): rel = ArgmaxTest() amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") diff --git a/tests/test_schema_keywords.py b/tests/test_schema_keywords.py index c8b7d5a2..1cad98ef 100644 --- a/tests/test_schema_keywords.py +++ b/tests/test_schema_keywords.py @@ -33,7 +33,7 @@ class D(B): source = A -@pytest.fixture(scope="module") +@pytest.fixture def schema(connection_test): schema = dj.Schema(PREFIX + "_keywords", connection=connection_test) schema(A) diff --git a/tests/test_utils.py b/tests/test_utils.py index 936badb1..04325db5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,14 +6,6 @@ import pytest -def setup(): - pass - - -def teardown(): - pass - - def test_from_camel_case(): assert from_camel_case("AllGroups") == "all_groups" with pytest.raises(DataJointError): From 33e21cf0ade77a3ee912374fba1d1ea5217b9cba Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Mon, 4 Dec 2023 14:21:39 +0000 Subject: [PATCH 05/14] WIP test_adapted_attributes migration --- tests/__init__.py | 22 +++++-- tests/conftest.py | 85 +++++++++++++++++++++++- tests/schema_adapted.py | 61 +++++++++++++++++ tests/test_adapted_attributes.py | 108 +++++++++++++++++++++++++++++++ 4 files changed, 271 insertions(+), 5 deletions(-) create mode 100644 tests/schema_adapted.py create mode 100644 tests/test_adapted_attributes.py diff --git a/tests/__init__.py b/tests/__init__.py index de57f6ea..219f7f5c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,10 +3,24 @@ import pytest import os -PREFIX = "djtest" +PREFIX = os.environ.get("DJ_TEST_DB_PREFIX", "djtest") + +# Connection for testing +CONN_INFO = dict( + host=os.environ.get("DJ_TEST_HOST", "fakeservices.datajoint.io"), + user=os.environ.get("DJ_TEST_USER", "datajoint"), + password=os.environ.get("DJ_TEST_PASSWORD", "datajoint"), +) CONN_INFO_ROOT = dict( - host=os.getenv("DJ_HOST"), - user=os.getenv("DJ_USER"), - password=os.getenv("DJ_PASS"), + host=os.environ.get("DJ_HOST", "fakeservices.datajoint.io"), + user=os.environ.get("DJ_USER", "root"), + password=os.environ.get("DJ_PASS", "simple"), +) + +S3_CONN_INFO = dict( + endpoint=os.environ.get("S3_ENDPOINT", "fakeservices.datajoint.io"), + access_key=os.environ.get("S3_ACCESS_KEY", "datajoint"), + secret_key=os.environ.get("S3_SECRET_KEY", "datajoint"), + bucket=os.environ.get("S3_BUCKET", "datajoint.test"), ) diff --git a/tests/conftest.py b/tests/conftest.py index 109bda6c..97c71c1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,21 @@ import datajoint as dj from packaging import version import os +import minio +import urllib3 +import certifi +import shutil import pytest -from . import PREFIX, schema, schema_simple, schema_advanced +import networkx as nx +import json +from pathlib import Path +import tempfile +from datajoint import errors +from . import ( + PREFIX, CONN_INFO, S3_CONN_INFO, + schema, schema_simple, schema_advanced, schema_adapted +) + @pytest.fixture(scope="session") @@ -151,3 +164,73 @@ def schema_adv(connection_test): schema(schema_advanced.GlobalSynapse) yield schema schema.drop() + + +@pytest.fixture +def adapted_graph_instance(): + yield schema_adapted.GraphAdapter() + +@pytest.fixture +def enable_adapted_types(monkeypatch): + monkeypatch.setenv('ADAPTED_TYPE_SWITCH', 'TRUE') + yield + monkeypatch.delenv('ADAPTED_TYPE_SWITCH', raising=True) + +@pytest.fixture +def enable_filepath_feature(monkeypatch): + monkeypatch.setenv('FILEPATH_FEATURE_SWITCH', 'TRUE') + yield + monkeypatch.delenv('FILEPATH_FEATURE_SWITCH', raising=True) + +@pytest.fixture +def schema_ad(monkeypatch, connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature): + stores_config = { + "repo-s3": dict( + S3_CONN_INFO, protocol="s3", location="adapted/repo", stage=tempfile.mkdtemp() + ) + } + dj.config["stores"] = stores_config + schema_name = PREFIX + "_test_custom_datatype" + layout_to_filepath = schema_adapted.LayoutToFilepath() + context = { + **schema_adapted.LOCALS_ADAPTED, + 'graph': adapted_graph_instance, + 'layout_to_filepath': layout_to_filepath, + } + schema = dj.schema(schema_name, context=context, connection=connection_test) + + + # instantiate for use as a datajoint type + # TODO: remove? + graph = adapted_graph_instance + + schema(schema_adapted.Connectivity) + # errors._switch_filepath_types(True) + schema(schema_adapted.Layout) + yield schema + # errors._switch_filepath_types(False) + +@pytest.fixture +def httpClient(): + # Initialize httpClient with relevant timeout. + httpClient = urllib3.PoolManager( + timeout=30, + cert_reqs="CERT_REQUIRED", + ca_certs=certifi.where(), + retries=urllib3.Retry( + total=3, backoff_factor=0.2, status_forcelist=[500, 502, 503, 504] + ), + ) + yield httpClient + +@pytest.fixture +def minioClient(): + # Initialize minioClient with an endpoint and access/secret keys. + minioClient = minio.Minio( + S3_CONN_INFO["endpoint"], + access_key=S3_CONN_INFO["access_key"], + secret_key=S3_CONN_INFO["secret_key"], + secure=True, + http_client=httpClient, + ) + yield minioClient diff --git a/tests/schema_adapted.py b/tests/schema_adapted.py new file mode 100644 index 00000000..559c1423 --- /dev/null +++ b/tests/schema_adapted.py @@ -0,0 +1,61 @@ +import datajoint as dj +import inspect +import networkx as nx +import json +from pathlib import Path +import tempfile + + +class GraphAdapter(dj.AttributeAdapter): + attribute_type = "longblob" # this is how the attribute will be declared + + @staticmethod + def get(obj): + # convert edge list into a graph + return nx.Graph(obj) + + @staticmethod + def put(obj): + # convert graph object into an edge list + assert isinstance(obj, nx.Graph) + return list(obj.edges) + + +class LayoutToFilepath(dj.AttributeAdapter): + """ + An adapted data type that saves a graph layout into fixed filepath + """ + + attribute_type = "filepath@repo-s3" + + @staticmethod + def get(path): + with open(path, "r") as f: + return json.load(f) + + @staticmethod + def put(layout): + path = Path(dj.config["stores"]["repo-s3"]["stage"], "layout.json") + with open(str(path), "w") as f: + json.dump(layout, f) + return path + + +class Connectivity(dj.Manual): + definition = """ + connid : int + --- + conn_graph = null : + """ + +class Layout(dj.Manual): + definition = """ + # stores graph layout + -> Connectivity + --- + layout: + """ + + +LOCALS_ADAPTED = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_ADAPTED.keys()) diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py new file mode 100644 index 00000000..0c1d9ea0 --- /dev/null +++ b/tests/test_adapted_attributes.py @@ -0,0 +1,108 @@ +import os +import pytest +import datajoint as dj +import networkx as nx +from itertools import zip_longest +# from . import schema_adapted as adapted +from .schema_adapted import Connectivity, Layout + + +def test_adapted_type(schema_ad): + assert os.environ['ADAPTED_TYPE_SWITCH'] == 'TRUE' + c = Connectivity() + graphs = [ + nx.lollipop_graph(4, 2), + nx.star_graph(5), + nx.barbell_graph(3, 1), + nx.cycle_graph(5), + ] + c.insert((i, g) for i, g in enumerate(graphs)) + returned_graphs = c.fetch("conn_graph", order_by="connid") + for g1, g2 in zip(graphs, returned_graphs): + assert isinstance(g2, nx.Graph) + assert len(g1.edges) == len(g2.edges) + assert 0 == len(nx.symmetric_difference(g1, g2).edges) + c.delete() + + +# adapted_graph_instance? +def test_adapted_filepath_type(schema_ad): + # https://github.com/datajoint/datajoint-python/issues/684 + + # dj.errors._switch_adapted_types(True) + # dj.errors._switch_filepath_types(True) + + c = Connectivity() + c.delete() + c.insert1((0, nx.lollipop_graph(4, 2))) + + layout = nx.spring_layout(c.fetch1("conn_graph")) + # make json friendly + layout = {str(k): [round(r, ndigits=4) for r in v] for k, v in layout.items()} + t = Layout() + t.insert1((0, layout)) + result = t.fetch1("layout") + # TODO: may fail, used to be assert_dict_equal + assert result == layout + + t.delete() + c.delete() + + # dj.errors._switch_filepath_types(False) + # dj.errors._switch_adapted_types(False) + + +# test spawned classes +# TODO: separate fixture +# local_schema = dj.Schema(adapted.schema_name) +# local_schema.spawn_missing_classes() + +@pytest.mark.skip(reason='temp') +def test_adapted_spawned(): + dj.errors._switch_adapted_types(True) + c = Connectivity() # a spawned class + graphs = [ + nx.lollipop_graph(4, 2), + nx.star_graph(5), + nx.barbell_graph(3, 1), + nx.cycle_graph(5), + ] + c.insert((i, g) for i, g in enumerate(graphs)) + returned_graphs = c.fetch("conn_graph", order_by="connid") + for g1, g2 in zip(graphs, returned_graphs): + assert isinstance(g2, nx.Graph) + assert len(g1.edges) == len(g2.edges) + assert 0 == len(nx.symmetric_difference(g1, g2).edges) + c.delete() + dj.errors._switch_adapted_types(False) + + +# test with virtual module +# TODO: separate fixture +# virtual_module = dj.VirtualModule( +# "virtual_module", adapted.schema_name, add_objects={"graph": graph} +# ) + + +@pytest.mark.skip(reason='temp') +def test_adapted_virtual(): + dj.errors._switch_adapted_types(True) + c = virtual_module.Connectivity() + graphs = [ + nx.lollipop_graph(4, 2), + nx.star_graph(5), + nx.barbell_graph(3, 1), + nx.cycle_graph(5), + ] + c.insert((i, g) for i, g in enumerate(graphs)) + c.insert1({"connid": 100}) # test work with NULLs + returned_graphs = c.fetch("conn_graph", order_by="connid") + for g1, g2 in zip_longest(graphs, returned_graphs): + if g1 is None: + assert g2 is None + else: + assert isinstance(g2, nx.Graph) + assert len(g1.edges) == len(g2.edges) + assert 0 == len(nx.symmetric_difference(g1, g2).edges) + c.delete() + dj.errors._switch_adapted_types(False) From 3177773e848a5e95521c7969c9bd53f410f1744f Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Mon, 4 Dec 2023 08:42:26 -0600 Subject: [PATCH 06/14] Use correct env var names for feature switches --- tests/conftest.py | 14 ++++++++++---- tests/test_adapted_attributes.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 97c71c1e..86f34114 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,9 @@ from pathlib import Path import tempfile from datajoint import errors +from datajoint.errors import ( + ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH +) from . import ( PREFIX, CONN_INFO, S3_CONN_INFO, schema, schema_simple, schema_advanced, schema_adapted @@ -18,6 +21,7 @@ + @pytest.fixture(scope="session") def connection_root(): """Root user database connection.""" @@ -172,18 +176,19 @@ def adapted_graph_instance(): @pytest.fixture def enable_adapted_types(monkeypatch): - monkeypatch.setenv('ADAPTED_TYPE_SWITCH', 'TRUE') + monkeypatch.setenv(ADAPTED_TYPE_SWITCH, 'TRUE') yield - monkeypatch.delenv('ADAPTED_TYPE_SWITCH', raising=True) + monkeypatch.delenv(ADAPTED_TYPE_SWITCH, raising=True) @pytest.fixture def enable_filepath_feature(monkeypatch): - monkeypatch.setenv('FILEPATH_FEATURE_SWITCH', 'TRUE') + monkeypatch.setenv(FILEPATH_FEATURE_SWITCH, 'TRUE') yield - monkeypatch.delenv('FILEPATH_FEATURE_SWITCH', raising=True) + monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True) @pytest.fixture def schema_ad(monkeypatch, connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature): + assert os.environ.get(ADAPTED_TYPE_SWITCH) == 'TRUE', 'must have adapted types enabled in environment' stores_config = { "repo-s3": dict( S3_CONN_INFO, protocol="s3", location="adapted/repo", stage=tempfile.mkdtemp() @@ -209,6 +214,7 @@ def schema_ad(monkeypatch, connection_test, adapted_graph_instance, enable_adapt schema(schema_adapted.Layout) yield schema # errors._switch_filepath_types(False) + schema.drop() @pytest.fixture def httpClient(): diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index 0c1d9ea0..beb69414 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -8,7 +8,7 @@ def test_adapted_type(schema_ad): - assert os.environ['ADAPTED_TYPE_SWITCH'] == 'TRUE' + assert os.environ[dj.errors.ADAPTED_TYPE_SWITCH] == 'TRUE' c = Connectivity() graphs = [ nx.lollipop_graph(4, 2), From 1f1575a74329a111c66b602336760410f51c783a Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Mon, 4 Dec 2023 08:56:38 -0600 Subject: [PATCH 07/14] WIP migrating test_adapted_attributes tests/test_adapted_attributes.py::test_adapted_filepath_type throws datajoint/s3.py:54: BucketInaccessible --- tests/conftest.py | 29 ------------------------ tests/test_adapted_attributes.py | 38 ++++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 86f34114..67b02fbf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -186,35 +186,6 @@ def enable_filepath_feature(monkeypatch): yield monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True) -@pytest.fixture -def schema_ad(monkeypatch, connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature): - assert os.environ.get(ADAPTED_TYPE_SWITCH) == 'TRUE', 'must have adapted types enabled in environment' - stores_config = { - "repo-s3": dict( - S3_CONN_INFO, protocol="s3", location="adapted/repo", stage=tempfile.mkdtemp() - ) - } - dj.config["stores"] = stores_config - schema_name = PREFIX + "_test_custom_datatype" - layout_to_filepath = schema_adapted.LayoutToFilepath() - context = { - **schema_adapted.LOCALS_ADAPTED, - 'graph': adapted_graph_instance, - 'layout_to_filepath': layout_to_filepath, - } - schema = dj.schema(schema_name, context=context, connection=connection_test) - - - # instantiate for use as a datajoint type - # TODO: remove? - graph = adapted_graph_instance - - schema(schema_adapted.Connectivity) - # errors._switch_filepath_types(True) - schema(schema_adapted.Layout) - yield schema - # errors._switch_filepath_types(False) - schema.drop() @pytest.fixture def httpClient(): diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index beb69414..8657efee 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -1,11 +1,44 @@ import os import pytest +import tempfile import datajoint as dj +from datajoint.errors import ADAPTED_TYPE_SWITCH import networkx as nx from itertools import zip_longest -# from . import schema_adapted as adapted +from . import schema_adapted from .schema_adapted import Connectivity, Layout - +from . import PREFIX, S3_CONN_INFO + + +@pytest.fixture +def schema_ad(monkeypatch, connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature): + assert os.environ.get(ADAPTED_TYPE_SWITCH) == 'TRUE', 'must have adapted types enabled in environment' + stores_config = { + "repo-s3": dict( + S3_CONN_INFO, protocol="s3", location="adapted/repo", stage=tempfile.mkdtemp() + ) + } + dj.config["stores"] = stores_config + schema_name = PREFIX + "_test_custom_datatype" + layout_to_filepath = schema_adapted.LayoutToFilepath() + context = { + **schema_adapted.LOCALS_ADAPTED, + 'graph': adapted_graph_instance, + 'layout_to_filepath': layout_to_filepath, + } + schema = dj.schema(schema_name, context=context, connection=connection_test) + + + # instantiate for use as a datajoint type + # TODO: remove? + graph = adapted_graph_instance + + schema(schema_adapted.Connectivity) + # errors._switch_filepath_types(True) + schema(schema_adapted.Layout) + yield schema + # errors._switch_filepath_types(False) + schema.drop() def test_adapted_type(schema_ad): assert os.environ[dj.errors.ADAPTED_TYPE_SWITCH] == 'TRUE' @@ -26,6 +59,7 @@ def test_adapted_type(schema_ad): # adapted_graph_instance? +# @pytest.mark.skip(reason='misconfigured s3 fixtures') def test_adapted_filepath_type(schema_ad): # https://github.com/datajoint/datajoint-python/issues/684 From 21854dad18a3db00eb586bb3718d427a26f1d2df Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Mon, 4 Dec 2023 09:09:57 -0600 Subject: [PATCH 08/14] Migrate test_adapted_attributes: module scoped fixtures for now --- tests/conftest.py | 25 ++++++++--------------- tests/test_adapted_attributes.py | 34 ++++++++++++++++++++++++++------ 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 67b02fbf..aed3ca46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,15 @@ schema, schema_simple, schema_advanced, schema_adapted ) +@pytest.fixture(scope="session") +def monkeysession(): + with pytest.MonkeyPatch.context() as mp: + yield mp +@pytest.fixture(scope="module") +def monkeymodule(): + with pytest.MonkeyPatch.context() as mp: + yield mp @pytest.fixture(scope="session") @@ -170,23 +178,6 @@ def schema_adv(connection_test): schema.drop() -@pytest.fixture -def adapted_graph_instance(): - yield schema_adapted.GraphAdapter() - -@pytest.fixture -def enable_adapted_types(monkeypatch): - monkeypatch.setenv(ADAPTED_TYPE_SWITCH, 'TRUE') - yield - monkeypatch.delenv(ADAPTED_TYPE_SWITCH, raising=True) - -@pytest.fixture -def enable_filepath_feature(monkeypatch): - monkeypatch.setenv(FILEPATH_FEATURE_SWITCH, 'TRUE') - yield - monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True) - - @pytest.fixture def httpClient(): # Initialize httpClient with relevant timeout. diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index 8657efee..7e275c5a 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -2,7 +2,7 @@ import pytest import tempfile import datajoint as dj -from datajoint.errors import ADAPTED_TYPE_SWITCH +from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH import networkx as nx from itertools import zip_longest from . import schema_adapted @@ -10,9 +10,28 @@ from . import PREFIX, S3_CONN_INFO -@pytest.fixture -def schema_ad(monkeypatch, connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature): - assert os.environ.get(ADAPTED_TYPE_SWITCH) == 'TRUE', 'must have adapted types enabled in environment' +@pytest.fixture(scope='module') +def adapted_graph_instance(): + yield schema_adapted.GraphAdapter() + + +@pytest.fixture(scope='module') +def enable_adapted_types(monkeymodule): + monkeymodule.setenv(ADAPTED_TYPE_SWITCH, 'TRUE') + yield + monkeymodule.delenv(ADAPTED_TYPE_SWITCH, raising=True) + + +@pytest.fixture(scope='module') +def enable_filepath_feature(monkeymodule): + monkeymodule.setenv(FILEPATH_FEATURE_SWITCH, 'TRUE') + yield + monkeymodule.delenv(FILEPATH_FEATURE_SWITCH, raising=True) + + + +@pytest.fixture(scope='module') +def schema_ad(connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature): stores_config = { "repo-s3": dict( S3_CONN_INFO, protocol="s3", location="adapted/repo", stage=tempfile.mkdtemp() @@ -40,9 +59,12 @@ def schema_ad(monkeypatch, connection_test, adapted_graph_instance, enable_adapt # errors._switch_filepath_types(False) schema.drop() -def test_adapted_type(schema_ad): +@pytest.fixture(scope='module') +def c(schema_ad): + yield Connectivity() + +def test_adapted_type(schema_ad, c): assert os.environ[dj.errors.ADAPTED_TYPE_SWITCH] == 'TRUE' - c = Connectivity() graphs = [ nx.lollipop_graph(4, 2), nx.star_graph(5), From cd584bce1f41c666f35eb38b16d4ee97b968e946 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Mon, 4 Dec 2023 09:27:09 -0600 Subject: [PATCH 09/14] Add @dimitri-yatsenko suggested changes on #1116 --- tests/schema_adapted.py | 2 +- tests/schema_advanced.py | 2 +- tests/schema_simple.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/schema_adapted.py b/tests/schema_adapted.py index 559c1423..68a7e965 100644 --- a/tests/schema_adapted.py +++ b/tests/schema_adapted.py @@ -58,4 +58,4 @@ class Layout(dj.Manual): LOCALS_ADAPTED = {k: v for k, v in locals().items() if inspect.isclass(v)} -__all__ = list(LOCALS_ADAPTED.keys()) +__all__ = list(LOCALS_ADAPTED) diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py index f925e497..649ff186 100644 --- a/tests/schema_advanced.py +++ b/tests/schema_advanced.py @@ -135,4 +135,4 @@ class GlobalSynapse(dj.Manual): """ LOCALS_ADVANCED = {k: v for k, v in locals().items() if inspect.isclass(v)} -__all__ = list(LOCALS_ADVANCED.keys()) +__all__ = list(LOCALS_ADVANCED) diff --git a/tests/schema_simple.py b/tests/schema_simple.py index addd70c2..e751a9c6 100644 --- a/tests/schema_simple.py +++ b/tests/schema_simple.py @@ -262,4 +262,4 @@ class OutfitPiece(dj.Part, dj.Lookup): LOCALS_SIMPLE = {k: v for k, v in locals().items() if inspect.isclass(v)} -__all__ = list(LOCALS_SIMPLE.keys()) +__all__ = list(LOCALS_SIMPLE) From 93fa858e567b68fab2446566b1817a3c4f6aa8fe Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Mon, 4 Dec 2023 09:35:07 -0600 Subject: [PATCH 10/14] Migrate test_adapted_attributes::test_adapted_spawned --- tests/schema.py | 2 +- tests/test_adapted_attributes.py | 53 ++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/tests/schema.py b/tests/schema.py index 13ff945a..140a34bb 100644 --- a/tests/schema.py +++ b/tests/schema.py @@ -451,4 +451,4 @@ class Longblob(dj.Manual): LOCALS_ANY = {k: v for k, v in locals().items() if inspect.isclass(v)} -__all__ = list(LOCALS_ANY.keys()) +__all__ = list(LOCALS_ANY) diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index 7e275c5a..e6ce5679 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -10,42 +10,47 @@ from . import PREFIX, S3_CONN_INFO -@pytest.fixture(scope='module') +@pytest.fixture def adapted_graph_instance(): yield schema_adapted.GraphAdapter() -@pytest.fixture(scope='module') -def enable_adapted_types(monkeymodule): - monkeymodule.setenv(ADAPTED_TYPE_SWITCH, 'TRUE') +@pytest.fixture +def enable_adapted_types(monkeypatch): + monkeypatch.setenv(ADAPTED_TYPE_SWITCH, 'TRUE') yield - monkeymodule.delenv(ADAPTED_TYPE_SWITCH, raising=True) + monkeypatch.delenv(ADAPTED_TYPE_SWITCH, raising=True) -@pytest.fixture(scope='module') -def enable_filepath_feature(monkeymodule): - monkeymodule.setenv(FILEPATH_FEATURE_SWITCH, 'TRUE') +@pytest.fixture +def enable_filepath_feature(monkeypatch): + monkeypatch.setenv(FILEPATH_FEATURE_SWITCH, 'TRUE') yield - monkeymodule.delenv(FILEPATH_FEATURE_SWITCH, raising=True) + monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True) +@pytest.fixture +def schema_name_custom_datatype(): + schema_name = PREFIX + "_test_custom_datatype" + return schema_name -@pytest.fixture(scope='module') -def schema_ad(connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature): +@pytest.fixture +def schema_ad( + schema_name_custom_datatype, connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature +): stores_config = { "repo-s3": dict( S3_CONN_INFO, protocol="s3", location="adapted/repo", stage=tempfile.mkdtemp() ) } dj.config["stores"] = stores_config - schema_name = PREFIX + "_test_custom_datatype" layout_to_filepath = schema_adapted.LayoutToFilepath() context = { **schema_adapted.LOCALS_ADAPTED, 'graph': adapted_graph_instance, 'layout_to_filepath': layout_to_filepath, } - schema = dj.schema(schema_name, context=context, connection=connection_test) + schema = dj.schema(schema_name_custom_datatype, context=context, connection=connection_test) # instantiate for use as a datajoint type @@ -59,7 +64,7 @@ def schema_ad(connection_test, adapted_graph_instance, enable_adapted_types, ena # errors._switch_filepath_types(False) schema.drop() -@pytest.fixture(scope='module') +@pytest.fixture def c(schema_ad): yield Connectivity() @@ -81,7 +86,7 @@ def test_adapted_type(schema_ad, c): # adapted_graph_instance? -# @pytest.mark.skip(reason='misconfigured s3 fixtures') +@pytest.mark.skip(reason='misconfigured s3 fixtures') def test_adapted_filepath_type(schema_ad): # https://github.com/datajoint/datajoint-python/issues/684 @@ -108,14 +113,17 @@ def test_adapted_filepath_type(schema_ad): # dj.errors._switch_adapted_types(False) -# test spawned classes -# TODO: separate fixture -# local_schema = dj.Schema(adapted.schema_name) -# local_schema.spawn_missing_classes() +@pytest.fixture +def local_schema(schema_ad, schema_name_custom_datatype): + """Fixture for testing spawned classes""" + local_schema = dj.Schema(schema_name_custom_datatype) + local_schema.spawn_missing_classes() + yield local_schema + local_schema.drop() -@pytest.mark.skip(reason='temp') -def test_adapted_spawned(): - dj.errors._switch_adapted_types(True) + +# @pytest.mark.skip(reason='temp') +def test_adapted_spawned(local_schema, enable_adapted_types): c = Connectivity() # a spawned class graphs = [ nx.lollipop_graph(4, 2), @@ -130,7 +138,6 @@ def test_adapted_spawned(): assert len(g1.edges) == len(g2.edges) assert 0 == len(nx.symmetric_difference(g1, g2).edges) c.delete() - dj.errors._switch_adapted_types(False) # test with virtual module From 07631f9a0944c16c25d0aee1c29ac7788298db68 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Mon, 4 Dec 2023 09:42:59 -0600 Subject: [PATCH 11/14] All passing in test_adapted_attributes::test_adapted_spawned except s3 --- tests/test_adapted_attributes.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index e6ce5679..626cb969 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -122,9 +122,8 @@ def local_schema(schema_ad, schema_name_custom_datatype): local_schema.drop() -# @pytest.mark.skip(reason='temp') -def test_adapted_spawned(local_schema, enable_adapted_types): - c = Connectivity() # a spawned class +def test_adapted_spawned(local_schema, enable_adapted_types, c): + # c = Connectivity() # a spawned class graphs = [ nx.lollipop_graph(4, 2), nx.star_graph(5), @@ -140,17 +139,20 @@ def test_adapted_spawned(local_schema, enable_adapted_types): c.delete() -# test with virtual module -# TODO: separate fixture -# virtual_module = dj.VirtualModule( -# "virtual_module", adapted.schema_name, add_objects={"graph": graph} -# ) - - -@pytest.mark.skip(reason='temp') -def test_adapted_virtual(): - dj.errors._switch_adapted_types(True) - c = virtual_module.Connectivity() +@pytest.fixture +def schema_virtual_module(schema_ad, schema_name_custom_datatype, adapted_graph_instance): + """Fixture for testing virtual modules""" + # virtual_module = dj.VirtualModule( + # "virtual_module", adapted.schema_name, add_objects={"graph": graph} + # ) + schema_virtual_module = dj.VirtualModule( + "virtual_module", schema_name_custom_datatype, add_objects={"graph": adapted_graph_instance} + ) + return schema_virtual_module + + +def test_adapted_virtual(schema_virtual_module): + c = schema_virtual_module.Connectivity() graphs = [ nx.lollipop_graph(4, 2), nx.star_graph(5), @@ -168,4 +170,3 @@ def test_adapted_virtual(): assert len(g1.edges) == len(g2.edges) assert 0 == len(nx.symmetric_difference(g1, g2).edges) c.delete() - dj.errors._switch_adapted_types(False) From a6ca9339641d40732c06b346d7e4a8a18184f2b1 Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Mon, 4 Dec 2023 09:47:05 -0600 Subject: [PATCH 12/14] Clean up fixtures --- tests/test_adapted_attributes.py | 69 ++++++++++++-------------------- 1 file changed, 25 insertions(+), 44 deletions(-) diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index 626cb969..2ec0c239 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -34,9 +34,11 @@ def schema_name_custom_datatype(): schema_name = PREFIX + "_test_custom_datatype" return schema_name + @pytest.fixture def schema_ad( - schema_name_custom_datatype, connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature + schema_name_custom_datatype, connection_test, adapted_graph_instance, + enable_adapted_types, enable_filepath_feature ): stores_config = { "repo-s3": dict( @@ -51,25 +53,33 @@ def schema_ad( 'layout_to_filepath': layout_to_filepath, } schema = dj.schema(schema_name_custom_datatype, context=context, connection=connection_test) - - - # instantiate for use as a datajoint type - # TODO: remove? graph = adapted_graph_instance - schema(schema_adapted.Connectivity) - # errors._switch_filepath_types(True) schema(schema_adapted.Layout) yield schema - # errors._switch_filepath_types(False) schema.drop() + +@pytest.fixture +def local_schema(schema_ad, schema_name_custom_datatype): + """Fixture for testing spawned classes""" + local_schema = dj.Schema(schema_name_custom_datatype) + local_schema.spawn_missing_classes() + yield local_schema + local_schema.drop() + + @pytest.fixture -def c(schema_ad): - yield Connectivity() +def schema_virtual_module(schema_ad, schema_name_custom_datatype, adapted_graph_instance): + """Fixture for testing virtual modules""" + schema_virtual_module = dj.VirtualModule( + "virtual_module", schema_name_custom_datatype, add_objects={"graph": adapted_graph_instance} + ) + return schema_virtual_module -def test_adapted_type(schema_ad, c): - assert os.environ[dj.errors.ADAPTED_TYPE_SWITCH] == 'TRUE' + +def test_adapted_type(schema_ad): + c = Connectivity() graphs = [ nx.lollipop_graph(4, 2), nx.star_graph(5), @@ -85,14 +95,9 @@ def test_adapted_type(schema_ad, c): c.delete() -# adapted_graph_instance? @pytest.mark.skip(reason='misconfigured s3 fixtures') def test_adapted_filepath_type(schema_ad): - # https://github.com/datajoint/datajoint-python/issues/684 - - # dj.errors._switch_adapted_types(True) - # dj.errors._switch_filepath_types(True) - + """https://github.com/datajoint/datajoint-python/issues/684""" c = Connectivity() c.delete() c.insert1((0, nx.lollipop_graph(4, 2))) @@ -105,25 +110,12 @@ def test_adapted_filepath_type(schema_ad): result = t.fetch1("layout") # TODO: may fail, used to be assert_dict_equal assert result == layout - t.delete() c.delete() - # dj.errors._switch_filepath_types(False) - # dj.errors._switch_adapted_types(False) - -@pytest.fixture -def local_schema(schema_ad, schema_name_custom_datatype): - """Fixture for testing spawned classes""" - local_schema = dj.Schema(schema_name_custom_datatype) - local_schema.spawn_missing_classes() - yield local_schema - local_schema.drop() - - -def test_adapted_spawned(local_schema, enable_adapted_types, c): - # c = Connectivity() # a spawned class +def test_adapted_spawned(local_schema, enable_adapted_types): + c = Connectivity() # a spawned class graphs = [ nx.lollipop_graph(4, 2), nx.star_graph(5), @@ -139,17 +131,6 @@ def test_adapted_spawned(local_schema, enable_adapted_types, c): c.delete() -@pytest.fixture -def schema_virtual_module(schema_ad, schema_name_custom_datatype, adapted_graph_instance): - """Fixture for testing virtual modules""" - # virtual_module = dj.VirtualModule( - # "virtual_module", adapted.schema_name, add_objects={"graph": graph} - # ) - schema_virtual_module = dj.VirtualModule( - "virtual_module", schema_name_custom_datatype, add_objects={"graph": adapted_graph_instance} - ) - return schema_virtual_module - def test_adapted_virtual(schema_virtual_module): c = schema_virtual_module.Connectivity() From 81ef9a9a1e97f7015f478414d3a418572cf7751b Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Tue, 5 Dec 2023 17:28:54 +0000 Subject: [PATCH 13/14] Add @A-Baji suggestions for SCHEMA_NAME --- tests/test_adapted_attributes.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index 2ec0c239..ffa85f8a 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -9,6 +9,8 @@ from .schema_adapted import Connectivity, Layout from . import PREFIX, S3_CONN_INFO +SCHEMA_NAME = PREFIX + "_test_custom_datatype" + @pytest.fixture def adapted_graph_instance(): @@ -29,15 +31,9 @@ def enable_filepath_feature(monkeypatch): monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True) -@pytest.fixture -def schema_name_custom_datatype(): - schema_name = PREFIX + "_test_custom_datatype" - return schema_name - - @pytest.fixture def schema_ad( - schema_name_custom_datatype, connection_test, adapted_graph_instance, + connection_test, adapted_graph_instance, enable_adapted_types, enable_filepath_feature ): stores_config = { @@ -52,7 +48,7 @@ def schema_ad( 'graph': adapted_graph_instance, 'layout_to_filepath': layout_to_filepath, } - schema = dj.schema(schema_name_custom_datatype, context=context, connection=connection_test) + schema = dj.schema(SCHEMA_NAME, context=context, connection=connection_test) graph = adapted_graph_instance schema(schema_adapted.Connectivity) schema(schema_adapted.Layout) @@ -61,19 +57,19 @@ def schema_ad( @pytest.fixture -def local_schema(schema_ad, schema_name_custom_datatype): +def local_schema(schema_ad): """Fixture for testing spawned classes""" - local_schema = dj.Schema(schema_name_custom_datatype) + local_schema = dj.Schema(SCHEMA_NAME) local_schema.spawn_missing_classes() yield local_schema local_schema.drop() @pytest.fixture -def schema_virtual_module(schema_ad, schema_name_custom_datatype, adapted_graph_instance): +def schema_virtual_module(schema_ad, adapted_graph_instance): """Fixture for testing virtual modules""" schema_virtual_module = dj.VirtualModule( - "virtual_module", schema_name_custom_datatype, add_objects={"graph": adapted_graph_instance} + "virtual_module", SCHEMA_NAME, add_objects={"graph": adapted_graph_instance} ) return schema_virtual_module From eff463dd239d911d094ffef35ca553796af0473d Mon Sep 17 00:00:00 2001 From: Ethan Ho Date: Tue, 5 Dec 2023 17:29:04 +0000 Subject: [PATCH 14/14] Format with black --- tests/conftest.py | 20 ++++++++++++++------ tests/schema_adapted.py | 1 + tests/schema_advanced.py | 2 ++ tests/test_adapted_attributes.py | 22 +++++++++++++--------- tests/test_connection.py | 4 +++- tests/test_json.py | 7 ++++--- tests/test_plugin.py | 5 +++-- 7 files changed, 40 insertions(+), 21 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index aed3ca46..2c4063a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,19 +11,24 @@ from pathlib import Path import tempfile from datajoint import errors -from datajoint.errors import ( - ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH -) +from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH from . import ( - PREFIX, CONN_INFO, S3_CONN_INFO, - schema, schema_simple, schema_advanced, schema_adapted + PREFIX, + CONN_INFO, + S3_CONN_INFO, + schema, + schema_simple, + schema_advanced, + schema_adapted, ) + @pytest.fixture(scope="session") def monkeysession(): with pytest.MonkeyPatch.context() as mp: yield mp + @pytest.fixture(scope="module") def monkeymodule(): with pytest.MonkeyPatch.context() as mp: @@ -163,7 +168,9 @@ def schema_simp(connection_test): @pytest.fixture def schema_adv(connection_test): schema = dj.Schema( - PREFIX + "_advanced", schema_advanced.LOCALS_ADVANCED, connection=connection_test + PREFIX + "_advanced", + schema_advanced.LOCALS_ADVANCED, + connection=connection_test, ) schema(schema_advanced.Person) schema(schema_advanced.Parent) @@ -191,6 +198,7 @@ def httpClient(): ) yield httpClient + @pytest.fixture def minioClient(): # Initialize minioClient with an endpoint and access/secret keys. diff --git a/tests/schema_adapted.py b/tests/schema_adapted.py index 68a7e965..ab9a02e7 100644 --- a/tests/schema_adapted.py +++ b/tests/schema_adapted.py @@ -48,6 +48,7 @@ class Connectivity(dj.Manual): conn_graph = null : """ + class Layout(dj.Manual): definition = """ # stores graph layout diff --git a/tests/schema_advanced.py b/tests/schema_advanced.py index 649ff186..6a35cb34 100644 --- a/tests/schema_advanced.py +++ b/tests/schema_advanced.py @@ -1,6 +1,7 @@ import datajoint as dj import inspect + class Person(dj.Manual): definition = """ person_id : int @@ -134,5 +135,6 @@ class GlobalSynapse(dj.Manual): -> Cell.proj(post_slice="slice", post_cell="cell") """ + LOCALS_ADVANCED = {k: v for k, v in locals().items() if inspect.isclass(v)} __all__ = list(LOCALS_ADVANCED) diff --git a/tests/test_adapted_attributes.py b/tests/test_adapted_attributes.py index ffa85f8a..29d77347 100644 --- a/tests/test_adapted_attributes.py +++ b/tests/test_adapted_attributes.py @@ -19,34 +19,39 @@ def adapted_graph_instance(): @pytest.fixture def enable_adapted_types(monkeypatch): - monkeypatch.setenv(ADAPTED_TYPE_SWITCH, 'TRUE') + monkeypatch.setenv(ADAPTED_TYPE_SWITCH, "TRUE") yield monkeypatch.delenv(ADAPTED_TYPE_SWITCH, raising=True) @pytest.fixture def enable_filepath_feature(monkeypatch): - monkeypatch.setenv(FILEPATH_FEATURE_SWITCH, 'TRUE') + monkeypatch.setenv(FILEPATH_FEATURE_SWITCH, "TRUE") yield monkeypatch.delenv(FILEPATH_FEATURE_SWITCH, raising=True) @pytest.fixture def schema_ad( - connection_test, adapted_graph_instance, - enable_adapted_types, enable_filepath_feature + connection_test, + adapted_graph_instance, + enable_adapted_types, + enable_filepath_feature, ): stores_config = { "repo-s3": dict( - S3_CONN_INFO, protocol="s3", location="adapted/repo", stage=tempfile.mkdtemp() + S3_CONN_INFO, + protocol="s3", + location="adapted/repo", + stage=tempfile.mkdtemp(), ) } dj.config["stores"] = stores_config layout_to_filepath = schema_adapted.LayoutToFilepath() context = { **schema_adapted.LOCALS_ADAPTED, - 'graph': adapted_graph_instance, - 'layout_to_filepath': layout_to_filepath, + "graph": adapted_graph_instance, + "layout_to_filepath": layout_to_filepath, } schema = dj.schema(SCHEMA_NAME, context=context, connection=connection_test) graph = adapted_graph_instance @@ -91,7 +96,7 @@ def test_adapted_type(schema_ad): c.delete() -@pytest.mark.skip(reason='misconfigured s3 fixtures') +@pytest.mark.skip(reason="misconfigured s3 fixtures") def test_adapted_filepath_type(schema_ad): """https://github.com/datajoint/datajoint-python/issues/684""" c = Connectivity() @@ -127,7 +132,6 @@ def test_adapted_spawned(local_schema, enable_adapted_types): c.delete() - def test_adapted_virtual(schema_virtual_module): c = schema_virtual_module.Connectivity() graphs = [ diff --git a/tests/test_connection.py b/tests/test_connection.py index a73677ae..8cdbbbff 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -12,7 +12,9 @@ @pytest.fixture def schema(connection_test): - schema = dj.Schema(PREFIX + "_transactions", context=dict(), connection=connection_test) + schema = dj.Schema( + PREFIX + "_transactions", context=dict(), connection=connection_test + ) yield schema schema.drop() diff --git a/tests/test_json.py b/tests/test_json.py index 37a33c82..c1caaeed 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -107,6 +107,7 @@ def test_insert_update(schema): q.delete_quick() assert not q + def test_describe(schema): rel = Team() context = inspect.currentframe().f_globals @@ -114,6 +115,7 @@ def test_describe(schema): s2 = declare(rel.full_table_name, rel.describe(), context) assert s1 == s2 + def test_restrict(schema): # dict assert (Team & {"car.name": "Chaching"}).fetch1("name") == "business" @@ -139,9 +141,7 @@ def test_restrict(schema): assert (Team & {"car": None}).fetch1("name") == "marketing" - assert (Team & {"car.tire_pressure": [34, 30, 27, 32]}).fetch1( - "name" - ) == "business" + assert (Team & {"car.tire_pressure": [34, 30, 27, 32]}).fetch1("name") == "business" assert ( Team & {"car.headlights[1]": {"side": "right", "hyper_white": True}} @@ -175,6 +175,7 @@ def test_restrict(schema): & """`car`->>'$.headlights[1]' = '{"side": "right", "hyper_white": true}'""" ).fetch1("name") == "business", "2nd `headlight` object did not match" + def test_proj(schema): # proj necessary since we need to rename indexed value into a proper attribute name assert Team.proj(car_length="car.length").fetch( diff --git a/tests/test_plugin.py b/tests/test_plugin.py index e4122411..ddb8b3bf 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -23,7 +23,7 @@ def test_normal_djerror(): assert e.__cause__ is None -@pytest.mark.parametrize('category', ('connection', )) +@pytest.mark.parametrize("category", ("connection",)) def test_verified_djerror(category): try: curr_plugins = getattr(p, "{}_plugins".format(category)) @@ -41,7 +41,8 @@ def test_verified_djerror(category): def test_verified_djerror_type(): test_verified_djerror(category="type") -@pytest.mark.parametrize('category', ('connection', )) + +@pytest.mark.parametrize("category", ("connection",)) def test_unverified_djerror(category): try: curr_plugins = getattr(p, "{}_plugins".format(category))