From 2b3383db3692cc46c2e262d4f5800584e906fc1f Mon Sep 17 00:00:00 2001 From: Ben Harling Date: Wed, 22 Aug 2018 09:45:06 +0100 Subject: [PATCH] Support binding connection in sqlalchemy as well as engine --- .../ext/sqlalchemy/util/decorators.py | 7 ++++- tests/ext/sqlalchemy/test_query.py | 30 ++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/aws_xray_sdk/ext/sqlalchemy/util/decorators.py b/aws_xray_sdk/ext/sqlalchemy/util/decorators.py index d5a646e8..a32a891c 100644 --- a/aws_xray_sdk/ext/sqlalchemy/util/decorators.py +++ b/aws_xray_sdk/ext/sqlalchemy/util/decorators.py @@ -4,6 +4,7 @@ from future.standard_library import install_aliases install_aliases() from urllib.parse import urlparse, uses_netloc +from sqlalchemy.engine.base import Connection def decorate_all_functions(function_decorator): @@ -86,7 +87,11 @@ def wrapper(*args, **kw): # } def parse_bind(bind): """Parses a connection string and creates SQL trace metadata""" - m = re.match(r"Engine\((.*?)\)", str(bind)) + if isinstance(bind, Connection): + engine = bind.engine + else: + engine = bind + m = re.match(r"Engine\((.*?)\)", str(engine)) if m is not None: u = urlparse(m.group(1)) # Add Scheme to uses_netloc or // will be missing from url. diff --git a/tests/ext/sqlalchemy/test_query.py b/tests/ext/sqlalchemy/test_query.py index 2edb205a..c664b724 100644 --- a/tests/ext/sqlalchemy/test_query.py +++ b/tests/ext/sqlalchemy/test_query.py @@ -21,7 +21,12 @@ class User(Base): @pytest.fixture() -def session(): +def engine(): + return create_engine('sqlite:///:memory:') + + +@pytest.fixture() +def session(engine): """Test Fixture to Create DataBase Tables and start a trace segment""" engine = create_engine('sqlite:///:memory:') xray_recorder.configure(service='test', sampling=False, context=Context()) @@ -35,6 +40,21 @@ def session(): xray_recorder.clear_trace_entities() +@pytest.fixture() +def connection(engine): + conn = engine.connect() + xray_recorder.configure(service='test', sampling=False, context=Context()) + xray_recorder.clear_trace_entities() + xray_recorder.begin_segment('SQLAlchemyTest') + Session = XRaySessionMaker(bind=conn) + Base.metadata.create_all(engine) + session = Session() + yield session + xray_recorder.end_segment() + xray_recorder.clear_trace_entities() + + + def test_all(capsys, session): """ Test calling all() on get all records. Verify we run the query and return the SQL as metdata""" @@ -46,6 +66,14 @@ def test_all(capsys, session): assert subsegment['sql']['url'] +def test_supports_connection(capsys, connection): + """ Test that XRaySessionMaker supports connection as well as engine""" + connection.query(User).all() + subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', + 'sqlalchemy.orm.query.all') + assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.query.all' + + def test_add(capsys, session): """ Test calling add() on insert a row. Verify we that we capture trace for the add"""