Skip to content

Commit

Permalink
[ext.httpx] Call inject_trace_header with correct subsegment
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-k committed Oct 11, 2022
1 parent 3c04255 commit 2447967
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 68 deletions.
76 changes: 31 additions & 45 deletions aws_xray_sdk/ext/httpx/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from aws_xray_sdk.core import xray_recorder
from aws_xray_sdk.core.models import http
from aws_xray_sdk.ext.util import UNKNOWN_HOSTNAME, inject_trace_header
from aws_xray_sdk.ext.util import inject_trace_header, get_hostname


def patch():
Expand Down Expand Up @@ -32,54 +32,40 @@ def __init__(self, transport: httpx.BaseTransport):
self._wrapped_transport = transport

def handle_request(self, request: httpx.Request) -> httpx.Response:
def httpx_processor(return_value, exception, subsegment, stack, **kwargs):
subsegment.put_http_meta(http.METHOD, request.method)
subsegment.put_http_meta(
http.URL,
str(request.url.copy_with(password=None, query=None, fragment=None)),
)

if return_value is not None:
subsegment.put_http_meta(http.STATUS, return_value.status_code)
elif exception:
subsegment.add_exception(exception, stack)

inject_trace_header(request.headers, xray_recorder.current_subsegment())
return xray_recorder.record_subsegment(
wrapped=self._wrapped_transport.handle_request,
instance=self._wrapped_transport,
args=(request,),
kwargs={},
name=request.url.host or UNKNOWN_HOSTNAME,
namespace="remote",
meta_processor=httpx_processor,
)
with xray_recorder.in_subsegment(
get_hostname(str(request.url)), namespace="remote"
) as subsegment:
if subsegment is not None:
subsegment.put_http_meta(http.METHOD, request.method)
subsegment.put_http_meta(
http.URL,
str(request.url.copy_with(password=None, query=None, fragment=None)),
)
inject_trace_header(request.headers, subsegment)

response = self._wrapped_transport.handle_request(request)
if subsegment is not None:
subsegment.put_http_meta(http.STATUS, response.status_code)
return response


class AsyncInstrumentedTransport(httpx.AsyncBaseTransport):
def __init__(self, transport: httpx.AsyncBaseTransport):
self._wrapped_transport = transport

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
def httpx_processor(return_value, exception, subsegment, stack, **kwargs):
subsegment.put_http_meta(http.METHOD, request.method)
subsegment.put_http_meta(
http.URL,
str(request.url.copy_with(password=None, query=None, fragment=None)),
)

if return_value is not None:
subsegment.put_http_meta(http.STATUS, return_value.status_code)
elif exception:
subsegment.add_exception(exception, stack)

inject_trace_header(request.headers, xray_recorder.current_subsegment())
return await xray_recorder.record_subsegment_async(
wrapped=self._wrapped_transport.handle_async_request,
instance=self._wrapped_transport,
args=(request,),
kwargs={},
name=request.url.host or UNKNOWN_HOSTNAME,
namespace="remote",
meta_processor=httpx_processor,
)
async with xray_recorder.in_subsegment_async(
get_hostname(str(request.url)), namespace="remote"
) as subsegment:
if subsegment is not None:
subsegment.put_http_meta(http.METHOD, request.method)
subsegment.put_http_meta(
http.URL,
str(request.url.copy_with(password=None, query=None, fragment=None)),
)
inject_trace_header(request.headers, subsegment)

response = await self._wrapped_transport.handle_async_request(request)
if subsegment is not None:
subsegment.put_http_meta(http.STATUS, response.status_code)
return response
40 changes: 30 additions & 10 deletions tests/ext/httpx/test_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ def test_ok(use_client):
url = "http://{}/status/{}?foo=bar".format(BASE_URL, status_code)
if use_client:
with httpx.Client() as client:
client.get(url)
response = client.get(url)
else:
httpx.get(url)
response = httpx.get(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert get_hostname(url) == BASE_URL
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)

http_meta = subsegment.http
Expand All @@ -52,10 +55,13 @@ def test_error(use_client):
url = "http://{}/status/{}".format(BASE_URL, status_code)
if use_client:
with httpx.Client() as client:
client.post(url)
response = client.post(url)
else:
httpx.post(url)
response = httpx.post(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)
assert subsegment.error

Expand All @@ -71,10 +77,13 @@ def test_throttle(use_client):
url = "http://{}/status/{}".format(BASE_URL, status_code)
if use_client:
with httpx.Client() as client:
client.head(url)
response = client.head(url)
else:
httpx.head(url)
response = httpx.head(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)
assert subsegment.error
assert subsegment.throttle
Expand All @@ -91,10 +100,13 @@ def test_fault(use_client):
url = "http://{}/status/{}".format(BASE_URL, status_code)
if use_client:
with httpx.Client() as client:
client.put(url)
response = client.put(url)
else:
httpx.put(url)
response = httpx.put(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)
assert subsegment.fault

Expand All @@ -114,6 +126,7 @@ def test_nonexistent_domain(use_client):
httpx.get("http://doesnt.exist")

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.fault

exception = subsegment.cause["exceptions"][0]
Expand All @@ -131,6 +144,7 @@ def test_invalid_url(use_client):
httpx.get(url)

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)
assert subsegment.fault

