Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support functools.partial functions in AsyncioInstrumentor.trace_to_thread #2911

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def func():
---
"""
import asyncio
import functools
import sys
from asyncio import futures
from timeit import default_timer
Expand Down Expand Up @@ -231,14 +232,17 @@ def wrap_taskgroup_create_task(method, instance, args, kwargs) -> None:
def trace_to_thread(self, func: callable):
"""Trace a function."""
start = default_timer()
func_name = getattr(func, '__name__', None)
if func_name is None and isinstance(func, functools.partial):
func_name = func.func.__name__
span = (
self._tracer.start_span(
f"{ASYNCIO_PREFIX} to_thread-" + func.__name__
f"{ASYNCIO_PREFIX} to_thread-" + func_name
)
if func.__name__ in self._to_thread_name_to_trace
if func_name in self._to_thread_name_to_trace
else None
)
attr = {"type": "to_thread", "name": func.__name__}
attr = {"type": "to_thread", "name": func_name}
exception = None
try:
attr["state"] = "finished"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import functools
import sys
from unittest import skipIf
from unittest.mock import patch
Expand Down Expand Up @@ -72,3 +73,37 @@ async def to_thread():
for point in metric.data.data_points:
self.assertEqual(point.attributes["type"], "to_thread")
self.assertEqual(point.attributes["name"], "multiply")

@skipIf(
sys.version_info < (3, 9), "to_thread is only available in Python 3.9+"
)
def test_to_thread_partial_func(self):
def multiply(x, y):
return x * y

double = functools.partial(multiply, 2)

async def to_thread():
result = await asyncio.to_thread(double, 3)
assert result == 6

with self._tracer.start_as_current_span("root"):
asyncio.run(to_thread())
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 2)
assert spans[0].name == "asyncio to_thread-multiply"
for metric in (
self.memory_metrics_reader.get_metrics_data()
.resource_metrics[0]
.scope_metrics[0]
.metrics
):
if metric.name == "asyncio.process.duration":
for point in metric.data.data_points:
self.assertEqual(point.attributes["type"], "to_thread")
self.assertEqual(point.attributes["name"], "multiply")
if metric.name == "asyncio.process.created":
for point in metric.data.data_points:
self.assertEqual(point.attributes["type"], "to_thread")
self.assertEqual(point.attributes["name"], "multiply")
Loading