Skip to content

Commit

Permalink
Enable and fix some Python typing errors in the executors package.
Browse files Browse the repository at this point in the history
This revealed two Python typing-related issues:
1. Every subclass of `executor_base.Executor` was using a property on the corresponding subclass of `executor_value_base.ExecutorValue` without having access to that property. This was fixed by:
2. The methods on the subclasses of `federating_executor.FederatingStrategy` (`FederatedComposingStrategy` and `FederatedResolvingStrategy`) were annotated to return types specific to that subclass. This return type conflicted with the return types of the shared implementations defined in `executor_utils`.
3. The type annotation of `target_executors` parameter of `FederatedResolvingStrategy` was incorrect.

* Enabled pytype in the following modules:
    * `federated_composing_strategy`
    * `federated_resolving_strategy`
    * `federating_executor`
    * `reference_resolving_executor`
    * `remote_executor`
    * `thread_delegating_executor`
    * `value_serialization`

To fix issue #1:

* Promoted the `reference` property from the subclasses of `executor_value_base.ExecutorValue` to `executor_value_base.ExecutorValue`.
* Renamed usage of this API to be consistent.

To fix issue #2:

* Updated the methods on the subclasses of `federating_executor.FederatingStrategy` (`FederatedComposingStrategy` and `FederatedResolvingStrategy`) to accept and return `executor_value_base.ExecutorValue` types.

Note: That this does make the return types more generic, but I believe this is the intended interface. An alternative fix could be to update the `executor_utils` function to take and return `TypeVar`'s (i.e. generic functions). However, I don't think this is the intended interface.

To fix issue #3:

* Updated the type annotation of the `target_executors` parameter to `dict[placements.PlacementLiteral, Union[list[executor_base.Executor], executor_base.Executor]]`.

PiperOrigin-RevId: 502890245
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Jan 18, 2023
1 parent 975b2d6 commit c731295
Show file tree
Hide file tree
Showing 18 changed files with 331 additions and 353 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def type_signature(self) -> computation_types.Type:
return self._type_signature

@property
def ref(self) -> int:
"""Hands out a reference to self without transferring ownership."""
def reference(self) -> int:
return self._owned_value_id.ref

