From 52d6a43146518af85d5513c4d663ec3a8bc59bda Mon Sep 17 00:00:00 2001 From: Elliana May Date: Thu, 15 Feb 2024 18:52:03 +0800 Subject: [PATCH] fix: support views in has_table --- duckdb_engine/__init__.py | 9 ++++++--- duckdb_engine/tests/conftest.py | 7 ++++++- duckdb_engine/tests/test_basic.py | 18 ++++++++++-------- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index 982c42be..68e3cb80 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -411,10 +411,13 @@ def get_table_oid( # type: ignore[no-untyped-def] In the latter scenario the schema associated with the default database is used. """ s = """ - SELECT table_oid - FROM duckdb_tables() + SELECT oid, table_name + FROM ( + SELECT table_oid AS oid, table_name, database_name, schema_name FROM duckdb_tables() + UNION ALL BY NAME + SELECT view_oid AS oid , view_name AS table_name, database_name, schema_name FROM duckdb_views() + ) WHERE schema_name NOT LIKE 'pg\\_%' ESCAPE '\\' - AND table_name = :table_name """ sql, params = self._build_query_where(table_name=table_name, schema_name=schema) s += sql diff --git a/duckdb_engine/tests/conftest.py b/duckdb_engine/tests/conftest.py index cdd739c2..245ee60c 100644 --- a/duckdb_engine/tests/conftest.py +++ b/duckdb_engine/tests/conftest.py @@ -5,7 +5,7 @@ from pytest import fixture, raises from sqlalchemy import create_engine from sqlalchemy.dialects import registry # type: ignore -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Dialect, Engine from sqlalchemy.engine.base import Connection from sqlalchemy.orm import Session, sessionmaker from typing_extensions import ParamSpec @@ -33,6 +33,11 @@ def conn(engine: Engine) -> Generator[Connection, None, None]: yield conn +@fixture() +def dialect(engine: Engine) -> Dialect: + return engine.dialect + + @fixture def session(engine: Engine) -> Session: return sessionmaker(bind=engine)() diff --git a/duckdb_engine/tests/test_basic.py b/duckdb_engine/tests/test_basic.py index e323475c..5ad1d77a 100644 --- a/duckdb_engine/tests/test_basic.py +++ b/duckdb_engine/tests/test_basic.py @@ -30,7 +30,7 @@ types, ) from sqlalchemy.dialects import registry # type: ignore -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.exc import DBAPIError from sqlalchemy.ext.declarative import declarative_base @@ -233,22 +233,24 @@ def test_get_table_names(inspector: Inspector, session: Session) -> None: assert inspector.has_table(table_name) -def test_get_views(engine: Engine) -> None: - con = engine.connect() - views = engine.dialect.get_view_names(con) +def test_get_views(conn: Connection, dialect: Dialect) -> None: + views = dialect.get_view_names(conn) assert views == [] - con.execute(text("create view test as select 1")) - con.execute( + conn.execute(text("create view test as select 1")) + conn.execute( text("create schema scheme; create view scheme.schema_test as select 1") ) - views = engine.dialect.get_view_names(con) + views = dialect.get_view_names(conn) assert views == ["test"] - views = engine.dialect.get_view_names(con, schema="scheme") + views = dialect.get_view_names(conn, schema="scheme") assert views == ["schema_test"] + assert dialect.has_table(conn, table_name="test") + assert dialect.has_table(conn, table_name="schema_test", schema="scheme") + @mark.skipif(os.uname().machine == "aarch64", reason="not supported on aarch64") @mark.remote_data