Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sqla2 type hints iain #142

Merged
merged 3 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions and classes to create and populate the target database."""
import logging
from typing import Any, Dict, Generator, List, Tuple
from typing import Any, Generator, Mapping, Sequence, Tuple

from sqlalchemy import Connection, insert
from sqlalchemy.exc import IntegrityError
Expand All @@ -10,14 +10,16 @@
from sqlsynthgen.settings import get_settings
from sqlsynthgen.utils import create_db_engine, get_sync_engine

Story = Generator[Tuple[str, Dict[str, Any]], Dict[str, Any], None]
Story = Generator[Tuple[str, dict[str, Any]], dict[str, Any], None]


def create_db_tables(metadata: MetaData) -> None:
"""Create tables described by the sqlalchemy metadata object."""
settings = get_settings()
dst_dsn: str = settings.dst_dsn or ""
assert dst_dsn != "", "Missing DST_DSN setting."

engine = get_sync_engine(create_db_engine(settings.dst_dsn)) # type: ignore
engine = get_sync_engine(create_db_engine(dst_dsn))

# Create schema, if necessary.
if settings.dst_schema:
Expand All @@ -27,21 +29,19 @@ def create_db_tables(metadata: MetaData) -> None:
connection.execute(CreateSchema(schema_name, if_not_exists=True))

# Recreate the engine, this time with a schema specified
engine = get_sync_engine(
create_db_engine(settings.dst_dsn, schema_name=schema_name) # type: ignore
)
engine = get_sync_engine(create_db_engine(dst_dsn, schema_name=schema_name))

metadata.create_all(engine)


def create_db_vocab(vocab_dict: Dict[str, FileUploader]) -> None:
def create_db_vocab(vocab_dict: Mapping[str, FileUploader]) -> None:
"""Load vocabulary tables from files."""
settings = get_settings()
dst_dsn: str = settings.dst_dsn or ""
assert dst_dsn != "", "Missing DST_DSN setting."

dst_engine = get_sync_engine(
create_db_engine(
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
)
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
)

with dst_engine.connect() as dst_conn:
Expand All @@ -55,18 +55,18 @@ def create_db_vocab(vocab_dict: Dict[str, FileUploader]) -> None:


def create_db_data(
sorted_tables: list[Table],
table_generator_dict: dict[str, TableGenerator],
story_generator_list: list[dict[str, Any]],
sorted_tables: Sequence[Table],
table_generator_dict: Mapping[str, TableGenerator],
story_generator_list: Sequence[Mapping[str, Any]],
num_passes: int,
) -> None:
"""Connect to a database and populate it with data."""
settings = get_settings()
dst_dsn: str = settings.dst_dsn or ""
assert dst_dsn != "", "Missing DST_DSN setting."

dst_engine = get_sync_engine(
create_db_engine(
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
)
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
)

with dst_engine.connect() as dst_conn:
Expand All @@ -81,8 +81,8 @@ def create_db_data(

def _populate_story(
story: Story,
table_dict: Dict[str, Table],
table_generator_dict: Dict[str, TableGenerator],
table_dict: Mapping[str, Table],
table_generator_dict: Mapping[str, TableGenerator],
dst_conn: Connection,
) -> None:
"""Write to the database all the rows created by the given story."""
Expand Down Expand Up @@ -121,17 +121,17 @@ def _populate_story(

def populate(
dst_conn: Connection,
tables: list[Table],
table_generator_dict: dict[str, TableGenerator],
story_generator_list: list[dict[str, Any]],
tables: Sequence[Table],
table_generator_dict: Mapping[str, TableGenerator],
story_generator_list: Sequence[Mapping[str, Any]],
) -> None:
"""Populate a database schema with synthetic data."""
table_dict = {table.name: table for table in tables}
# Generate stories
# Each story generator returns a python generator (an unfortunate naming clash with
# what we call generators). Iterating over it yields individual rows for the
# database. First, collect all of the python generators into a single list.
stories: List[Story] = sum(
stories: list[Story] = sum(
[
[sg["name"](dst_conn) for _ in range(sg["num_stories_per_pass"])]
for sg in story_generator_list
Expand Down
71 changes: 35 additions & 36 deletions sqlsynthgen/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path
from sys import stderr
from types import ModuleType
from typing import Any, Dict, Final, List, Optional, Tuple
from typing import Any, Final, Mapping, Optional, Sequence, Tuple

import pandas as pd
import snsql
Expand All @@ -25,7 +25,7 @@
from sqlsynthgen.settings import get_settings
from sqlsynthgen.utils import create_db_engine, download_table, get_sync_engine

PROVIDER_IMPORTS: Final[List[str]] = []
PROVIDER_IMPORTS: Final[list[str]] = []
for entry_name, entry in inspect.getmembers(providers, inspect.isclass):
if issubclass(entry, BaseProvider) and entry.__module__ == "sqlsynthgen.providers":
PROVIDER_IMPORTS.append(entry_name)
Expand All @@ -49,14 +49,14 @@ class FunctionCall:
"""Contains the ssg.py content related function calls."""

function_name: str
argument_values: List[str]
argument_values: list[str]


@dataclass
class RowGeneratorInfo:
"""Contains the ssg.py content related to row generators of a table."""

variable_names: List[str]
variable_names: list[str]
function_call: FunctionCall
primary_key: bool = False

Expand All @@ -68,8 +68,8 @@ class TableGeneratorInfo:
class_name: str
table_name: str
rows_per_pass: int
row_gens: List[RowGeneratorInfo] = field(default_factory=list)
unique_constraints: List[UniqueConstraint] = field(default_factory=list)
row_gens: list[RowGeneratorInfo] = field(default_factory=list)
unique_constraints: list[UniqueConstraint] = field(default_factory=list)


@dataclass
Expand Down Expand Up @@ -100,38 +100,38 @@ def _orm_class_from_table_name(

def _get_function_call(
function_name: str,
positional_arguments: Optional[List[Any]] = None,
keyword_arguments: Optional[Dict[str, Any]] = None,
positional_arguments: Optional[Sequence[Any]] = None,
keyword_arguments: Optional[Mapping[str, Any]] = None,
) -> FunctionCall:
if positional_arguments is None:
positional_arguments = []

if keyword_arguments is None:
keyword_arguments = {}

argument_values: List[str] = [str(value) for value in positional_arguments]
argument_values: list[str] = [str(value) for value in positional_arguments]
argument_values += [f"{key}={value}" for key, value in keyword_arguments.items()]

return FunctionCall(function_name=function_name, argument_values=argument_values)


def _get_row_generator(
table_config: dict[str, Any],
) -> tuple[List[RowGeneratorInfo], list[str]]:
table_config: Mapping[str, Any],
) -> tuple[list[RowGeneratorInfo], list[str]]:
"""Get the row generators information, for the given table."""
row_gen_info: List[RowGeneratorInfo] = []
config: List[Dict[str, Any]] = table_config.get("row_generators", {})
row_gen_info: list[RowGeneratorInfo] = []
config: list[dict[str, Any]] = table_config.get("row_generators", {})
columns_covered = []
for gen_conf in config:
name: str = gen_conf["name"]
columns_assigned = gen_conf["columns_assigned"]
keyword_arguments: Dict[str, Any] = gen_conf.get("kwargs", {})
positional_arguments: List[str] = gen_conf.get("args", [])
keyword_arguments: Mapping[str, Any] = gen_conf.get("kwargs", {})
positional_arguments: Sequence[str] = gen_conf.get("args", [])

if isinstance(columns_assigned, str):
columns_assigned = [columns_assigned]

variable_names: List[str] = columns_assigned
variable_names: list[str] = columns_assigned
try:
columns_covered += columns_assigned
except TypeError:
Expand All @@ -158,9 +158,9 @@ def _get_default_generator(

# If it's a foreign key column, pull random values from the column it
# references.
variable_names: List[str] = []
variable_names: list[str] = []
generator_function: str = ""
generator_arguments: List[str] = []
generator_arguments: list[str] = []

if column.foreign_keys:
if len(column.foreign_keys) > 1:
Expand Down Expand Up @@ -202,19 +202,19 @@ def _get_default_generator(
)


def _get_provider_for_column(column: Column) -> Tuple[List[str], str, List[str]]:
def _get_provider_for_column(column: Column) -> Tuple[list[str], str, list[str]]:
"""
Get a default Mimesis provider and its arguments for a SQL column type.