@tracing.trace
Expand Down Expand Up @@ -119,9 +118,9 @@ async def create_call(
fn: CppToPythonExecutorValue,
arg: Optional[CppToPythonExecutorValue] = None
) -> CppToPythonExecutorValue:
fn_ref = fn.ref
fn_ref = fn.reference
if arg is not None:
arg_ref = arg.ref
arg_ref = arg.reference
else:
arg_ref = None
try:
Expand All @@ -139,7 +138,7 @@ async def create_struct(
id_list = []
type_list = []
for name, value in structure.iter_elements(executor_value_struct):
id_list.append(value.ref)
id_list.append(value.reference)
type_list.append((name, value.type_signature))
try:
struct_id = self._cpp_executor.create_struct(id_list)
Expand All @@ -153,7 +152,9 @@ async def create_struct(
async def create_selection(self, source: CppToPythonExecutorValue,
index: int) -> CppToPythonExecutorValue:
try:
selection_id = self._cpp_executor.create_selection(source.ref, index)
selection_id = self._cpp_executor.create_selection(
source.reference, index
)
except Exception as e: # pylint: disable=broad-except
_handle_error(e)
selection_type = source.type_signature[index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def to_representation_for_type(
return to_representation_for_type(
value, tf_function_cache, type_spec=type_spec, device=None)
elif isinstance(value, EagerValue):
return value.internal_representation
return value.reference
elif isinstance(value, executor_value_base.ExecutorValue):
raise TypeError(
'Cannot accept a value embedded within a non-eager executor.')
Expand Down Expand Up @@ -568,12 +568,7 @@ def __init__(self, value, type_spec):
self._value = value

@property
def internal_representation(self):
"""Returns a representation of the eager value embedded in the executor.
This property is only intended for use by the eager executor and tests. Not
for consumption by consumers of the executor interface.
"""
def reference(self):
return self._value

@property
Expand Down Expand Up @@ -619,8 +614,8 @@ class EagerTFExecutor(executor_base.Executor):
One further implementation detail is worth noting. Like all executors, this
executor embeds incoming data as an instance of an executor-specific class,
here the `EagerValue`. All `EagerValues` are assumed in this implmentation
to have an `internal_representation` which is a fixed point under the action
of `to_representation_for_type` with type the `type_signature` attribute of
to have an `reference` which is a fixed point under the action of
`to_representation_for_type` with type the `type_signature` attribute of
the `EagerValue`. This invariant is introduced by normalization in
`create_value`, and is respected by the form of returned `EagerValues` in all
other methods this executor exposes.
Expand Down Expand Up @@ -711,11 +706,10 @@ async def create_call(self, comp, arg=None):
comp.type_signature))
if comp.type_signature.parameter is not None:
return EagerValue(
comp.internal_representation(arg.internal_representation),
comp.type_signature.result)
comp.reference(arg.reference), comp.type_signature.result
)
elif arg is None:
return EagerValue(comp.internal_representation(),
comp.type_signature.result)
return EagerValue(comp.reference(), comp.type_signature.result)
else:
raise TypeError('Cannot pass an argument to a no-argument function.')

Expand All @@ -734,7 +728,7 @@ async def create_struct(self, elements):
type_elements = []
for k, v in elements:
py_typecheck.check_type(v, EagerValue)
val_elements.append((k, v.internal_representation))
val_elements.append((k, v.reference))
type_elements.append((k, v.type_signature))
return EagerValue(
structure.Struct(val_elements),
Expand All @@ -759,10 +753,9 @@ async def create_selection(self, source, index):
"""
py_typecheck.check_type(source, EagerValue)
py_typecheck.check_type(source.type_signature, computation_types.StructType)
py_typecheck.check_type(source.internal_representation, structure.Struct)
py_typecheck.check_type(source.reference, structure.Struct)
py_typecheck.check_type(index, int)
return EagerValue(source.internal_representation[index],
source.type_signature[index])
return EagerValue(source.reference[index], source.type_signature[index])

def close(self):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,8 @@ def test_eager_value_constructor_with_int_constant(self):
10, {}, int_tensor_type)
v = eager_tf_executor.EagerValue(normalized_value, int_tensor_type)
self.assertEqual(str(v.type_signature), 'int32')
self.assertIsInstance(v.internal_representation, tf.Tensor)
self.assertEqual(v.internal_representation, 10)
self.assertIsInstance(v.reference, tf.Tensor)
self.assertEqual(v.reference, 10)

def test_executor_constructor_fails_if_not_in_eager_mode(self):
with tf.Graph().as_default():
Expand All @@ -427,9 +427,9 @@ def test_executor_create_value_int(self):
ex = eager_tf_executor.EagerTFExecutor()
val = asyncio.run(ex.create_value(10, tf.int32))
self.assertIsInstance(val, eager_tf_executor.EagerValue)
self.assertIsInstance(val.internal_representation, tf.Tensor)
self.assertIsInstance(val.reference, tf.Tensor)
self.assertEqual(str(val.type_signature), 'int32')
self.assertEqual(val.internal_representation, 10)
self.assertEqual(val.reference, 10)

def test_executor_create_value_raises_on_lambda(self):
ex = eager_tf_executor.EagerTFExecutor()
Expand Down Expand Up @@ -458,15 +458,15 @@ def test_executor_create_value_unnamed_int_pair(self):
}], [tf.int32, collections.OrderedDict([('a', tf.int32)])]))
self.assertIsInstance(val, eager_tf_executor.EagerValue)
self.assertEqual(str(val.type_signature), '<int32,<a=int32>>')
self.assertIsInstance(val.internal_representation, structure.Struct)
self.assertLen(val.internal_representation, 2)
self.assertIsInstance(val.internal_representation[0], tf.Tensor)
self.assertIsInstance(val.internal_representation[1], structure.Struct)
self.assertLen(val.internal_representation[1], 1)
self.assertEqual(dir(val.internal_representation[1]), ['a'])
self.assertIsInstance(val.internal_representation[1][0], tf.Tensor)
self.assertEqual(val.internal_representation[0], 10)
self.assertEqual(val.internal_representation[1][0], 20)
self.assertIsInstance(val.reference, structure.Struct)
self.assertLen(val.reference, 2)
self.assertIsInstance(val.reference[0], tf.Tensor)
self.assertIsInstance(val.reference[1], structure.Struct)
self.assertLen(val.reference[1], 1)
self.assertEqual(dir(val.reference[1]), ['a'])
self.assertIsInstance(val.reference[1][0], tf.Tensor)
self.assertEqual(val.reference[0], 10)
self.assertEqual(val.reference[1][0], 20)

def test_executor_create_value_named_type_unnamed_value(self):
ex = eager_tf_executor.EagerTFExecutor()
Expand All @@ -475,12 +475,12 @@ def test_executor_create_value_named_type_unnamed_value(self):
collections.OrderedDict(a=tf.int32, b=tf.int32)))
self.assertIsInstance(val, eager_tf_executor.EagerValue)
self.assertEqual(str(val.type_signature), '<a=int32,b=int32>')
self.assertIsInstance(val.internal_representation, structure.Struct)
self.assertLen(val.internal_representation, 2)
self.assertIsInstance(val.internal_representation[0], tf.Tensor)
self.assertIsInstance(val.internal_representation[1], tf.Tensor)
self.assertEqual(val.internal_representation[0], 10)
self.assertEqual(val.internal_representation[1], 20)
self.assertIsInstance(val.reference, structure.Struct)
self.assertLen(val.reference, 2)
self.assertIsInstance(val.reference[0], tf.Tensor)
self.assertIsInstance(val.reference[1], tf.Tensor)
self.assertEqual(val.reference[0], 10)
self.assertEqual(val.reference[1], 20)

def test_executor_create_value_no_arg_computation(self):
ex = eager_tf_executor.EagerTFExecutor()
Expand All @@ -495,8 +495,8 @@ def comp():
computation_types.FunctionType(None, tf.int32)))
self.assertIsInstance(val, eager_tf_executor.EagerValue)
self.assertEqual(str(val.type_signature), '( -> int32)')
self.assertTrue(callable(val.internal_representation))
result = val.internal_representation()
self.assertTrue(callable(val.reference))
result = val.reference()
self.assertIsInstance(result, tf.Tensor)
self.assertEqual(result, 1000)

Expand All @@ -516,9 +516,9 @@ def comp(a, b):
('b', tf.int32)]), tf.int32)))
self.assertIsInstance(val, eager_tf_executor.EagerValue)
self.assertEqual(str(val.type_signature), '(<a=int32,b=int32> -> int32)')
self.assertTrue(callable(val.internal_representation))
self.assertTrue(callable(val.reference))
arg = structure.Struct([('a', tf.constant(10)), ('b', tf.constant(10))])
result = val.internal_representation(arg)
result = val.reference(arg)
self.assertIsInstance(result, tf.Tensor)
self.assertEqual(result, 20)

Expand All @@ -537,8 +537,8 @@ def comp(a, b):
result = asyncio.run(ex.create_call(comp, arg))
self.assertIsInstance(result, eager_tf_executor.EagerValue)
self.assertEqual(str(result.type_signature), 'int32')
self.assertIsInstance(result.internal_representation, tf.Tensor)
self.assertEqual(result.internal_representation, 30)
self.assertIsInstance(result.reference, tf.Tensor)
self.assertEqual(result.reference, 30)

def test_dynamic_lookup_table_usage(self):

Expand Down Expand Up @@ -566,8 +566,8 @@ def comp(table_args, to_lookup):
result_1 = asyncio.run(ex.create_call(comp, arg_1))
result_2 = asyncio.run(ex.create_call(comp, arg_2))

self.assertEqual(self.evaluate(result_1.internal_representation), 0)
self.assertEqual(self.evaluate(result_2.internal_representation), 3)
self.assertEqual(self.evaluate(result_1.reference), 0)
self.assertEqual(self.evaluate(result_2.reference), 3)

# TODO(b/137602785): bring GPU test back after the fix for `wrap_function`.
@tensorflow_test_utils.skip_test_for_gpu
Expand All @@ -585,9 +585,8 @@ def comp(ds):
result = asyncio.run(ex.create_call(comp, arg))
self.assertIsInstance(result, eager_tf_executor.EagerValue)
self.assertEqual(str(result.type_signature), 'int32*')
self.assertIn('Dataset', type(result.internal_representation).__name__)
self.assertCountEqual([x.numpy() for x in result.internal_representation],
[10, 20])
self.assertIn('Dataset', type(result.reference).__name__)
self.assertCountEqual([x.numpy() for x in result.reference], [10, 20])

# TODO(b/137602785): bring GPU test back after the fix for `wrap_function`.
@tensorflow_test_utils.skip_test_for_gpu
Expand All @@ -612,9 +611,8 @@ def comp(ds):
result = asyncio.run(ex.create_call(comp, arg))
self.assertIsInstance(result, eager_tf_executor.EagerValue)
self.assertEqual(str(result.type_signature), 'int64*')
self.assertIn('Dataset', type(result.internal_representation).__name__)
self.assertCountEqual([x.numpy() for x in result.internal_representation],
[0, 1])
self.assertIn('Dataset', type(result.reference).__name__)
self.assertCountEqual([x.numpy() for x in result.reference], [0, 1])

# TODO(b/137602785): bring GPU test back after the fix for `wrap_function`.
@tensorflow_test_utils.skip_test_for_gpu
Expand All @@ -632,9 +630,8 @@ def comp(ds):
result = asyncio.run(ex.create_call(comp, arg))
self.assertIsInstance(result, eager_tf_executor.EagerValue)
self.assertEqual(str(result.type_signature), 'int32*')
self.assertIn('Dataset', type(result.internal_representation).__name__)
self.assertCountEqual([x.numpy() for x in result.internal_representation],
[10, 10, 10])
self.assertIn('Dataset', type(result.reference).__name__)
self.assertCountEqual([x.numpy() for x in result.reference], [10, 10, 10])

# TODO(b/137602785): bring GPU test back after the fix for `wrap_function`.
@tensorflow_test_utils.skip_test_for_gpu
Expand All @@ -652,8 +649,8 @@ def comp(ds):
result = asyncio.run(ex.create_call(comp, arg))
self.assertIsInstance(result, eager_tf_executor.EagerValue)
self.assertEqual(str(result.type_signature), 'int32')
self.assertIsInstance(result.internal_representation, tf.Tensor)
self.assertEqual(result.internal_representation, 90)
self.assertIsInstance(result.reference, tf.Tensor)
self.assertEqual(result.reference, 90)

# TODO(b/137602785): bring GPU test back after the fix for `wrap_function`.
@tensorflow_test_utils.skip_test_for_gpu
Expand All @@ -675,12 +672,12 @@ def comp(ds):
result = asyncio.run(ex.create_call(comp, arg))
self.assertIsInstance(result, eager_tf_executor.EagerValue)
self.assertEqual(str(result.type_signature), '<a=int32,b=int32>')
self.assertIsInstance(result.internal_representation, structure.Struct)
self.assertCountEqual(dir(result.internal_representation), ['a', 'b'])
self.assertIsInstance(result.internal_representation.a, tf.Tensor)
self.assertIsInstance(result.internal_representation.b, tf.Tensor)
self.assertEqual(result.internal_representation.a, 60)
self.assertEqual(result.internal_representation.b, 15)
self.assertIsInstance(result.reference, structure.Struct)
self.assertCountEqual(dir(result.reference), ['a', 'b'])
self.assertIsInstance(result.reference.a, tf.Tensor)
self.assertIsInstance(result.reference.b, tf.Tensor)
self.assertEqual(result.reference.a, 60)
self.assertEqual(result.reference.b, 15)

def test_executor_create_struct_and_selection(self):
ex = eager_tf_executor.EagerTFExecutor()
Expand All @@ -693,24 +690,24 @@ async def gather_values(values):
v3 = asyncio.run(
ex.create_struct(collections.OrderedDict([('a', v1), ('b', v2)])))
self.assertIsInstance(v3, eager_tf_executor.EagerValue)
self.assertIsInstance(v3.internal_representation, structure.Struct)
self.assertLen(v3.internal_representation, 2)
self.assertCountEqual(dir(v3.internal_representation), ['a', 'b'])
self.assertIsInstance(v3.internal_representation[0], tf.Tensor)
self.assertIsInstance(v3.internal_representation[1], tf.Tensor)
self.assertIsInstance(v3.reference, structure.Struct)
self.assertLen(v3.reference, 2)
self.assertCountEqual(dir(v3.reference), ['a', 'b'])
self.assertIsInstance(v3.reference[0], tf.Tensor)
self.assertIsInstance(v3.reference[1], tf.Tensor)
self.assertEqual(str(v3.type_signature), '<a=int32,b=int32>')
self.assertEqual(v3.internal_representation[0], 10)
self.assertEqual(v3.internal_representation[1], 20)
self.assertEqual(v3.reference[0], 10)
self.assertEqual(v3.reference[1], 20)
v4 = asyncio.run(ex.create_selection(v3, 0))
self.assertIsInstance(v4, eager_tf_executor.EagerValue)
self.assertIsInstance(v4.internal_representation, tf.Tensor)
self.assertIsInstance(v4.reference, tf.Tensor)
self.assertEqual(str(v4.type_signature), 'int32')
self.assertEqual(v4.internal_representation, 10)
self.assertEqual(v4.reference, 10)
v5 = asyncio.run(ex.create_selection(v3, 1))
self.assertIsInstance(v5, eager_tf_executor.EagerValue)
self.assertIsInstance(v5.internal_representation, tf.Tensor)
self.assertIsInstance(v5.reference, tf.Tensor)
self.assertEqual(str(v5.type_signature), 'int32')
self.assertEqual(v5.internal_representation, 20)
self.assertEqual(v5.reference, 20)

def test_executor_compute(self):
ex = eager_tf_executor.EagerTFExecutor()
Expand Down Expand Up @@ -764,9 +761,8 @@ def _generate_items():
val = asyncio.run(ex.create_value(_generate_items, type_spec))
self.assertIsInstance(val, eager_tf_executor.EagerValue)
self.assertEqual(str(val.type_signature), str(type_spec))
self.assertIn('Dataset', type(val.internal_representation).__name__)
self.assertCountEqual([x.numpy() for x in val.internal_representation],
[2, 5, 10])
self.assertIn('Dataset', type(val.reference).__name__)
self.assertCountEqual([x.numpy() for x in val.reference], [2, 5, 10])


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def __init__(self, v, t):
self._v = v
self._t = t

@property
def reference(self):
return self._v

@property
def type_signature(self):
return self._t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def index(self):
return self._index

@property
def value(self):
def reference(self):
return self._value

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@
from tensorflow_federated.python.core.impl.types import typed_object


class ExecutorValue(typed_object.TypedObject, metaclass=abc.ABCMeta):
class ExecutorValue(abc.ABC, typed_object.TypedObject):
"""Represents the abstract interface for values embedded within executors.
The embedded values may represent computations in-flight that may materialize
in the future or fail before they materialize.
"""

@property
@abc.abstractmethod
def reference(self):
"""Returns a reference to the value without transferring ownership."""
raise NotImplementedError

@abc.abstractmethod
async def compute(self):
"""A coroutine that asynchronously returns the computed form of the value.
Expand Down
Loading

0 comments on commit c731295

Please sign in to comment.