Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code Distribution SQL query #63

Merged
merged 6 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 39 additions & 37 deletions src/hutch_bunny/core/query_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import hutch_bunny.core.settings as settings
from hutch_bunny.core.constants import DISTRIBUTION_TYPE_FILE_NAMES_MAP


# Class for availability queries
class AvailibilityQuerySolver:
subqueries = list()
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, db_manager: SyncDBManager, query: AvailabilityQuery) -> None:
should be based in one table but it is actually present in another

"""

def _find_concepts(self) -> dict:
concept_ids = set()
for group in self.query.cohort.groups:
Expand Down Expand Up @@ -113,9 +115,10 @@ def _find_concepts(self) -> dict:
the person was between a certain age. - #TODO - not sure this is implemented here

"""

def _solve_rules(self) -> None:

#get the list of concepts to build the query constraints
# get the list of concepts to build the query constraints
concepts = self._find_concepts()

# This is related to the logic within a group. This is used in the subsequent for loop to determine how
Expand Down Expand Up @@ -257,17 +260,18 @@ def _solve_rules(self) -> None:
(1) call solve_rules that takes each group and adds those results to the sub_queries list
(2) this function then iterates through the list of groups to resolve the logic (AND/OR) between groups
"""

def solve_query(self) -> int:
#resolve within the group
# resolve within the group
self._solve_rules()

merge_method = lambda x: "inner" if x == "AND" else "outer"

#seed the dataframe with the first
# seed the dataframe with the first
group0_df = self.subqueries[0]
group0_df.rename({"person_id": "person_id_0"}, inplace=True, axis=1)

#for the next, rename columns to give a unique key, then merge based on the merge_method value
# for the next, rename columns to give a unique key, then merge based on the merge_method value
for i, df in enumerate(self.subqueries[1:], start=1):
df.rename({"person_id": f"person_id_{i}"}, inplace=True, axis=1)
group0_df = group0_df.merge(
Expand All @@ -284,9 +288,10 @@ class BaseDistributionQuerySolver:
def solve_query(self) -> Tuple[str, int]:
raise NotImplementedError

# class for distriubtion queries

# class for distribution queries
class CodeDistributionQuerySolver(BaseDistributionQuerySolver):
#todo - can the following be placed somewhere once as its repeated for all classes handling queries
# todo - can the following be placed somewhere once as its repeated for all classes handling queries
allowed_domains_map = {
"Condition": ConditionOccurrence,
"Ethnicity": Person,
Expand Down Expand Up @@ -346,53 +351,50 @@ def solve_query(self) -> Tuple[str, int]:
concepts = list()
categories = list()
biobanks = list()
omop_desc = list()

with self.db_manager.engine.connect() as con:
#todo - rename k, as this is a domain id that is being used
for k in self.allowed_domains_map:

for domain_id in self.allowed_domains_map:
# get the right table and column based on the domain
table = self.allowed_domains_map[k]
concept_col = self.domain_concept_id_map[k]
table = self.allowed_domains_map[domain_id]
concept_col = self.domain_concept_id_map[domain_id]

# gets a list of all concepts within this given table and their respective counts
stmnt = select(func.count(table.person_id), concept_col).group_by(
concept_col
stmnt = (
select(
func.count(table.person_id),
Concept.concept_id,
Concept.concept_name,
)
.join(Concept, concept_col == Concept.concept_id)
.group_by(Concept.concept_id, Concept.concept_name)
)
res = pd.read_sql(stmnt, con)
counts.extend(res.iloc[:, 0])
concepts.extend(res.iloc[:, 1])

omop_desc.extend(res.iloc[:, 2])
# add the same category and collection if, for the number of results received
categories.extend([k] * len(res))
categories.extend([domain_id] * len(res))
biobanks.extend([self.query.collection] * len(res))

df["COUNT"] = counts
df["OMOP"] = concepts
df["CATEGORY"] = categories
df["CODE"] = df["OMOP"].apply(lambda x: f"OMOP:{x}")
df["BIOBANK"] = biobanks
df["COUNT"] = counts
df["OMOP"] = concepts
df["CATEGORY"] = categories
df["CODE"] = df["OMOP"].apply(lambda x: f"OMOP:{x}")
df["BIOBANK"] = biobanks
df["OMOP_DESCR"] = omop_desc

# Get descriptions
#todo - not sure why this can be included in the SQL output above, it would need a join to the concept table
concept_query = select(Concept.concept_id, Concept.concept_name).where(
Concept.concept_id.in_(concepts)
)
concepts_df = pd.read_sql_query(
concept_query, con=con
)
for _, row in concepts_df.iterrows():
df.loc[df["OMOP"] == row["concept_id"], "OMOP_DESCR"] = row["concept_name"]
# replace NaN values with empty string
df = df.fillna('')
# Convert df to tab separated string
results = list(["\t".join(df.columns)])
for _, row in df.iterrows():
results.append("\t".join([str(r) for r in row.values]))
# replace NaN values with empty string
df = df.fillna("")
# Convert df to tab separated string
results = list(["\t".join(df.columns)])
for _, row in df.iterrows():
results.append("\t".join([str(r) for r in row.values]))

return os.linesep.join(results), len(df)

#todo - i *think* the only diference between this one and generic is that the allowed_domain list is different. Could we not just have the one class and functions that have this passed in?

# todo - i *think* the only diference between this one and generic is that the allowed_domain list is different. Could we not just have the one class and functions that have this passed in?
class DemographicsDistributionQuerySolver(BaseDistributionQuerySolver):
allowed_domains_map = {
"Gender": Person,
Expand Down
105 changes: 105 additions & 0 deletions tests/test_code_distribution_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest
from hutch_bunny.core.query_solvers import DistributionQuery, solve_distribution
from hutch_bunny.core.db_manager import SyncDBManager
from hutch_bunny.core.rquest_dto.result import RquestResult
from hutch_bunny.core.rquest_dto.file import File
from dotenv import load_dotenv
import os
import hutch_bunny.core.settings as settings
import hutch_bunny.core.setting_database as db_settings

load_dotenv()

pytestmark = pytest.mark.skipif(
os.environ.get("CI") is not None, reason="Skip integration tests in CI"
)


@pytest.fixture
def db_manager():
datasource_db_port = os.getenv("DATASOURCE_DB_PORT")
return SyncDBManager(
username=os.getenv("DATASOURCE_DB_USERNAME"),
password=os.getenv("DATASOURCE_DB_PASSWORD"),
host=os.getenv("DATASOURCE_DB_HOST"),
port=(int(datasource_db_port) if datasource_db_port is not None else None),
database=os.getenv("DATASOURCE_DB_DATABASE"),
drivername=db_settings.expand_short_drivers(
os.getenv("DATASOURCE_DB_DRIVERNAME", settings.DEFAULT_DB_DRIVER)
),
schema=os.getenv("DATASOURCE_DB_SCHEMA"),
)


@pytest.fixture
def distribution_query():
return DistributionQuery(
owner="user1",
code="GENERAL",
analysis="DISTRIBUTION",
uuid="unique_id",
collection="collection_id",
)


@pytest.fixture
def distribution_example():
return RquestResult(
uuid="unique_id",
status="ok",
collection_id="collection_id",
count=1,
datasets_count=1,
files=[
File(
name="code.distribution",
data="",
description="Result of code.distribution analysis",
size=0.308,
type_="BCOS",
sensitive=True,
reference="",
)
],
message="",
protocol_version="v2",
)


@pytest.fixture
def distribution_result(db_manager, distribution_query):
db_manager.list_tables()
return solve_distribution(db_manager=db_manager, query=distribution_query)


def test_solve_distribution_returns_result(distribution_result):
assert isinstance(distribution_result, RquestResult)


def test_solve_distribution_is_ok(distribution_result):
assert distribution_result.status == "ok"


def test_solve_distribution_files_count(distribution_result):
assert len(distribution_result.files) == 1


def test_solve_distribution_files_type(distribution_result):
assert isinstance(distribution_result.files[0], File)


def test_solve_distribution_match_query(distribution_result, distribution_example):
assert distribution_result.files[0].name == distribution_example.files[0].name
assert distribution_result.files[0].type_ == distribution_example.files[0].type_
assert (
distribution_result.files[0].description
== distribution_example.files[0].description
)
assert (
distribution_result.files[0].sensitive
== distribution_example.files[0].sensitive
)
assert (
distribution_result.files[0].reference
== distribution_example.files[0].reference
)
2 changes: 1 addition & 1 deletion tests/test_demographics_distribution_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def distribution_example():
File(
name="demographics.distribution",
data="",
description="Result of code.distribution anaylsis",
description="Result of demographics.distribution analysis",
size=0.308,
type_="BCOS",
sensitive=True,
Expand Down
Loading