diff --git a/src/hutch_bunny/core/query_solvers.py b/src/hutch_bunny/core/query_solvers.py index 37e9247..c400a35 100644 --- a/src/hutch_bunny/core/query_solvers.py +++ b/src/hutch_bunny/core/query_solvers.py @@ -25,6 +25,7 @@ import hutch_bunny.core.settings as settings from hutch_bunny.core.constants import DISTRIBUTION_TYPE_FILE_NAMES_MAP + # Class for availability queries class AvailibilityQuerySolver: subqueries = list() @@ -80,6 +81,7 @@ def __init__(self, db_manager: SyncDBManager, query: AvailabilityQuery) -> None: should be based in one table but it is actually present in another """ + def _find_concepts(self) -> dict: concept_ids = set() for group in self.query.cohort.groups: @@ -113,9 +115,10 @@ def _find_concepts(self) -> dict: the person was between a certain age. - #TODO - not sure this is implemented here """ + def _solve_rules(self) -> None: - #get the list of concepts to build the query constraints + # get the list of concepts to build the query constraints concepts = self._find_concepts() # This is related to the logic within a group. This is used in the subsequent for loop to determine how @@ -257,17 +260,18 @@ def _solve_rules(self) -> None: (1) call solve_rules that takes each group and adds those results to the sub_queries list (2) this function then iterates through the list of groups to resolve the logic (AND/OR) between groups """ + def solve_query(self) -> int: - #resolve within the group + # resolve within the group self._solve_rules() merge_method = lambda x: "inner" if x == "AND" else "outer" - #seed the dataframe with the first + # seed the dataframe with the first group0_df = self.subqueries[0] group0_df.rename({"person_id": "person_id_0"}, inplace=True, axis=1) - #for the next, rename columns to give a unique key, then merge based on the merge_method value + # for the next, rename columns to give a unique key, then merge based on the merge_method value for i, df in enumerate(self.subqueries[1:], start=1): df.rename({"person_id": f"person_id_{i}"}, inplace=True, axis=1) group0_df = group0_df.merge( @@ -284,9 +288,10 @@ class BaseDistributionQuerySolver: def solve_query(self) -> Tuple[str, int]: raise NotImplementedError -# class for distriubtion queries + +# class for distribution queries class CodeDistributionQuerySolver(BaseDistributionQuerySolver): - #todo - can the following be placed somewhere once as its repeated for all classes handling queries + # todo - can the following be placed somewhere once as its repeated for all classes handling queries allowed_domains_map = { "Condition": ConditionOccurrence, "Ethnicity": Person, @@ -346,53 +351,50 @@ def solve_query(self) -> Tuple[str, int]: concepts = list() categories = list() biobanks = list() + omop_desc = list() with self.db_manager.engine.connect() as con: - #todo - rename k, as this is a domain id that is being used - for k in self.allowed_domains_map: - + for domain_id in self.allowed_domains_map: # get the right table and column based on the domain - table = self.allowed_domains_map[k] - concept_col = self.domain_concept_id_map[k] + table = self.allowed_domains_map[domain_id] + concept_col = self.domain_concept_id_map[domain_id] # gets a list of all concepts within this given table and their respective counts - stmnt = select(func.count(table.person_id), concept_col).group_by( - concept_col + stmnt = ( + select( + func.count(table.person_id), + Concept.concept_id, + Concept.concept_name, + ) + .join(Concept, concept_col == Concept.concept_id) + .group_by(Concept.concept_id, Concept.concept_name) ) res = pd.read_sql(stmnt, con) counts.extend(res.iloc[:, 0]) concepts.extend(res.iloc[:, 1]) - + omop_desc.extend(res.iloc[:, 2]) # add the same category and collection if, for the number of results received - categories.extend([k] * len(res)) + categories.extend([domain_id] * len(res)) biobanks.extend([self.query.collection] * len(res)) - df["COUNT"] = counts - df["OMOP"] = concepts - df["CATEGORY"] = categories - df["CODE"] = df["OMOP"].apply(lambda x: f"OMOP:{x}") - df["BIOBANK"] = biobanks + df["COUNT"] = counts + df["OMOP"] = concepts + df["CATEGORY"] = categories + df["CODE"] = df["OMOP"].apply(lambda x: f"OMOP:{x}") + df["BIOBANK"] = biobanks + df["OMOP_DESCR"] = omop_desc - # Get descriptions - #todo - not sure why this can be included in the SQL output above, it would need a join to the concept table - concept_query = select(Concept.concept_id, Concept.concept_name).where( - Concept.concept_id.in_(concepts) - ) - concepts_df = pd.read_sql_query( - concept_query, con=con - ) - for _, row in concepts_df.iterrows(): - df.loc[df["OMOP"] == row["concept_id"], "OMOP_DESCR"] = row["concept_name"] - # replace NaN values with empty string - df = df.fillna('') - # Convert df to tab separated string - results = list(["\t".join(df.columns)]) - for _, row in df.iterrows(): - results.append("\t".join([str(r) for r in row.values])) + # replace NaN values with empty string + df = df.fillna("") + # Convert df to tab separated string + results = list(["\t".join(df.columns)]) + for _, row in df.iterrows(): + results.append("\t".join([str(r) for r in row.values])) return os.linesep.join(results), len(df) -#todo - i *think* the only diference between this one and generic is that the allowed_domain list is different. Could we not just have the one class and functions that have this passed in? + +# todo - i *think* the only diference between this one and generic is that the allowed_domain list is different. Could we not just have the one class and functions that have this passed in? class DemographicsDistributionQuerySolver(BaseDistributionQuerySolver): allowed_domains_map = { "Gender": Person, diff --git a/tests/test_code_distribution_query.py b/tests/test_code_distribution_query.py new file mode 100644 index 0000000..4568cb5 --- /dev/null +++ b/tests/test_code_distribution_query.py @@ -0,0 +1,105 @@ +import pytest +from hutch_bunny.core.query_solvers import DistributionQuery, solve_distribution +from hutch_bunny.core.db_manager import SyncDBManager +from hutch_bunny.core.rquest_dto.result import RquestResult +from hutch_bunny.core.rquest_dto.file import File +from dotenv import load_dotenv +import os +import hutch_bunny.core.settings as settings +import hutch_bunny.core.setting_database as db_settings + +load_dotenv() + +pytestmark = pytest.mark.skipif( + os.environ.get("CI") is not None, reason="Skip integration tests in CI" +) + + +@pytest.fixture +def db_manager(): + datasource_db_port = os.getenv("DATASOURCE_DB_PORT") + return SyncDBManager( + username=os.getenv("DATASOURCE_DB_USERNAME"), + password=os.getenv("DATASOURCE_DB_PASSWORD"), + host=os.getenv("DATASOURCE_DB_HOST"), + port=(int(datasource_db_port) if datasource_db_port is not None else None), + database=os.getenv("DATASOURCE_DB_DATABASE"), + drivername=db_settings.expand_short_drivers( + os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER) + ), + schema=os.getenv("DATASOURCE_DB_SCHEMA"), + ) + + +@pytest.fixture +def distribution_query(): + return DistributionQuery( + owner="user1", + code="GENERAL", + analysis="DISTRIBUTION", + uuid="unique_id", + collection="collection_id", + ) + + +@pytest.fixture +def distribution_example(): + return RquestResult( + uuid="unique_id", + status="ok", + collection_id="collection_id", + count=1, + datasets_count=1, + files=[ + File( + name="code.distribution", + data="", + description="Result of code.distribution analysis", + size=0.308, + type_="BCOS", + sensitive=True, + reference="", + ) + ], + message="", + protocol_version="v2", + ) + + +@pytest.fixture +def distribution_result(db_manager, distribution_query): + db_manager.list_tables() + return solve_distribution(db_manager=db_manager, query=distribution_query) + + +def test_solve_distribution_returns_result(distribution_result): + assert isinstance(distribution_result, RquestResult) + + +def test_solve_distribution_is_ok(distribution_result): + assert distribution_result.status == "ok" + + +def test_solve_distribution_files_count(distribution_result): + assert len(distribution_result.files) == 1 + + +def test_solve_distribution_files_type(distribution_result): + assert isinstance(distribution_result.files[0], File) + + +def test_solve_distribution_match_query(distribution_result, distribution_example): + assert distribution_result.files[0].name == distribution_example.files[0].name + assert distribution_result.files[0].type_ == distribution_example.files[0].type_ + assert ( + distribution_result.files[0].description + == distribution_example.files[0].description + ) + assert ( + distribution_result.files[0].sensitive + == distribution_example.files[0].sensitive + ) + assert ( + distribution_result.files[0].reference + == distribution_example.files[0].reference + ) diff --git a/tests/test_demographics_distribution_query.py b/tests/test_demographics_distribution_query.py index 613b6f1..5180b3d 100644 --- a/tests/test_demographics_distribution_query.py +++ b/tests/test_demographics_distribution_query.py @@ -26,7 +26,7 @@ def distribution_example(): File( name="demographics.distribution", data="", - description="Result of code.distribution anaylsis", + description="Result of demographics.distribution analysis", size=0.308, type_="BCOS", sensitive=True,