Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unique clients per function #1490

Merged
merged 9 commits into from
Aug 15, 2024
23 changes: 18 additions & 5 deletions azure_functions_worker/bindings/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def from_incoming_proto(
pytype: typing.Optional[type],
trigger_metadata: typing.Optional[typing.Dict[str, protos.TypedData]],
shmem_mgr: SharedMemoryManager,
function_name: str,
is_deferred_binding: typing.Optional[bool] = False) -> typing.Any:
binding = get_binding(binding, is_deferred_binding)
if trigger_metadata:
Expand Down Expand Up @@ -184,7 +185,8 @@ def from_incoming_proto(
pb=pb,
pytype=pytype,
datum=datum,
metadata=metadata)
metadata=metadata,
function_name=function_name)
return binding.decode(datum, trigger_metadata=metadata)
except NotImplementedError:
# Binding does not support the data.
Expand Down Expand Up @@ -281,29 +283,40 @@ def deferred_bindings_decode(binding: typing.Any,
pb: protos.ParameterBinding, *,
pytype: typing.Optional[type],
datum: typing.Any,
metadata: typing.Any):
metadata: typing.Any,
function_name: str):
"""
This cache holds deferred binding types (ie. BlobClient, ContainerClient)
That have already been created, so that the worker can reuse the
Previously created type without creating a new one.

For async types, the function_name is needed as a key to differentiate.
This prevents a known SDK issue where reusing a client across functions
can lose the session context and cause an error.

The cache key is based on: param name, type, resource, function_name

If cache is empty or key doesn't exist, deferred_binding_type is None
"""
global deferred_bindings_cache

if deferred_bindings_cache.get((pb.name,
pytype,
datum.value.content), None) is not None:
datum.value.content,
function_name), None) is not None:
return deferred_bindings_cache.get((pb.name,
pytype,
datum.value.content))
datum.value.content,
function_name))
else:
deferred_binding_type = binding.decode(datum,
trigger_metadata=metadata,
pytype=pytype)

deferred_bindings_cache[(pb.name,
pytype,
datum.value.content)] = deferred_binding_type
datum.value.content,
function_name)] = deferred_binding_type
return deferred_binding_type


Expand Down
2 changes: 2 additions & 0 deletions azure_functions_worker/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ async def _handle__invocation_request(self, request):
trigger_metadata=trigger_metadata,
pytype=pb_type_info.pytype,
shmem_mgr=self._shmem_mgr,
function_name=self._functions.get_function(
function_id).name,
is_deferred_binding=pb_type_info.deferred_bindings_enabled)

if http_v2_enabled:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,32 @@ def put_blob_bytes(req: func.HttpRequest, file: func.Out[bytes]) -> str:
@app.route(route="blob_cache")
def blob_cache(req: func.HttpRequest,
cachedClient: blob.BlobClient) -> str:
return cachedClient.download_blob(encoding='utf-8').readall()
return func.HttpResponse(repr(cachedClient))


@app.function_name(name="blob_cache2")
@app.blob_input(arg_name="cachedClient",
path="python-worker-tests/test-blobclient-triggered.txt",
connection="AzureWebJobsStorage")
@app.route(route="blob_cache2")
def blob_cache2(req: func.HttpRequest,
cachedClient: blob.BlobClient) -> func.HttpResponse:
return func.HttpResponse(repr(cachedClient))


@app.function_name(name="blob_cache3")
@app.blob_input(arg_name="cachedClient",
path="python-worker-tests/test-blobclient-triggered.txt",
connection="AzureWebJobsStorage")
@app.blob_input(arg_name="cachedClient2",
path="python-worker-tests/test-blobclient-triggered.txt",
connection="AzureWebJobsStorage")
@app.route(route="blob_cache3")
def blob_cache3(req: func.HttpRequest,
cachedClient: blob.BlobClient,
cachedClient2: blob.BlobClient) -> func.HttpResponse:
return func.HttpResponse("Client 1: " + repr(cachedClient)
+ " | Client 2: " + repr(cachedClient2))


@app.function_name(name="invalid_connection_info")
Expand All @@ -265,5 +290,5 @@ def blob_cache(req: func.HttpRequest,
connection="NotARealConnectionString")
@app.route(route="invalid_connection_info")
def invalid_connection_info(req: func.HttpRequest,
client: blob.BlobClient) -> str:
return client.download_blob(encoding='utf-8').readall()
client: blob.BlobClient) -> func.HttpResponse:
return func.HttpResponse(repr(client))
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def test_deferred_bindings_enabled_decode(self):
datum = datumdef.Datum(value=sample_mbd, type='model_binding_data')

obj = meta.deferred_bindings_decode(binding=binding, pb=pb,
pytype=BlobClient, datum=datum, metadata={})
pytype=BlobClient, datum=datum, metadata={},
function_name="test_function")

self.assertIsNotNone(obj)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from tests.utils import testutils

from azure_functions_worker.bindings import meta


@unittest.skipIf(sys.version_info.minor <= 8, "The base extension"
"is only supported for 3.9+.")
Expand Down Expand Up @@ -174,16 +172,58 @@ def test_type_undefined(self):
self.assertEqual(r.text, 'test-data')

def test_caching(self):
# Cache is empty at the start
self.assertEqual(meta.deferred_bindings_cache, {})
'''
The cache returns the same type based on resource and function name.
Two different functions with clients that access the same resource
will have two different clients. This tests that the same client
is returned for each invocation and that the clients are different
between the two functions.
'''

r = self.webhost.request('GET', 'blob_cache')
r2 = self.webhost.request('GET', 'blob_cache2')
self.assertEqual(r.status_code, 200)
self.assertEqual(r2.status_code, 200)
client = r.text
client2 = r2.text
self.assertNotEqual(client, client2)

r = self.webhost.request('GET', 'blob_cache')
r2 = self.webhost.request('GET', 'blob_cache2')
self.assertEqual(r.status_code, 200)
self.assertEqual(r2.status_code, 200)
self.assertEqual(r.text, client)
self.assertEqual(r2.text, client2)
self.assertNotEqual(r.text, r2.text)

r = self.webhost.request('GET', 'blob_cache')
self.assertEqual(r.status_code, 200)
r2 = self.webhost.request('GET', 'blob_cache2')
self.assertEqual(r.status_code, 200)
self.assertEqual(r2.status_code, 200)
self.assertEqual(r.text, client)
self.assertEqual(r2.text, client2)
self.assertNotEqual(r.text, r2.text)

def test_caching_same_resource(self):
'''
The cache returns the same type based on param name.
One functions with two clients that access the same resource
will have two different clients. This tests that the same clients
are returned for each invocation and that the clients are different
between the two bindings.
'''

r = self.webhost.request('GET', 'blob_cache3')
self.assertEqual(r.status_code, 200)
clients = r.text.split(" | ")
self.assertNotEqual(clients[0], clients[1])

r2 = self.webhost.request('GET', 'blob_cache3')
self.assertEqual(r2.status_code, 200)
clients_second_call = r2.text.split(" | ")
self.assertEqual(clients[0], clients_second_call[0])
self.assertEqual(clients[1], clients_second_call[1])
self.assertNotEqual(clients_second_call[0], clients_second_call[1])

def test_failed_client_creation(self):
r = self.webhost.request('GET', 'invalid_connection_info')
Expand Down
Loading