Args:
column: SQLAlchemy column object

Returns:
Tuple[str, str, List[str]]: Tuple containing the variable names to assign to,
Tuple[str, str, list[str]]: Tuple containing the variable names to assign to,
generator function and any generator arguments.
"""
variable_names: List[str] = [column.name]
generator_arguments: List[str] = []
variable_names: list[str] = [column.name]
generator_arguments: list[str] = []

column_type = type(column.type)
column_size: Optional[int] = getattr(column.type, "length", None)
Expand Down Expand Up @@ -291,7 +291,7 @@ def _enforce_unique_constraints(table_data: TableGeneratorInfo) -> None:


def _get_generator_for_table(
tables_module: ModuleType, table_config: dict[str, Any], table: Table
tables_module: ModuleType, table_config: Mapping[str, Any], table: Table
) -> TableGeneratorInfo:
"""Get generator information for the given table."""
unique_constraints = [
Expand All @@ -318,7 +318,7 @@ def _get_generator_for_table(
return table_data


def _get_story_generators(config: dict) -> List[StoryGeneratorInfo]:
def _get_story_generators(config: Mapping) -> list[StoryGeneratorInfo]:
"""Get story generators."""
generators = []
for gen in config.get("story_generators", []):
Expand All @@ -339,7 +339,7 @@ def _get_story_generators(config: dict) -> List[StoryGeneratorInfo]:

def make_table_generators(
tables_module: ModuleType,
config: dict,
config: Mapping,
src_stats_filename: Optional[str],
overwrite_files: bool = False,
) -> str:
Expand All @@ -359,14 +359,13 @@ def make_table_generators(
story_generator_module_name = config.get("story_generators_module", None)

settings = get_settings()
engine = get_sync_engine(
create_db_engine(
settings.src_dsn, schema_name=settings.src_schema # type: ignore
)
)
src_dsn: str = settings.src_dsn or ""
assert src_dsn != "", "Missing SRC_DSN setting."

engine = get_sync_engine(create_db_engine(src_dsn, schema_name=settings.src_schema))

tables: List[TableGeneratorInfo] = []
vocabulary_tables: List[VocabularyTableGeneratorInfo] = []
tables: list[TableGeneratorInfo] = []
vocabulary_tables: list[VocabularyTableGeneratorInfo] = []

for table in tables_module.Base.metadata.sorted_tables:
table_config = config.get("tables", {}).get(table.name, {})
Expand Down Expand Up @@ -398,7 +397,7 @@ def make_table_generators(
)


def generate_ssg_content(template_context: Dict[str, Any]) -> str:
def generate_ssg_content(template_context: Mapping[str, Any]) -> str:
"""Generate the content of the ssg.py file as a string."""
environment: Environment = Environment(
loader=FileSystemLoader(TEMPLATE_DIRECTORY),
Expand Down Expand Up @@ -467,8 +466,8 @@ def make_tables_file(db_dsn: str, schema_name: Optional[str]) -> str:


async def make_src_stats(
dsn: str, config: dict, schema_name: Optional[str] = None
) -> Dict[str, List[dict]]:
dsn: str, config: Mapping, schema_name: Optional[str] = None
) -> dict[str, list[dict]]:
"""Run the src-stats queries specified by the configuration.

Query the src database with the queries in the src-stats block of the `config`
Expand All @@ -485,7 +484,7 @@ async def make_src_stats(
use_asyncio = config.get("use-asyncio", False)
engine = create_db_engine(dsn, schema_name=schema_name, use_asyncio=use_asyncio)

async def execute_query(query_block: Dict[str, Any]) -> Any:
async def execute_query(query_block: Mapping[str, Any]) -> Any:
"""Execute query in query_block."""
query = text(query_block["query"])
if isinstance(engine, AsyncEngine):
Expand Down
2 changes: 1 addition & 1 deletion sqlsynthgen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def create_db_engine(
db_dsn: str,
schema_name: Optional[str] = None,
use_asyncio: bool = False,
**kwargs: dict,
**kwargs: Any,
) -> MaybeAsyncEngine:
"""Create a SQLAlchemy Engine."""
if use_asyncio:
Expand Down
Loading