diff --git a/aws_xray_sdk/core/patcher.py b/aws_xray_sdk/core/patcher.py index c8616b72..1a700dd9 100644 --- a/aws_xray_sdk/core/patcher.py +++ b/aws_xray_sdk/core/patcher.py @@ -26,6 +26,7 @@ 'psycopg2', 'pg8000', 'sqlalchemy_core', + 'httpx', ) NO_DOUBLE_PATCH = ( @@ -40,6 +41,7 @@ 'psycopg2', 'pg8000', 'sqlalchemy_core', + 'httpx', ) _PATCHED_MODULES = set() diff --git a/aws_xray_sdk/ext/httpx/__init__.py b/aws_xray_sdk/ext/httpx/__init__.py new file mode 100644 index 00000000..4e8acac6 --- /dev/null +++ b/aws_xray_sdk/ext/httpx/__init__.py @@ -0,0 +1,3 @@ +from .patch import patch + +__all__ = ['patch'] diff --git a/aws_xray_sdk/ext/httpx/patch.py b/aws_xray_sdk/ext/httpx/patch.py new file mode 100644 index 00000000..dfcd9bf8 --- /dev/null +++ b/aws_xray_sdk/ext/httpx/patch.py @@ -0,0 +1,71 @@ +import httpx + +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.models import http +from aws_xray_sdk.ext.util import inject_trace_header, get_hostname + + +def patch(): + httpx.Client = _InstrumentedClient + httpx.AsyncClient = _InstrumentedAsyncClient + httpx._api.Client = _InstrumentedClient + + +class _InstrumentedClient(httpx.Client): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._original_transport = self._transport + self._transport = SyncInstrumentedTransport(self._transport) + + +class _InstrumentedAsyncClient(httpx.AsyncClient): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._original_transport = self._transport + self._transport = AsyncInstrumentedTransport(self._transport) + + +class SyncInstrumentedTransport(httpx.BaseTransport): + def __init__(self, transport: httpx.BaseTransport): + self._wrapped_transport = transport + + def handle_request(self, request: httpx.Request) -> httpx.Response: + with xray_recorder.in_subsegment( + get_hostname(str(request.url)), namespace="remote" + ) as subsegment: + if subsegment is not None: + subsegment.put_http_meta(http.METHOD, request.method) + subsegment.put_http_meta( + http.URL, + str(request.url.copy_with(password=None, query=None, fragment=None)), + ) + inject_trace_header(request.headers, subsegment) + + response = self._wrapped_transport.handle_request(request) + if subsegment is not None: + subsegment.put_http_meta(http.STATUS, response.status_code) + return response + + +class AsyncInstrumentedTransport(httpx.AsyncBaseTransport): + def __init__(self, transport: httpx.AsyncBaseTransport): + self._wrapped_transport = transport + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + async with xray_recorder.in_subsegment_async( + get_hostname(str(request.url)), namespace="remote" + ) as subsegment: + if subsegment is not None: + subsegment.put_http_meta(http.METHOD, request.method) + subsegment.put_http_meta( + http.URL, + str(request.url.copy_with(password=None, query=None, fragment=None)), + ) + inject_trace_header(request.headers, subsegment) + + response = await self._wrapped_transport.handle_async_request(request) + if subsegment is not None: + subsegment.put_http_meta(http.STATUS, response.status_code) + return response diff --git a/tests/ext/httpx/__init__.py b/tests/ext/httpx/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ext/httpx/test_httpx.py b/tests/ext/httpx/test_httpx.py new file mode 100644 index 00000000..3bfeb967 --- /dev/null +++ b/tests/ext/httpx/test_httpx.py @@ -0,0 +1,218 @@ +import pytest + +import httpx +from aws_xray_sdk.core import patch +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.context import Context +from aws_xray_sdk.ext.util import strip_url, get_hostname + + +patch(("httpx",)) + +# httpbin.org is created by the same author of requests to make testing http easy. +BASE_URL = "httpbin.org" + + +@pytest.fixture(autouse=True) +def construct_ctx(): + """ + Clean up context storage on each test run and begin a segment + so that later subsegment can be attached. After each test run + it cleans up context storage again. + """ + xray_recorder.configure(service="test", sampling=False, context=Context()) + xray_recorder.clear_trace_entities() + xray_recorder.begin_segment("name") + yield + xray_recorder.clear_trace_entities() + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_ok(use_client): + status_code = 200 + url = "http://{}/status/{}?foo=bar".format(BASE_URL, status_code) + if use_client: + with httpx.Client() as client: + response = client.get(url) + else: + response = httpx.get(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert get_hostname(url) == BASE_URL + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "GET" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_error(use_client): + status_code = 400 + url = "http://{}/status/{}".format(BASE_URL, status_code) + if use_client: + with httpx.Client() as client: + response = client.post(url) + else: + response = httpx.post(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.error + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "POST" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_throttle(use_client): + status_code = 429 + url = "http://{}/status/{}".format(BASE_URL, status_code) + if use_client: + with httpx.Client() as client: + response = client.head(url) + else: + response = httpx.head(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.error + assert subsegment.throttle + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "HEAD" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_fault(use_client): + status_code = 500 + url = "http://{}/status/{}".format(BASE_URL, status_code) + if use_client: + with httpx.Client() as client: + response = client.put(url) + else: + response = httpx.put(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.fault + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "PUT" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_nonexistent_domain(use_client): + with pytest.raises(httpx.ConnectError): + if use_client: + with httpx.Client() as client: + client.get("http://doesnt.exist") + else: + httpx.get("http://doesnt.exist") + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.fault + + exception = subsegment.cause["exceptions"][0] + assert exception.type == "ConnectError" + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_invalid_url(use_client): + url = "KLSDFJKLSDFJKLSDJF" + with pytest.raises(httpx.UnsupportedProtocol): + if use_client: + with httpx.Client() as client: + client.get(url) + else: + httpx.get(url) + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.fault + + http_meta = subsegment.http + assert http_meta["request"]["url"] == "/{}".format(strip_url(url)) + + exception = subsegment.cause["exceptions"][0] + assert exception.type == "UnsupportedProtocol" + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_name_uses_hostname(use_client): + if use_client: + client = httpx.Client() + else: + client = httpx + + try: + url1 = "http://{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL) + client.get(url1) + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == BASE_URL + http_meta1 = subsegment.http + assert http_meta1["request"]["url"] == strip_url(url1) + assert http_meta1["request"]["method"].upper() == "GET" + + url2 = "http://{}/".format(BASE_URL) + client.get(url2, params={"some": "payload", "not": "toBeIncluded"}) + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == BASE_URL + http_meta2 = subsegment.http + assert http_meta2["request"]["url"] == strip_url(url2) + assert http_meta2["request"]["method"].upper() == "GET" + + url3 = "http://subdomain.{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL) + try: + client.get(url3) + except httpx.ConnectError: + pass + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == "subdomain." + BASE_URL + http_meta3 = subsegment.http + assert http_meta3["request"]["url"] == strip_url(url3) + assert http_meta3["request"]["method"].upper() == "GET" + finally: + if use_client: + client.close() + + +@pytest.mark.parametrize("use_client", (True, False)) +def test_strip_http_url(use_client): + status_code = 200 + url = "http://{}/get?foo=bar".format(BASE_URL) + if use_client: + with httpx.Client() as client: + response = client.get(url) + else: + response = httpx.get(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "GET" + assert http_meta["response"]["status"] == status_code diff --git a/tests/ext/httpx/test_httpx_async.py b/tests/ext/httpx/test_httpx_async.py new file mode 100644 index 00000000..c5d0560a --- /dev/null +++ b/tests/ext/httpx/test_httpx_async.py @@ -0,0 +1,190 @@ +import pytest + +import httpx +from aws_xray_sdk.core import patch +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.context import Context +from aws_xray_sdk.ext.util import strip_url, get_hostname + + +patch(("httpx",)) + +# httpbin.org is created by the same author of requests to make testing http easy. +BASE_URL = "httpbin.org" + + +@pytest.fixture(autouse=True) +def construct_ctx(): + """ + Clean up context storage on each test run and begin a segment + so that later subsegment can be attached. After each test run + it cleans up context storage again. + """ + xray_recorder.configure(service="test", sampling=False, context=Context()) + xray_recorder.clear_trace_entities() + xray_recorder.begin_segment("name") + yield + xray_recorder.clear_trace_entities() + + +@pytest.mark.asyncio +async def test_ok_async(): + status_code = 200 + url = "http://{}/status/{}?foo=bar".format(BASE_URL, status_code) + async with httpx.AsyncClient() as client: + response = await client.get(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert get_hostname(url) == BASE_URL + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "GET" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.asyncio +async def test_error_async(): + status_code = 400 + url = "http://{}/status/{}".format(BASE_URL, status_code) + async with httpx.AsyncClient() as client: + response = await client.post(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.error + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "POST" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.asyncio +async def test_throttle_async(): + status_code = 429 + url = "http://{}/status/{}".format(BASE_URL, status_code) + async with httpx.AsyncClient() as client: + response = await client.head(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.error + assert subsegment.throttle + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "HEAD" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.asyncio +async def test_fault_async(): + status_code = 500 + url = "http://{}/status/{}".format(BASE_URL, status_code) + async with httpx.AsyncClient() as client: + response = await client.put(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.fault + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "PUT" + assert http_meta["response"]["status"] == status_code + + +@pytest.mark.asyncio +async def test_nonexistent_domain_async(): + with pytest.raises(httpx.ConnectError): + async with httpx.AsyncClient() as client: + await client.get("http://doesnt.exist") + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.fault + + exception = subsegment.cause["exceptions"][0] + assert exception.type == "ConnectError" + + +@pytest.mark.asyncio +async def test_invalid_url_async(): + url = "KLSDFJKLSDFJKLSDJF" + with pytest.raises(httpx.UnsupportedProtocol): + async with httpx.AsyncClient() as client: + await client.get(url) + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + assert subsegment.fault + + http_meta = subsegment.http + assert http_meta["request"]["url"] == "/{}".format(strip_url(url)) + + exception = subsegment.cause["exceptions"][0] + assert exception.type == "UnsupportedProtocol" + + +@pytest.mark.asyncio +async def test_name_uses_hostname_async(): + async with httpx.AsyncClient() as client: + url1 = "http://{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL) + await client.get(url1) + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == BASE_URL + http_meta1 = subsegment.http + assert http_meta1["request"]["url"] == strip_url(url1) + assert http_meta1["request"]["method"].upper() == "GET" + + url2 = "http://{}/".format(BASE_URL) + await client.get(url2, params={"some": "payload", "not": "toBeIncluded"}) + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == BASE_URL + http_meta2 = subsegment.http + assert http_meta2["request"]["url"] == strip_url(url2) + assert http_meta2["request"]["method"].upper() == "GET" + + url3 = "http://subdomain.{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL) + try: + await client.get(url3) + except Exception: + # This is an invalid url so we dont want to break the test + pass + subsegment = xray_recorder.current_segment().subsegments[-1] + assert subsegment.namespace == "remote" + assert subsegment.name == "subdomain." + BASE_URL + http_meta3 = subsegment.http + assert http_meta3["request"]["url"] == strip_url(url3) + assert http_meta3["request"]["method"].upper() == "GET" + + +@pytest.mark.asyncio +async def test_strip_http_url_async(): + status_code = 200 + url = "http://{}/get?foo=bar".format(BASE_URL) + async with httpx.AsyncClient() as client: + response = await client.get(url) + assert "x-amzn-trace-id" in response._request.headers + + subsegment = xray_recorder.current_segment().subsegments[0] + assert subsegment.namespace == "remote" + assert subsegment.name == get_hostname(url) + + http_meta = subsegment.http + assert http_meta["request"]["url"] == strip_url(url) + assert http_meta["request"]["method"].upper() == "GET" + assert http_meta["response"]["status"] == status_code diff --git a/tox.ini b/tox.ini index 8d024bce..d513aec3 100644 --- a/tox.ini +++ b/tox.ini @@ -25,6 +25,8 @@ envlist = py{27,34,35,36,37,38,39}-ext-httplib + py{37,38,39}-ext-httpx + py{27,34,35,36,37,38,39}-ext-pg8000 py{27,34,35,36,37,38,39}-ext-psycopg2 @@ -75,6 +77,9 @@ deps = ; Also, the stable version is only supported for Python 3.7+ ext-aiohttp: pytest-aiohttp < 1.0.0 + ext-httpx: httpx >= 0.20 + ext-httpx: pytest-asyncio >= 0.19 + ext-requests: requests ext-bottle: bottle >= 0.10 @@ -135,6 +140,8 @@ commands = ext-httplib: coverage run --append --source aws_xray_sdk -m pytest tests/ext/httplib + ext-httpx: coverage run --append --source aws_xray_sdk -m pytest tests/ext/httpx + ext-pg8000: coverage run --append --source aws_xray_sdk -m pytest tests/ext/pg8000 ext-psycopg2: coverage run --append --source aws_xray_sdk -m pytest tests/ext/psycopg2