diff --git a/lore/io/connection.py b/lore/io/connection.py index 306b680..8a8c7bb 100644 --- a/lore/io/connection.py +++ b/lore/io/connection.py @@ -438,7 +438,7 @@ def __prepare(self, sql=None, extract=None, filename=None, **kwargs): return sql - def __connection_execute(self, sql, bindings): + def _connection_execute(self, sql, bindings): if self._use_psycopg2: with self._connection.engine.raw_connection().connection as conn: with conn.cursor() as cursor: @@ -446,7 +446,7 @@ def __connection_execute(self, sql, bindings): try: return ResultWrapper(cursor.fetchall()) except psycopg2.ProgrammingError as e: - if 'no results to fetch' in str(y): + if 'no results to fetch' in str(e): return None raise e else: @@ -454,12 +454,12 @@ def __connection_execute(self, sql, bindings): def __execute(self, sql, bindings): try: - return self.__connection_execute(sql, bindings) + return self._connection_execute(sql, bindings) except (sqlalchemy.exc.DBAPIError, Psycopg2OperationalError, SnowflakeProgrammingError) as e: if not self._transactions and (isinstance(e, Psycopg2OperationalError) or e.connection_invalidated): logger.warning('Reconnect and retry due to invalid connection') self.close() - return self.__connection_execute(sql, bindings) + return self._connection_execute(sql, bindings) elif not self._transactions and (isinstance(e, SnowflakeProgrammingError) or e.connection_invalidated): if hasattr(e, 'msg') and e.msg and "authenticate" in e.msg.lower(): logger.warning('Reconnect and retry due to unauthenticated connection') diff --git a/tests/unit/io/test_connection.py b/tests/unit/io/test_connection.py index c36421a..acbdd87 100644 --- a/tests/unit/io/test_connection.py +++ b/tests/unit/io/test_connection.py @@ -12,7 +12,7 @@ from sqlalchemy import event from sqlalchemy.engine import Engine import pandas - +import psycopg2 import lore @@ -137,7 +137,7 @@ def insert(delay=0): lore.io.main.execute(sql='insert into tests_autocommit values (1), (2), (3)') posts.append(lore.io.main.select(sql='select count(*) from tests_autocommit')[0][0]) time.sleep(delay) - except sqlalchemy.exc.IntegrityError as ex: + except psycopg2.IntegrityError as ex: thrown.append(True) slow = Thread(target=insert, args=(1,)) @@ -167,17 +167,17 @@ def test_close(self): lore.io.main.select(sql='select count(*) from tests_close') def test_reconnect_and_retry(self): - original_execute = lore.io.main._connection.execute + original_execute = lore.io.main._connection_execute def raise_dbapi_error_on_first_call(sql, bindings): - lore.io.main._connection.execute = original_execute + lore.io.main._connection_execute = original_execute e = lore.io.connection.Psycopg2OperationalError('server closed the connection unexpectedly. This probably means the server terminated abnormally before or while processing the request.') raise sqlalchemy.exc.DBAPIError('select 1', [], e, True) exceptions = lore.env.STDOUT_EXCEPTIONS lore.env.STDOUT_EXCEPTIONS = False connection = lore.io.main._connection - lore.io.main._connection.execute = raise_dbapi_error_on_first_call + lore.io.main._connection_execute = raise_dbapi_error_on_first_call result = lore.io.main.select(sql='select 1') lore.env.STDOUT_EXCEPTIONS = exceptions @@ -192,17 +192,17 @@ def test_tuple_interpolation(self): self.assertEqual(len(temps), 3) def test_reconnect_and_retry_on_expired_connection(self): - original_execute = lore.io.main._connection.execute + original_execute = lore.io.main._connection_execute def raise_snowflake_programming_error_on_first_call(sql, bindings): - lore.io.main._connection.execute = original_execute + lore.io.main._connection_execute = original_execute e = lore.io.connection.SnowflakeProgrammingError('Authentication token has expired. The user must authenticate again') raise sqlalchemy.exc.DBAPIError('select 1', [], e, True) exceptions = lore.env.STDOUT_EXCEPTIONS lore.env.STDOUT_EXCEPTIONS = False connection = lore.io.main._connection - lore.io.main._connection.execute = raise_snowflake_programming_error_on_first_call + lore.io.main._connection_execute = raise_snowflake_programming_error_on_first_call result = lore.io.main.select(sql='select 1') lore.env.STDOUT_EXCEPTIONS = exceptions