Skip to content

Commit

Permalink
Add type hints to Starlette instrumentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Nov 25, 2024
1 parent 5c5fc73 commit 45989bb
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
API
---
"""
# pyright: reportPrivateUsage=false

from typing import Collection
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Collection, cast

from starlette import applications
from starlette.routing import Match
Expand All @@ -184,18 +187,29 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.starlette.package import _instruments
from opentelemetry.instrumentation.starlette.version import __version__
from opentelemetry.metrics import get_meter
from opentelemetry.metrics import MeterProvider, get_meter
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import get_tracer
from opentelemetry.trace import TracerProvider, get_tracer
from opentelemetry.util.http import get_excluded_urls

if TYPE_CHECKING:
from typing import NotRequired, TypedDict, Unpack

class InstrumentKwargs(TypedDict):
tracer_provider: NotRequired[TracerProvider]
meter_provider: NotRequired[MeterProvider]
server_request_hook: NotRequired[ServerRequestHook]
client_request_hook: NotRequired[ClientRequestHook]
client_response_hook: NotRequired[ClientResponseHook]


_excluded_urls = get_excluded_urls("STARLETTE")


class StarletteInstrumentor(BaseInstrumentor):
"""An instrumentor for starlette
"""An instrumentor for Starlette.
See `BaseInstrumentor`
See `BaseInstrumentor`.
"""

_original_starlette = None
Expand All @@ -206,8 +220,8 @@ def instrument_app(
server_request_hook: ServerRequestHook = None,
client_request_hook: ClientRequestHook = None,
client_response_hook: ClientResponseHook = None,
meter_provider=None,
tracer_provider=None,
meter_provider: MeterProvider | None = None,
tracer_provider: TracerProvider | None = None,
):
"""Instrument an uninstrumented Starlette application."""
tracer = get_tracer(
Expand Down Expand Up @@ -242,34 +256,24 @@ def instrument_app(

@staticmethod
def uninstrument_app(app: applications.Starlette):
app.user_middleware = [
x
for x in app.user_middleware
if x.cls is not OpenTelemetryMiddleware
]
app.user_middleware = [x for x in app.user_middleware if x.cls is not OpenTelemetryMiddleware]
app.middleware_stack = app.build_middleware_stack()
app._is_instrumented_by_opentelemetry = False

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Unpack[InstrumentKwargs]):
self._original_starlette = applications.Starlette
_InstrumentedStarlette._tracer_provider = kwargs.get("tracer_provider")
_InstrumentedStarlette._server_request_hook = kwargs.get(
"server_request_hook"
)
_InstrumentedStarlette._client_request_hook = kwargs.get(
"client_request_hook"
)
_InstrumentedStarlette._client_response_hook = kwargs.get(
"client_response_hook"
)
_InstrumentedStarlette._server_request_hook = kwargs.get("server_request_hook")
_InstrumentedStarlette._client_request_hook = kwargs.get("client_request_hook")
_InstrumentedStarlette._client_response_hook = kwargs.get("client_response_hook")
_InstrumentedStarlette._meter_provider = kwargs.get("_meter_provider")

applications.Starlette = _InstrumentedStarlette

def _uninstrument(self, **kwargs):
def _uninstrument(self, **kwargs: Any):
"""uninstrumenting all created apps by user"""
for instance in _InstrumentedStarlette._instrumented_starlette_apps:
self.uninstrument_app(instance)
Expand All @@ -278,14 +282,14 @@ def _uninstrument(self, **kwargs):


class _InstrumentedStarlette(applications.Starlette):
_tracer_provider = None
_meter_provider = None
_tracer_provider: TracerProvider | None = None
_meter_provider: MeterProvider | None = None
_server_request_hook: ServerRequestHook = None
_client_request_hook: ClientRequestHook = None
_client_response_hook: ClientResponseHook = None
_instrumented_starlette_apps = set()
_instrumented_starlette_apps: set[applications.Starlette] = set()

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
tracer = get_tracer(
__name__,
Expand Down Expand Up @@ -318,21 +322,22 @@ def __del__(self):
_InstrumentedStarlette._instrumented_starlette_apps.remove(self)


def _get_route_details(scope):
def _get_route_details(scope: dict[str, Any]) -> str | None:
"""
Function to retrieve Starlette route from scope.
Function to retrieve Starlette route from ASGI scope.
TODO: there is currently no way to retrieve http.route from
a starlette application from scope.
See: https://github.com/encode/starlette/pull/804
Args:
scope: A Starlette scope
scope: The ASGI scope that contains the Starlette application in the "app" key.
Returns:
A string containing the route or None
The path to the route if found, otherwise None.
"""
app = scope["app"]
route = None
app = cast(applications.Starlette, scope["app"])
route: str | None = None

for starlette_route in app.routes:
match, _ = starlette_route.matches(scope)
Expand All @@ -344,18 +349,18 @@ def _get_route_details(scope):
return route


def _get_default_span_details(scope):
"""
Callback to retrieve span name and attributes from scope.
def _get_default_span_details(scope: dict[str, Any]) -> tuple[str, dict[str, Any]]:
"""Callback to retrieve span name and attributes from ASGI scope.
Args:
scope: A Starlette scope
scope: The ASGI scope that contains the Starlette application in the "app" key.
Returns:
A tuple of span name and attributes
A tuple of span name and attributes.
"""
route = _get_route_details(scope)
method = scope.get("method", "")
attributes = {}
method: str = scope.get("method", "")
attributes: dict[str, Any] = {}
if route:
attributes[SpanAttributes.HTTP_ROUTE] = route
if method and route: # http
Expand Down

0 comments on commit 45989bb

Please sign in to comment.