From 202cc604e1fb7b863473756b0c6e515f458eb691 Mon Sep 17 00:00:00 2001 From: Hasier Rodriguez Date: Mon, 7 Jan 2019 14:08:40 +0000 Subject: [PATCH] Unpatch dbapi2, patch use custom cursor for Django and chunked_cursor --- CHANGELOG.rst | 2 +- README.md | 5 +- aws_xray_sdk/ext/dbapi2.py | 10 ++-- aws_xray_sdk/ext/django/db.py | 56 ++++++++++++++++--- tests/ext/django/test_db.py | 87 +++++++++++++++++++++++++++++ tests/ext/psycopg2/test_psycopg2.py | 25 +-------- tox.ini | 1 + 7 files changed, 145 insertions(+), 41 deletions(-) create mode 100644 tests/ext/django/test_db.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 21ce3194..00813e17 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,7 +4,7 @@ CHANGELOG unreleased ========== -* feature: Stream dbapi2 SQL queries and add flag to toggle their streaming +* feature: Stream Django ORM SQL queries and add flag to toggle their streaming 2.2.0 ===== diff --git a/README.md b/README.md index 86970710..13804828 100644 --- a/README.md +++ b/README.md @@ -256,7 +256,6 @@ By default, if no other value is provided to `.configure()`, SQL trace streaming for all the supported DB engines. Those currently are: - Any engine attached to the Django ORM. - Any engine attached to SQLAlchemy. -- SQLite3. The behaviour can be toggled by sending the appropriate `stream_sql` value, for example: ```python @@ -292,8 +291,8 @@ MIDDLEWARE = [ ``` #### SQL tracing If Django's ORM is patched - either using the `AUTO_INSTRUMENT = True` in your settings file -or explicitly calling `patch_db()` - the SQL query trace streaming can be enabled or disabled -updating the `STREAM_SQL` variable in your settings file. +or explicitly calling `patch_db()` - the SQL query trace streaming can then be enabled or +disabled updating the `STREAM_SQL` variable in your settings file. It is enabled by default. ### Add Flask middleware diff --git a/aws_xray_sdk/ext/dbapi2.py b/aws_xray_sdk/ext/dbapi2.py index 8ae4df49..c3ed8241 100644 --- a/aws_xray_sdk/ext/dbapi2.py +++ b/aws_xray_sdk/ext/dbapi2.py @@ -43,23 +43,23 @@ def __enter__(self): @xray_recorder.capture() def execute(self, query, *args, **kwargs): - add_sql_meta(self._xray_meta, query) + add_sql_meta(self._xray_meta) return self.__wrapped__.execute(query, *args, **kwargs) @xray_recorder.capture() def executemany(self, query, *args, **kwargs): - add_sql_meta(self._xray_meta, query) + add_sql_meta(self._xray_meta) return self.__wrapped__.executemany(query, *args, **kwargs) @xray_recorder.capture() def callproc(self, proc, args): - add_sql_meta(self._xray_meta, proc) + add_sql_meta(self._xray_meta) return self.__wrapped__.callproc(proc, args) -def add_sql_meta(meta, query): +def add_sql_meta(meta): subsegment = xray_recorder.current_subsegment() @@ -72,7 +72,5 @@ def add_sql_meta(meta, query): sql_meta = copy.copy(meta) if sql_meta.get('name', None): del sql_meta['name'] - if xray_recorder.stream_sql: - sql_meta['sanitized_query'] = query subsegment.set_sql(sql_meta) subsegment.namespace = 'remote' diff --git a/aws_xray_sdk/ext/django/db.py b/aws_xray_sdk/ext/django/db.py index 0a2c80d6..fdf7e27a 100644 --- a/aws_xray_sdk/ext/django/db.py +++ b/aws_xray_sdk/ext/django/db.py @@ -1,29 +1,62 @@ +import copy import logging import importlib from django.db import connections +from aws_xray_sdk.core import xray_recorder from aws_xray_sdk.ext.dbapi2 import XRayTracedCursor log = logging.getLogger(__name__) def patch_db(): - for conn in connections.all(): module = importlib.import_module(conn.__module__) _patch_conn(getattr(module, conn.__class__.__name__)) -def _patch_conn(conn): - - attr = '_xray_original_cursor' +class DjangoXRayTracedCursor(XRayTracedCursor): + def execute(self, query, *args, **kwargs): + if xray_recorder.stream_sql: + _previous_meta = copy.copy(self._xray_meta) + self._xray_meta['sanitized_query'] = query + result = super(DjangoXRayTracedCursor, self).execute(query, *args, **kwargs) + if xray_recorder.stream_sql: + self._xray_meta = _previous_meta + return result + + def executemany(self, query, *args, **kwargs): + if xray_recorder.stream_sql: + _previous_meta = copy.copy(self._xray_meta) + self._xray_meta['sanitized_query'] = query + result = super(DjangoXRayTracedCursor, self).executemany(query, *args, **kwargs) + if xray_recorder.stream_sql: + self._xray_meta = _previous_meta + return result + + def callproc(self, proc, args): + if xray_recorder.stream_sql: + _previous_meta = copy.copy(self._xray_meta) + self._xray_meta['sanitized_query'] = proc + result = super(DjangoXRayTracedCursor, self).callproc(proc, args) + if xray_recorder.stream_sql: + self._xray_meta = _previous_meta + return result + + +def _patch_cursor(cursor_name, conn): + attr = '_xray_original_{}'.format(cursor_name) if hasattr(conn, attr): - log.debug('django built-in db already patched') + log.debug('django built-in db {} already patched'.format(cursor_name)) + return + + if not hasattr(conn, cursor_name): + log.debug('django built-in db does not have {}'.format(cursor_name)) return - setattr(conn, attr, conn.cursor) + setattr(conn, attr, getattr(conn, cursor_name)) meta = {} @@ -45,7 +78,12 @@ def cursor(self, *args, **kwargs): if user: meta['user'] = user - return XRayTracedCursor( - self._xray_original_cursor(*args, **kwargs), meta) + original_cursor = getattr(self, attr)(*args, **kwargs) + return DjangoXRayTracedCursor(original_cursor, meta) + + setattr(conn, cursor_name, cursor) - conn.cursor = cursor + +def _patch_conn(conn): + _patch_cursor('cursor', conn) + _patch_cursor('chunked_cursor', conn) diff --git a/tests/ext/django/test_db.py b/tests/ext/django/test_db.py new file mode 100644 index 00000000..1c3e5439 --- /dev/null +++ b/tests/ext/django/test_db.py @@ -0,0 +1,87 @@ +import django + +import pytest + +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.context import Context +from aws_xray_sdk.ext.django.db import patch_db + + +@pytest.fixture(scope='module', autouse=True) +def setup(): + django.setup() + xray_recorder.configure(context=Context(), + context_missing='LOG_ERROR') + patch_db() + + +@pytest.fixture(scope='module') +def user_class(setup): + from django.db import models + from django_fake_model import models as f + + class User(f.FakeModel): + name = models.CharField(max_length=255) + password = models.CharField(max_length=255) + + return User + + +@pytest.fixture( + autouse=True, + params=[ + False, + True, + ] +) +@pytest.mark.django_db +def func_setup(request, user_class): + xray_recorder.stream_sql = request.param + xray_recorder.clear_trace_entities() + xray_recorder.begin_segment('name') + try: + user_class.create_table() + yield + finally: + xray_recorder.clear_trace_entities() + try: + user_class.delete_table() + finally: + xray_recorder.end_segment() + + +def _assert_query(sql_meta): + if xray_recorder.stream_sql: + assert 'sanitized_query' in sql_meta + assert sql_meta['sanitized_query'] + assert sql_meta['sanitized_query'].startswith('SELECT') + else: + if 'sanitized_query' in sql_meta: + assert sql_meta['sanitized_query'] + # Django internally executes queries for table checks, ignore those + assert not sql_meta['sanitized_query'].startswith('SELECT') + + +def test_all(user_class): + """ Test calling all() on get all records. + Verify we run the query and return the SQL as metadata""" + # Materialising the query executes the SQL + list(user_class.objects.all()) + subsegment = xray_recorder.current_segment().subsegments[-1] + sql = subsegment.sql + assert sql['database_type'] == 'sqlite' + _assert_query(sql) + + +def test_filter(user_class): + """ Test calling filter() to get filtered records. + Verify we run the query and return the SQL as metadata""" + # Materialising the query executes the SQL + list(user_class.objects.filter(password='mypassword!').all()) + subsegment = xray_recorder.current_segment().subsegments[-1] + sql = subsegment.sql + assert sql['database_type'] == 'sqlite' + _assert_query(sql) + if xray_recorder.stream_sql: + assert 'mypassword!' not in sql['sanitized_query'] + assert '"password" = %s' in sql['sanitized_query'] diff --git a/tests/ext/psycopg2/test_psycopg2.py b/tests/ext/psycopg2/test_psycopg2.py index 7b833097..c491d706 100644 --- a/tests/ext/psycopg2/test_psycopg2.py +++ b/tests/ext/psycopg2/test_psycopg2.py @@ -12,34 +12,20 @@ patch(('psycopg2',)) -@pytest.fixture( - autouse=True, - params=[ - False, - True, - ], -) -def construct_ctx(request): +@pytest.fixture(autouse=True) +def construct_ctx(): """ Clean up context storage on each test run and begin a segment so that later subsegment can be attached. After each test run it cleans up context storage again. """ - xray_recorder.configure(service='test', sampling=False, context=Context(), stream_sql=request.param) + xray_recorder.configure(service='test', sampling=False, context=Context()) xray_recorder.clear_trace_entities() xray_recorder.begin_segment('name') yield xray_recorder.clear_trace_entities() -def _assert_query(sql_meta, query): - if xray_recorder.stream_sql: - assert 'sanitized_query' in sql_meta - assert sql_meta['sanitized_query'] == query - else: - assert 'sanitized_query' not in sql_meta - - def test_execute_dsn_kwargs(): q = 'SELECT 1' with testing.postgresql.Postgresql() as postgresql: @@ -60,7 +46,6 @@ def test_execute_dsn_kwargs(): assert sql['user'] == dsn['user'] assert sql['url'] == url assert sql['database_version'] - _assert_query(sql, q) def test_execute_dsn_kwargs_alt_dbname(): @@ -87,7 +72,6 @@ def test_execute_dsn_kwargs_alt_dbname(): assert sql['user'] == dsn['user'] assert sql['url'] == url assert sql['database_version'] - _assert_query(sql, q) def test_execute_dsn_string(): @@ -110,7 +94,6 @@ def test_execute_dsn_string(): assert sql['user'] == dsn['user'] assert sql['url'] == url assert sql['database_version'] - _assert_query(sql, q) def test_execute_in_pool(): @@ -134,7 +117,6 @@ def test_execute_in_pool(): assert sql['user'] == dsn['user'] assert sql['url'] == url assert sql['database_version'] - _assert_query(sql, q) def test_execute_bad_query(): @@ -163,7 +145,6 @@ def test_execute_bad_query(): exception = subsegment.cause['exceptions'][0] assert exception.type == 'ProgrammingError' - _assert_query(sql, q) def test_register_extensions(): diff --git a/tox.ini b/tox.ini index 1bd78b23..0dbbe0be 100644 --- a/tox.ini +++ b/tox.ini @@ -18,6 +18,7 @@ deps = future # the sdk doesn't support earlier version of django django >= 1.10, <2.0 + django-fake-model pynamodb >= 3.3.1 psycopg2 pg8000