Skip to content

Commit

Permalink
Update db.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thompson0012 committed Jan 24, 2022
1 parent 673e6d0 commit 52847da
Showing 1 changed file with 49 additions and 8 deletions.
57 changes: 49 additions & 8 deletions pyemits/common/io/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from sqlmodel import create_engine, Session, SQLModel
from sqlmodel.engine.create import _FutureEngine
from typing import Union, Optional
from pyemits.common.validation import raise_if_incorrect_type
from pyemits.common.validation import raise_if_incorrect_type, raise_if_not_all_value_contains, \
raise_if_not_all_element_type_uniform, check_all_element_type_uniform, raise_if_value_not_contains
from typing import List


class DBConnectionBase:
Expand Down Expand Up @@ -56,23 +58,62 @@ def get_db_inspector(self):
inspector = inspect(self._db_engine)
return inspector

def get_schemas(self):
def get_schemas(self, schemas='all', tables='all'):
inspector = self.get_db_inspector()
schemas = inspector.get_schema_names()

from collections import defaultdict
schema_containers = defaultdict(dict)
for schema in schemas:
# print("schema: %s" % schema)
for table_name in inspector.get_table_names(schema=schema):
schema_containers[schema][table_name] = inspector.get_columns(table_name, schema=schema)

return schema_containers
schemas = _validate_schema_names(inspector, schemas)

return _get_schemas(inspector, schema_containers, schemas, tables)

def get_tables_names(self):
inspector = self.get_db_inspector()
return inspector.get_table_names()


def _get_schemas(inspector, schema_containers, schemas: Union[str, List[str]], tables: Union[str, List[List[str]]]):
schema_list = _validate_schema_names(inspector, schemas)
if check_all_element_type_uniform(tables, list):
for schema, table in zip(schema_list, tables):
table_names = _validate_table_names(inspector, schema, table)
for sub_table_names in table_names:
schema_containers[schema][sub_table_names] = inspector.get_columns(sub_table_names, schema=schema)
return schema_containers

elif check_all_element_type_uniform(tables, str) or tables == 'all':
for schema in schema_list:
table_names = _validate_table_names(inspector, schema, tables)
for table_name in table_names:
schema_containers[schema][table_name] = inspector.get_columns(table_name, schema=schema)

return schema_containers

raise ValueError


def _validate_schema_names(inspector, schemas: List[str]):
if schemas == 'all':
return inspector.get_schema_names()

if isinstance(schemas, list):
raise_if_not_all_value_contains(schemas, inspector.get_schema_names())
return schemas
raise ValueError('schemas must be "all" or a list of string')


def _validate_table_names(inspector, schema: str, tables: List[str]):
if tables == 'all':
return inspector.get_table_names(schema=schema)

if isinstance(tables, list):
if check_all_element_type_uniform(tables, str):
raise_if_value_not_contains(tables, inspector.get_table_names(schema=schema))
return tables
elif check_all_element_type_uniform(tables, list):
for sub_tab in tables:
print(sub_tab, inspector.get_table_names(schema=schema))
raise_if_value_not_contains(sub_tab, inspector.get_table_names(schema=schema))
return tables
raise ValueError('tables name are not existed in database, pls verify')

0 comments on commit 52847da

Please sign in to comment.