From 0a7cbdf58f422708923ed95e85e6fbb83c8fbaf8 Mon Sep 17 00:00:00 2001 From: sroda Date: Wed, 18 Jan 2023 15:06:17 +0200 Subject: [PATCH] Change the code to save all listen params so we will know how to remove them when uninstument --- .../instrumentation/sqlalchemy/__init__.py | 23 +++++----------- .../instrumentation/sqlalchemy/engine.py | 26 +++++++++++++------ .../tests/test_sqlalchemy.py | 19 ++++++++++++++ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py index 975c9ca385..b19de5ec96 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py @@ -119,8 +119,6 @@ class SQLAlchemyInstrumentor(BaseInstrumentor): See `BaseInstrumentor` """ - engines = [] - def instrumentation_dependencies(self) -> Collection[str]: return _instruments @@ -160,22 +158,17 @@ def _instrument(self, **kwargs): "create_async_engine", _wrap_create_async_engine(tracer_provider, enable_commenter), ) - - self.engines = [] if kwargs.get("engine") is not None: - self.engines.append( - EngineTracer( - _get_tracer(tracer_provider), - kwargs.get("engine"), - kwargs.get("enable_commenter", False), - kwargs.get("commenter_options", {}), - ) + return EngineTracer( + _get_tracer(tracer_provider), + kwargs.get("engine"), + kwargs.get("enable_commenter", False), + kwargs.get("commenter_options", {}), ) - return self.engines[0] if kwargs.get("engines") is not None and isinstance( kwargs.get("engines"), Sequence ): - self.engines = [ + return [ EngineTracer( _get_tracer(tracer_provider), engine, @@ -184,7 +177,6 @@ def _instrument(self, **kwargs): ) for engine in kwargs.get("engines") ] - return self.engines return None @@ -194,5 +186,4 @@ def _uninstrument(self, **kwargs): unwrap(Engine, "connect") if parse_version(sqlalchemy.__version__).release >= (1, 4): unwrap(sqlalchemy.ext.asyncio, "create_async_engine") - for engine in self.engines: - engine.remove_event_listeners() + EngineTracer.remove_all_event_listeners() diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py index 62eba4b08d..0ed51d3a6c 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py @@ -98,6 +98,8 @@ def _wrap_connect_internal(func, module, args, kwargs): class EngineTracer: + _removeEventListenerParams = [] + def __init__( self, tracer, engine, enable_commenter=False, commenter_options=None ): @@ -108,16 +110,24 @@ def __init__( self.commenter_options = commenter_options if commenter_options else {} self._leading_comment_remover = re.compile(r"^/\*.*?\*/") - listen( + self._register_event_listener( engine, "before_cursor_execute", self._before_cur_exec, retval=True ) - listen(engine, "after_cursor_execute", _after_cur_exec) - listen(engine, "handle_error", _handle_error) - - def remove_event_listeners(self): - remove(self.engine, "before_cursor_execute", self._before_cur_exec) - remove(self.engine, "after_cursor_execute", _after_cur_exec) - remove(self.engine, "handle_error", _handle_error) + self._register_event_listener( + engine, "after_cursor_execute", _after_cur_exec + ) + self._register_event_listener(engine, "handle_error", _handle_error) + + @classmethod + def _register_event_listener(cls, target, identifier, fn, *args, **kw): + listen(target, identifier, fn, *args, **kw) + cls._removeEventListenerParams.append((target, identifier, fn)) + + @classmethod + def remove_all_event_listeners(cls): + for removeParams in cls._removeEventListenerParams: + remove(*removeParams) + cls._removeEventListenerParams.clear() def _operation_name(self, db_name, statement): parts = [] diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py index 9788fe1cc1..68061ccf50 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py @@ -255,6 +255,25 @@ def test_uninstrument(self): spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 0) + def test_uninstrument_without_engine(self): + SQLAlchemyInstrumentor().instrument( + tracer_provider=self.tracer_provider + ) + from sqlalchemy import create_engine + + engine = create_engine("sqlite:///:memory:") + + cnx = engine.connect() + cnx.execute("SELECT 1 + 1;").fetchall() + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 2) + + self.memory_exporter.clear() + SQLAlchemyInstrumentor().uninstrument() + cnx.execute("SELECT 1 + 1;").fetchall() + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + def test_no_op_tracer_provider(self): engine = create_engine("sqlite:///:memory:") SQLAlchemyInstrumentor().instrument(