Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PLAT-138 Follow Up to #1114 #1116

Merged
merged 16 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,24 @@
import pytest
import os

PREFIX = "djtest"
PREFIX = os.environ.get("DJ_TEST_DB_PREFIX", "djtest")

# Connection for testing
CONN_INFO = dict(
host=os.environ.get("DJ_TEST_HOST", "fakeservices.datajoint.io"),
user=os.environ.get("DJ_TEST_USER", "datajoint"),
password=os.environ.get("DJ_TEST_PASSWORD", "datajoint"),
)

CONN_INFO_ROOT = dict(
host=os.getenv("DJ_HOST"),
user=os.getenv("DJ_USER"),
password=os.getenv("DJ_PASS"),
host=os.environ.get("DJ_HOST", "fakeservices.datajoint.io"),
user=os.environ.get("DJ_USER", "root"),
password=os.environ.get("DJ_PASS", "simple"),
)

S3_CONN_INFO = dict(
endpoint=os.environ.get("S3_ENDPOINT", "fakeservices.datajoint.io"),
access_key=os.environ.get("S3_ACCESS_KEY", "datajoint"),
secret_key=os.environ.get("S3_SECRET_KEY", "datajoint"),
bucket=os.environ.get("S3_BUCKET", "datajoint.test"),
)
74 changes: 66 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
import datajoint as dj
from packaging import version
import os
import minio
import urllib3
import certifi
import shutil
import pytest
from . import PREFIX, schema, schema_simple, schema_advanced
import networkx as nx
import json
from pathlib import Path
import tempfile
from datajoint import errors
from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH
from . import (
PREFIX,
CONN_INFO,
S3_CONN_INFO,
schema,
schema_simple,
schema_advanced,
schema_adapted,
)

namespace = locals()

@pytest.fixture(scope="session")
def monkeysession():
with pytest.MonkeyPatch.context() as mp:
yield mp


@pytest.fixture(scope="module")
def monkeymodule():
with pytest.MonkeyPatch.context() as mp:
yield mp


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -64,11 +92,12 @@ def connection_test(connection_root):
connection.close()


@pytest.fixture(scope="module")
@pytest.fixture
def schema_any(connection_test):
schema_any = dj.Schema(
PREFIX + "_test1", schema.__dict__, connection=connection_test
PREFIX + "_test1", schema.LOCALS_ANY, connection=connection_test
)
assert schema.LOCALS_ANY, "LOCALS_ANY is empty"
schema_any(schema.TTest)
schema_any(schema.TTest2)
schema_any(schema.TTest3)
Expand Down Expand Up @@ -109,10 +138,10 @@ def schema_any(connection_test):
schema_any.drop()


@pytest.fixture(scope="module")
@pytest.fixture
def schema_simp(connection_test):
schema = dj.Schema(
PREFIX + "_relational", schema_simple.__dict__, connection=connection_test
PREFIX + "_relational", schema_simple.LOCALS_SIMPLE, connection=connection_test
)
schema(schema_simple.IJ)
schema(schema_simple.JI)
Expand All @@ -136,10 +165,12 @@ def schema_simp(connection_test):
schema.drop()


@pytest.fixture(scope="module")
@pytest.fixture
def schema_adv(connection_test):
schema = dj.Schema(
PREFIX + "_advanced", schema_advanced.__dict__, connection=connection_test
PREFIX + "_advanced",
schema_advanced.LOCALS_ADVANCED,
connection=connection_test,
)
schema(schema_advanced.Person)
schema(schema_advanced.Parent)
Expand All @@ -152,3 +183,30 @@ def schema_adv(connection_test):
schema(schema_advanced.GlobalSynapse)
yield schema
schema.drop()


@pytest.fixture
def httpClient():
# Initialize httpClient with relevant timeout.
httpClient = urllib3.PoolManager(
timeout=30,
cert_reqs="CERT_REQUIRED",
ca_certs=certifi.where(),
retries=urllib3.Retry(
total=3, backoff_factor=0.2, status_forcelist=[500, 502, 503, 504]
),
)
yield httpClient


@pytest.fixture
def minioClient():
# Initialize minioClient with an endpoint and access/secret keys.
minioClient = minio.Minio(
S3_CONN_INFO["endpoint"],
access_key=S3_CONN_INFO["access_key"],
secret_key=S3_CONN_INFO["secret_key"],
secure=True,
http_client=httpClient,
)
yield minioClient
10 changes: 6 additions & 4 deletions tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import datajoint as dj
import inspect

LOCALS_ANY = locals()


class TTest(dj.Lookup):
"""
Expand All @@ -33,15 +31,15 @@ class TTest2(dj.Manual):

class TTest3(dj.Manual):
definition = """
key : int
key : int
---
value : varchar(300)
"""


class NullableNumbers(dj.Manual):
definition = """
key : int
key : int
---
fvalue = null : float
dvalue = null : double
Expand Down Expand Up @@ -450,3 +448,7 @@ class Longblob(dj.Manual):
---
data: longblob
"""


LOCALS_ANY = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_ANY)
62 changes: 62 additions & 0 deletions tests/schema_adapted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import datajoint as dj
import inspect
import networkx as nx
import json
from pathlib import Path
import tempfile


class GraphAdapter(dj.AttributeAdapter):
attribute_type = "longblob" # this is how the attribute will be declared

@staticmethod
def get(obj):
# convert edge list into a graph
return nx.Graph(obj)

@staticmethod
def put(obj):
# convert graph object into an edge list
assert isinstance(obj, nx.Graph)
return list(obj.edges)


class LayoutToFilepath(dj.AttributeAdapter):
"""
An adapted data type that saves a graph layout into fixed filepath
"""

attribute_type = "filepath@repo-s3"

@staticmethod
def get(path):
with open(path, "r") as f:
return json.load(f)

@staticmethod
def put(layout):
path = Path(dj.config["stores"]["repo-s3"]["stage"], "layout.json")
with open(str(path), "w") as f:
json.dump(layout, f)
return path


class Connectivity(dj.Manual):
definition = """
connid : int
---
conn_graph = null : <graph>
"""


class Layout(dj.Manual):
definition = """
# stores graph layout
-> Connectivity
---
layout: <layout_to_filepath>
"""


LOCALS_ADAPTED = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_ADAPTED)
7 changes: 5 additions & 2 deletions tests/schema_advanced.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datajoint as dj

LOCALS_ADVANCED = locals()
import inspect


class Person(dj.Manual):
Expand Down Expand Up @@ -135,3 +134,7 @@ class GlobalSynapse(dj.Manual):
-> Cell.proj(pre_slice="slice", pre_cell="cell")
-> Cell.proj(post_slice="slice", post_cell="cell")
"""


LOCALS_ADVANCED = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_ADVANCED)
11 changes: 7 additions & 4 deletions tests/schema_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import faker
import numpy as np
from datetime import date, timedelta

LOCALS_SIMPLE = locals()
import inspect


class IJ(dj.Lookup):
Expand Down Expand Up @@ -237,8 +236,8 @@ class ReservedWord(dj.Manual):
# Test of SQL reserved words
key : int
---
in : varchar(25)
from : varchar(25)
in : varchar(25)
from : varchar(25)
int : int
select : varchar(25)
"""
Expand All @@ -260,3 +259,7 @@ class OutfitPiece(dj.Part, dj.Lookup):
piece: varchar(20)
"""
contents = [(0, "jeans"), (0, "sneakers"), (0, "polo")]


LOCALS_SIMPLE = {k: v for k, v in locals().items() if inspect.isclass(v)}
__all__ = list(LOCALS_SIMPLE)
Loading