Skip to content

Commit

Permalink
Enabled and fixed python type annotations in some modules in the `exe…
Browse files Browse the repository at this point in the history
…cutors` package.

This revealed two Python typing-related issues:
1. Every subclass of `executor_base.Executor` was using a property on the corresponding subclass of `executor_value_base.ExecutorValue` without having access to that property. This was fixed by:
2. The methods on the subclasses of `federating_executor.FederatingStrategy` (`FederatedComposingStrategy` and `FederatedResolvingStrategy`) were annotated to return types specific to that subclass. This return type conflicted with the return types of the shared implementations defined in `executor_utils`.
3. The type annotation of `target_executors` parameter of `FederatedResolvingStrategy` was incorrect.

* Enabled pytype in the following modules:
    * `federated_composing_strategy`
    * `federated_resolving_strategy`
    * `federating_executor`
    * `reference_resolving_executor`
    * `remote_executor`
    * `thread_delegating_executor`
    * `value_serialization`

To fix issue #1:

* Promoted the `reference` property from the subclasses of `executor_value_base.ExecutorValue` to `executor_value_base.ExecutorValue`.
* Renamed usage of this API to be consistent.

To fix issue #2:

* Updated the methods on the subclasses of `federating_executor.FederatingStrategy` (`FederatedComposingStrategy` and `FederatedResolvingStrategy`) to accept and return `executor_value_base.ExecutorValue` types.

Note: That this does make the return types more generic, but I believe this is the intended interface. An alternative fix could be to update the `executor_utils` function to take and return `TypeVar`'s (i.e. generic functions). However, I don't think this is the intended interface.

To fix issue #3:

* Updated the type annotation of the `target_executors` parameter to `dict[placements.PlacementLiteral, Union[list[executor_base.Executor], executor_base.Executor]]`.

PiperOrigin-RevId: 502890245
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Jan 23, 2023
1 parent 6743420 commit e68335a
Show file tree
Hide file tree
Showing 20 changed files with 363 additions and 379 deletions.
17 changes: 8 additions & 9 deletions tensorflow_federated/python/core/backends/xla/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def __init__(self, value, type_spec, backend):
self._value = to_representation_for_type(value, type_spec, backend)

@property
def internal_representation(self):
"""Returns internal representation of the value embedded in the executor."""
def reference(self):
return self._value

