Skip to content

Commit

Permalink
Added an option to use the pure python implementation. (#1137)
Browse files Browse the repository at this point in the history
* Added an option to use the pure python implementation.
  • Loading branch information
gabrieldemarmiesse authored Mar 7, 2020
1 parent 7b02285 commit 21d0574
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 16 deletions.
44 changes: 37 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

| Build Type | Status |
| --- | --- |
| **MacOS CPU** | [![Status](https://github.com/tensorflow/addons/workflows/macos-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Amacos-nightly) |
| **Windows CPU** | [![Status](https://github.com/tensorflow/addons/workflows/windows-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Awindows-nightly) |
| **Ubuntu CPU** | [![Status](https://github.com/tensorflow/addons/workflows/manylinux-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Amanylinux-nightly) |
| **Ubuntu GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.html) |
| **MacOS** | [![Status](https://github.com/tensorflow/addons/workflows/macos-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Amacos-nightly) |
| **Windows** | [![Status](https://github.com/tensorflow/addons/workflows/windows-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Awindows-nightly) |
| **Ubuntu** | [![Status](https://github.com/tensorflow/addons/workflows/manylinux-nightly/badge.svg)](https://github.com/tensorflow/addons/actions?query=workflow%3Amanylinux-nightly) |
| **Ubuntu custom GPU ops** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/addons/ubuntu-gpu-py3.html) |

**TensorFlow Addons** is a repository of contributions that conform to
well-established API patterns, but implement new functionality
Expand Down Expand Up @@ -113,9 +113,39 @@ TF-Addons. In order to achieve these we require that our additions
conform to established API patterns seen in core TensorFlow.

#### GPU/CPU Custom-Ops
A major benefit of TensorFlow Addons is that there are precompiled ops. Should
a CUDA 10.1 installation not be found then the op will automatically fall back to
a CPU implementation.
A major benefit of TensorFlow Addons is that there are precompiled ops for CPU/GPU.
Currently however, GPU custom ops only work for Linux distributions. For this reason Windows and MacOS will fallback to pure TensorFlow Python implementations whenever possible.

The order of priority in MacOS/Windows:
1) Pure TensorFlow + Python implementation (work on cpu+gpu)
2) C++ implementation for CPU

The order of priority for Linux:
1) CUDA implementation
2) C++ implementation
3) Pure TensorFlow + Python implementation (work on cpu+gpu)

If you want to change the default priority, "C++ and CUDA" VS "pure TF Python",
you can either set the variable `TF_ADDONS_PY_OPS` from the command line or in
your code.

For example, if you're on linux and you have compatibility problems with the compiled ops,
and you want to give priority to the Python implementation
you can do:

From the command line:
```
export TF_ADDONS_PY_OPS=1
```

or in your code:

```
import tensorflow_addons as tfa
tfa.options.TF_ADDONS_PY_OPS=True
```

This variable will default to `True` on Windows and Mac, and `False` for Linux.

#### Proxy Maintainership
Addons has been designed to compartmentalize subpackages and submodules so
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ py_library(
name = "tensorflow_addons",
data = [
"__init__.py",
"options.py",
"version.py",
],
deps = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/activations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ py_library(
"tanhshrink.py",
],
data = [
"//tensorflow_addons:options.py",
"//tensorflow_addons/custom_ops/activations:_activation_ops.so",
"//tensorflow_addons/utils",
],
Expand Down
14 changes: 13 additions & 1 deletion tensorflow_addons/activations/hardshrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO
from tensorflow_addons import options

_activation_so = LazySO("custom_ops/activations/_activation_ops.so")

Expand All @@ -40,6 +41,18 @@ def hardshrink(
A `Tensor`. Has the same type as `x`.
"""
x = tf.convert_to_tensor(x)

if not options.TF_ADDONS_PY_OPS:
try:
return _hardshrink_custom_op(x, lower, upper)
except tf.errors.NotFoundError:
options.warn_fallback("hardshrink")

return _hardshrink_py(x, lower, upper)


def _hardshrink_custom_op(x, lower=-0.5, upper=0.5):
"""Alias with lazy loading of the .so file"""
return _activation_so.ops.addons_hardshrink(x, lower, upper)


Expand All @@ -59,7 +72,6 @@ def _hardshrink_py(
" not be higher than the value "
"variable upper, which is {} .".format(lower, upper)
)
x = tf.convert_to_tensor(x)
mask_lower = x < lower
mask_upper = upper < x
mask = tf.logical_or(mask_lower, mask_upper)
Expand Down
14 changes: 7 additions & 7 deletions tensorflow_addons/activations/hardshrink_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import hardshrink
from tensorflow_addons.utils import test_utils
from tensorflow_addons.activations.hardshrink import _hardshrink_custom_op
from tensorflow_addons.activations.hardshrink import _hardshrink_py
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class HardshrinkTest(tf.test.TestCase, parameterized.TestCase):
def test_invalid(self):
with self.assertRaisesOpError("lower must be less than or equal to upper."):
y = hardshrink(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
y = _hardshrink_custom_op(tf.ones(shape=(1, 2, 3)), lower=2.0, upper=-2.0)
self.evaluate(y)

@parameterized.named_parameters(
Expand All @@ -35,11 +35,11 @@ def test_invalid(self):
def test_hardshrink(self, dtype):
x = tf.constant([-2.0, -0.5, 0.0, 0.5, 2.0], dtype=dtype)
expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype)
self.assertAllCloseAccordingToType(hardshrink(x), expected_result)
self.assertAllCloseAccordingToType(_hardshrink_custom_op(x), expected_result)

expected_result = tf.constant([-2.0, 0.0, 0.0, 0.0, 2.0], dtype=dtype)
self.assertAllCloseAccordingToType(
hardshrink(x, lower=-1.0, upper=1.0), expected_result
_hardshrink_custom_op(x, lower=-1.0, upper=1.0), expected_result
)

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
Expand All @@ -51,7 +51,7 @@ def test_theoretical_gradients(self, dtype):
# Avoid these two points to make gradients smooth.
x = tf.constant([-2.0, -1.5, 0.0, 1.5, 2.0], dtype=dtype)

theoretical, numerical = tf.test.compute_gradient(hardshrink, [x])
theoretical, numerical = tf.test.compute_gradient(_hardshrink_custom_op, [x])
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
Expand All @@ -68,7 +68,7 @@ def verify_funcs_are_equivalent(self, dtype):

with tf.GradientTape(persistent=True) as t:
t.watch(x)
y_native = hardshrink(x, lower, upper)
y_native = _hardshrink_custom_op(x, lower, upper)
y_py = _hardshrink_py(x, lower, upper)

self.assertAllCloseAccordingToType(y_native, y_py)
Expand Down
49 changes: 49 additions & 0 deletions tensorflow_addons/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import platform
import warnings
import traceback

try:
TF_ADDONS_PY_OPS = bool(int(os.environ["TF_ADDONS_PY_OPS"]))
except KeyError:
if platform.system() == "Linux":
TF_ADDONS_PY_OPS = False
else:
TF_ADDONS_PY_OPS = True


FALLBACK_WARNING_TEMPLATE = """{}
The {} C++/CUDA custom op could not be loaded.
For this reason, Addons will fallback to an implementation written
in Python with public TensorFlow ops. There worst you might experience with
this is a moderate slowdown on GPU. There can be multiple
reason for this loading error, one of them may be an ABI incompatibility between
the TensorFlow installed on your system and the TensorFlow used to compile
TensorFlow Addons' custom ops. The stacktrace generated when loading the
shared object file was displayed above.
If you want this warning to disappear, either make sure the TensorFlow installed
is compatible with this version of Addons, or tell TensorFlow Addons to
prefer using Python implementations and not custom C++/CUDA ones. You can do that
by changing the TF_ADDONS_PY_OPS flag
either with the environment variable:
```bash
TF_ADDONS_PY_OPS=1 python my_script.py
```
or in your code, after your imports:
```python
import tensorflow_addons as tfa
import ...
import ...
tfa.options.TF_ADDONS_PY_OPS = True
```
"""


def warn_fallback(op_name):
warning_msg = FALLBACK_WARNING_TEMPLATE.format(traceback.format_exc(), op_name)
warnings.warn(warning_msg, RuntimeWarning)
global TF_ADDONS_PY_OPS
TF_ADDONS_PY_OPS = True
2 changes: 1 addition & 1 deletion tools/ci_build/verify/check_typing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
modules_list = []
for attr_name in dir(tensorflow_addons):
attr = getattr(tensorflow_addons, attr_name)
if isinstance(attr, ModuleType):
if isinstance(attr, ModuleType) and attr is not tensorflow_addons.options:
modules_list.append(attr)


Expand Down

0 comments on commit 21d0574

Please sign in to comment.