Expand All @@ -152,6 +166,7 @@ def test_name_uses_hostname(use_client):
url1 = "http://{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL)
client.get(url1)
subsegment = xray_recorder.current_segment().subsegments[-1]
assert subsegment.namespace == "remote"
assert subsegment.name == BASE_URL
http_meta1 = subsegment.http
assert http_meta1["request"]["url"] == strip_url(url1)
Expand All @@ -160,6 +175,7 @@ def test_name_uses_hostname(use_client):
url2 = "http://{}/".format(BASE_URL)
client.get(url2, params={"some": "payload", "not": "toBeIncluded"})
subsegment = xray_recorder.current_segment().subsegments[-1]
assert subsegment.namespace == "remote"
assert subsegment.name == BASE_URL
http_meta2 = subsegment.http
assert http_meta2["request"]["url"] == strip_url(url2)
Expand All @@ -171,6 +187,7 @@ def test_name_uses_hostname(use_client):
except httpx.ConnectError:
pass
subsegment = xray_recorder.current_segment().subsegments[-1]
assert subsegment.namespace == "remote"
assert subsegment.name == "subdomain." + BASE_URL
http_meta3 = subsegment.http
assert http_meta3["request"]["url"] == strip_url(url3)
Expand All @@ -186,10 +203,13 @@ def test_strip_http_url(use_client):
url = "http://{}/get?foo=bar".format(BASE_URL)
if use_client:
with httpx.Client() as client:
client.get(url)
response = client.get(url)
else:
httpx.get(url)
response = httpx.get(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)

http_meta = subsegment.http
Expand Down
42 changes: 29 additions & 13 deletions tests/ext/httpx/test_httpx_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ async def test_ok_async():
status_code = 200
url = "http://{}/status/{}?foo=bar".format(BASE_URL, status_code)
async with httpx.AsyncClient() as client:
await client.get(url)
response = await client.get(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert get_hostname(url) == BASE_URL
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)

http_meta = subsegment.http
Expand All @@ -48,8 +51,11 @@ async def test_error_async():
status_code = 400
url = "http://{}/status/{}".format(BASE_URL, status_code)
async with httpx.AsyncClient() as client:
await client.post(url)
response = await client.post(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)
assert subsegment.error

Expand All @@ -64,8 +70,11 @@ async def test_throttle_async():
status_code = 429
url = "http://{}/status/{}".format(BASE_URL, status_code)
async with httpx.AsyncClient() as client:
await client.head(url)
response = await client.head(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)
assert subsegment.error
assert subsegment.throttle
Expand All @@ -81,8 +90,11 @@ async def test_fault_async():
status_code = 500
url = "http://{}/status/{}".format(BASE_URL, status_code)
async with httpx.AsyncClient() as client:
await client.put(url)
response = await client.put(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)
assert subsegment.fault

Expand All @@ -94,13 +106,12 @@ async def test_fault_async():

@pytest.mark.asyncio
async def test_nonexistent_domain_async():
try:
with pytest.raises(httpx.ConnectError):
async with httpx.AsyncClient() as client:
await client.get("http://doesnt.exist")
except Exception:
# prevent uncatch exception from breaking test run
pass

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.fault

exception = subsegment.cause["exceptions"][0]
Expand All @@ -110,13 +121,12 @@ async def test_nonexistent_domain_async():
@pytest.mark.asyncio
async def test_invalid_url_async():
url = "KLSDFJKLSDFJKLSDJF"
try:
with pytest.raises(httpx.UnsupportedProtocol):
async with httpx.AsyncClient() as client:
await client.get(url)
except Exception:
# prevent uncatch exception from breaking test run
pass

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)
assert subsegment.fault

Expand All @@ -133,6 +143,7 @@ async def test_name_uses_hostname_async():
url1 = "http://{}/fakepath/stuff/koo/lai/ahh".format(BASE_URL)
await client.get(url1)
subsegment = xray_recorder.current_segment().subsegments[-1]
assert subsegment.namespace == "remote"
assert subsegment.name == BASE_URL
http_meta1 = subsegment.http
assert http_meta1["request"]["url"] == strip_url(url1)
Expand All @@ -141,6 +152,7 @@ async def test_name_uses_hostname_async():
url2 = "http://{}/".format(BASE_URL)
await client.get(url2, params={"some": "payload", "not": "toBeIncluded"})
subsegment = xray_recorder.current_segment().subsegments[-1]
assert subsegment.namespace == "remote"
assert subsegment.name == BASE_URL
http_meta2 = subsegment.http
assert http_meta2["request"]["url"] == strip_url(url2)
Expand All @@ -153,6 +165,7 @@ async def test_name_uses_hostname_async():
# This is an invalid url so we dont want to break the test
pass
subsegment = xray_recorder.current_segment().subsegments[-1]
assert subsegment.namespace == "remote"
assert subsegment.name == "subdomain." + BASE_URL
http_meta3 = subsegment.http
assert http_meta3["request"]["url"] == strip_url(url3)
Expand All @@ -164,8 +177,11 @@ async def test_strip_http_url_async():
status_code = 200
url = "http://{}/get?foo=bar".format(BASE_URL)
async with httpx.AsyncClient() as client:
await client.get(url)
response = await client.get(url)
assert "x-amzn-trace-id" in response._request.headers

subsegment = xray_recorder.current_segment().subsegments[0]
assert subsegment.namespace == "remote"
assert subsegment.name == get_hostname(url)

http_meta = subsegment.http
Expand Down

0 comments on commit 2447967

Please sign in to comment.