@property
Expand Down Expand Up @@ -184,12 +183,12 @@ async def create_call(self, comp, arg=None):
if arg is not None:
py_typecheck.check_type(arg, XlaValue)
py_typecheck.check_type(comp.type_signature, computation_types.FunctionType)
py_typecheck.check_type(comp.internal_representation,
py_typecheck.check_type(comp.reference,
runtime.ComputationCallable)
if comp.type_signature.parameter is not None:
result = comp.internal_representation(arg.internal_representation)
result = comp.reference(arg.reference)
else:
result = comp.internal_representation()
result = comp.reference()
return XlaValue(result, comp.type_signature.result, self._backend)

@tracing.trace
Expand All @@ -198,7 +197,7 @@ async def create_struct(self, elements):
type_elements = []
for k, v in structure.iter_elements(structure.from_container(elements)):
py_typecheck.check_type(v, XlaValue)
val_elements.append((k, v.internal_representation))
val_elements.append((k, v.reference))
type_elements.append((k, v.type_signature))
struct_val = structure.Struct(val_elements)
struct_type = computation_types.StructType([
Expand All @@ -210,20 +209,20 @@ async def create_struct(self, elements):
async def create_selection(self, source, index=None, name=None):
py_typecheck.check_type(source, XlaValue)
py_typecheck.check_type(source.type_signature, computation_types.StructType)
py_typecheck.check_type(source.internal_representation, structure.Struct)
py_typecheck.check_type(source.reference, structure.Struct)
if index is not None:
py_typecheck.check_type(index, int)
if name is not None:
raise ValueError(
'Cannot simultaneously specify name {} and index {}.'.format(
name, index))
else:
return XlaValue(source.internal_representation[index],
return XlaValue(source.reference[index],
source.type_signature[index], self._backend)
elif name is not None:
py_typecheck.check_type(name, str)
return XlaValue(
getattr(source.internal_representation, str(name)),
getattr(source.reference, str(name)),
getattr(source.type_signature, str(name)), self._backend)
else:
raise ValueError('Must specify either name or index.')
Expand Down
16 changes: 8 additions & 8 deletions tensorflow_federated/python/core/backends/xla/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def test_create_compute_int32(self):
int_val = asyncio.run(ex.create_value(10, np.int32))
self.assertIsInstance(int_val, executor.XlaValue)
self.assertEqual(str(int_val.type_signature), 'int32')
self.assertIsInstance(int_val.internal_representation, np.int32)
self.assertEqual(int_val.internal_representation, 10)
self.assertIsInstance(int_val.reference, np.int32)
self.assertEqual(int_val.reference, 10)
result = asyncio.run(int_val.compute())
self.assertEqual(result, 10)

Expand All @@ -107,8 +107,8 @@ def test_create_compute_2xint32_struct(self):
struct_val = asyncio.run(ex.create_struct([x_val, y_val]))
self.assertIsInstance(struct_val, executor.XlaValue)
self.assertEqual(str(struct_val.type_signature), '<int32,int32>')
self.assertIsInstance(struct_val.internal_representation, structure.Struct)
self.assertEqual(str(struct_val.internal_representation), '<10,20>')
self.assertIsInstance(struct_val.reference, structure.Struct)
self.assertEqual(str(struct_val.reference), '<10,20>')
result = asyncio.run(struct_val.compute())
self.assertEqual(str(result), '<10,20>')

Expand All @@ -123,8 +123,8 @@ def test_create_and_invoke_noarg_comp_returning_int32(self):
comp_val = asyncio.run(ex.create_value(comp_pb, comp_type))
self.assertIsInstance(comp_val, executor.XlaValue)
self.assertEqual(str(comp_val.type_signature), str(comp_type))
self.assertTrue(callable(comp_val.internal_representation))
result = comp_val.internal_representation()
self.assertTrue(callable(comp_val.reference))
result = comp_val.reference()
self.assertEqual(result, 10)
call_val = asyncio.run(ex.create_call(comp_val))
self.assertIsInstance(call_val, executor.XlaValue)
Expand Down Expand Up @@ -163,9 +163,9 @@ def test_selection(self):
self.assertIsInstance(struct_val, executor.XlaValue)
self.assertEqual(str(struct_val.type_signature), '<a=int32,b=int32>')
by_index_val = asyncio.run(ex.create_selection(struct_val, index=0))
self.assertEqual(by_index_val.internal_representation, 10)
self.assertEqual(by_index_val.reference, 10)
by_name_val = asyncio.run(ex.create_selection(struct_val, name='b'))
self.assertEqual(by_name_val.internal_representation, 20)
self.assertEqual(by_name_val.reference, 20)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def type_signature(self) -> computation_types.Type:
return self._type_signature

@property
def ref(self) -> int:
"""Hands out a reference to self without transferring ownership."""
def reference(self) -> int:
return self._owned_value_id.ref

@tracing.trace
Expand Down Expand Up @@ -119,9 +118,9 @@ async def create_call(
fn: CppToPythonExecutorValue,
arg: Optional[CppToPythonExecutorValue] = None
) -> CppToPythonExecutorValue:
fn_ref = fn.ref
fn_ref = fn.reference
if arg is not None:
arg_ref = arg.ref
arg_ref = arg.reference
else:
arg_ref = None
try:
Expand All @@ -139,7 +138,7 @@ async def create_struct(
id_list = []
type_list = []
for name, value in structure.iter_elements(executor_value_struct):
id_list.append(value.ref)
id_list.append(value.reference)
type_list.append((name, value.type_signature))
try:
struct_id = self._cpp_executor.create_struct(id_list)
Expand All @@ -153,7 +152,9 @@ async def create_struct(
async def create_selection(self, source: CppToPythonExecutorValue,
index: int) -> CppToPythonExecutorValue:
try:
selection_id = self._cpp_executor.create_selection(source.ref, index)
selection_id = self._cpp_executor.create_selection(
source.reference, index
)
except Exception as e: # pylint: disable=broad-except
_handle_error(e)
selection_type = source.type_signature[index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def to_representation_for_type(
return to_representation_for_type(
value, tf_function_cache, type_spec=type_spec, device=None)
elif isinstance(value, EagerValue):
return value.internal_representation
return value.reference
elif isinstance(value, executor_value_base.ExecutorValue):
raise TypeError(
'Cannot accept a value embedded within a non-eager executor.')
Expand Down Expand Up @@ -568,12 +568,7 @@ def __init__(self, value, type_spec):
self._value = value

@property
def internal_representation(self):
"""Returns a representation of the eager value embedded in the executor.
This property is only intended for use by the eager executor and tests. Not
for consumption by consumers of the executor interface.
"""
def reference(self):
return self._value

@property
Expand Down Expand Up @@ -619,8 +614,8 @@ class EagerTFExecutor(executor_base.Executor):
One further implementation detail is worth noting. Like all executors, this
executor embeds incoming data as an instance of an executor-specific class,
here the `EagerValue`. All `EagerValues` are assumed in this implmentation
to have an `internal_representation` which is a fixed point under the action
of `to_representation_for_type` with type the `type_signature` attribute of
to have an `reference` which is a fixed point under the action of
`to_representation_for_type` with type the `type_signature` attribute of
the `EagerValue`. This invariant is introduced by normalization in
`create_value`, and is respected by the form of returned `EagerValues` in all
other methods this executor exposes.
Expand Down Expand Up @@ -711,11 +706,10 @@ async def create_call(self, comp, arg=None):
comp.type_signature))
if comp.type_signature.parameter is not None:
return EagerValue(
comp.internal_representation(arg.internal_representation),
comp.type_signature.result)
comp.reference(arg.reference), comp.type_signature.result
)
elif arg is None:
return EagerValue(comp.internal_representation(),
comp.type_signature.result)
return EagerValue(comp.reference(), comp.type_signature.result)
else:
raise TypeError('Cannot pass an argument to a no-argument function.')

Expand All @@ -734,7 +728,7 @@ async def create_struct(self, elements):
type_elements = []
for k, v in elements:
py_typecheck.check_type(v, EagerValue)
val_elements.append((k, v.internal_representation))
val_elements.append((k, v.reference))
type_elements.append((k, v.type_signature))
return EagerValue(
structure.Struct(val_elements),
Expand All @@ -759,10 +753,9 @@ async def create_selection(self, source, index):
"""
py_typecheck.check_type(source, EagerValue)
py_typecheck.check_type(source.type_signature, computation_types.StructType)
py_typecheck.check_type(source.internal_representation, structure.Struct)
py_typecheck.check_type(source.reference, structure.Struct)
py_typecheck.check_type(index, int)
return EagerValue(source.internal_representation[index],
source.type_signature[index])
return EagerValue(source.reference[index], source.type_signature[index])

def close(self):
pass
Loading

0 comments on commit e68335a

Please sign in to comment.