Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gpt2, t5 and fnet under mixed precision #958

Merged
merged 1 commit into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions keras_nlp/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
import tensorflow as tf
from packaging import version
from tensorflow import keras


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -53,9 +54,17 @@ def pytest_addoption(parser):
default=False,
help="run tpu tests",
)
parser.addoption(
"--mixed_precision",
action="store_true",
default=False,
help="run with mixed precision",
)


def pytest_configure(config):
if config.getoption("--mixed_precision"):
keras.mixed_precision.set_global_policy("mixed_float16")
config.addinivalue_line(
"markers", "large: mark test as being slow or requiring a network"
)
Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def call(
value = dynamic_update_slice(value_cache, value, start)
cache = tf.stack((key, value), axis=1)

query = tf.multiply(query, 1.0 / tf.math.sqrt(float(self._key_dim)))
query = tf.multiply(
query, 1.0 / tf.math.sqrt(tf.cast(self._key_dim, query.dtype))
)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/layers/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ def test_cache_call_is_correct(self, eager):
key_dim = 4

layer = CachedMultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
x = tf.random.uniform(shape=[batch_size, seq_len, num_heads * key_dim])
cache = tf.zeros([batch_size, 2, seq_len, num_heads, key_dim])
dtype = layer.compute_dtype
x = tf.random.uniform(
shape=[batch_size, seq_len, num_heads * key_dim], dtype=dtype
)
cache = tf.zeros(
[batch_size, 2, seq_len, num_heads, key_dim], dtype=dtype
)
# Use a causal mask.
mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
outputs = tf.zeros_like(x)
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/layers/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def fourier_transform(input):
# Apply FFT on the input and take the real part.
# Before we apply fourier transform, let's convert the dtype of the
# input tensor to complex64.
input = tf.cast(input, tf.complex64)
mixing_output = tf.math.real(tf.signal.fft2d(input))
return mixing_output
x = tf.cast(input, tf.complex64)
mixing_output = tf.math.real(tf.signal.fft2d(x))
return tf.cast(mixing_output, input.dtype)

def add_and_norm(input1, input2, norm_layer):
return norm_layer(input1 + input2)
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/layers/position_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_static_layer_output_shape(self):
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [None, sequence_length, feature_size]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
# The output dtype for this layer should match the compute dtype.
self.assertEqual(test_layer.compute_dtype, output_tensor.dtype)

def test_more_than_3_dimensions_static(self):
# Create a 4-dimensional input (the first dimension is implicit).
Expand All @@ -68,8 +68,8 @@ def test_more_than_3_dimensions_static(self):
feature_size,
]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
# The output dtype for this layer should match the compute dtype.
self.assertEqual(test_layer.compute_dtype, output_tensor.dtype)

def test_float16_dtype(self):
# Create a 3-dimensional input (the first dimension is implicit).
Expand Down
37 changes: 14 additions & 23 deletions keras_nlp/layers/sine_position_encoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,29 +79,20 @@ def test_output_correct_values(self):
output = model(input)

# comapre position encoding values for position 0 and 3
expected_encoding_position_0 = [0.0, 1.0, 0.0, 1.0, 0.0, 1.0]
expected_encoding_position_3 = [
0.14112,
-0.9899925,
0.1387981,
0.9903207,
0.00646326,
0.99997914,
]
self.assertAllClose(output[0, 0, :], expected_encoding_position_0)
self.assertAllClose(output[0, 3, :], expected_encoding_position_3)
expected_0 = [0.0, 1.0, 0.0, 1.0, 0.0, 1.0]
expected_3 = [0.14112, -0.98999, 0.13879, 0.99032, 0.00646, 0.99997]
self.assertAllClose(output[0, 0, :], expected_0, atol=0.01, rtol=0.01)
self.assertAllClose(output[0, 3, :], expected_3, atol=0.01, rtol=0.01)

def test_ragged_tensor_with_3_dimensions(self):
feature_size = 2
test_layer = sine_position_encoding.SinePositionEncoding()
# Create a 3-dimensional ragged input (the first dimension is implicit).
input_tensor = keras.Input(
shape=(None, feature_size), dtype=tf.float32, ragged=True
)
input_tensor = keras.Input(shape=(None, feature_size), ragged=True)
output_tensor = test_layer(input_tensor)
model = keras.Model(input_tensor, output_tensor)

input_data = tf.ragged.constant(
inputs = tf.ragged.constant(
[
[[1.0, 1.0], [1.0, 1.0]],
[],
Expand All @@ -111,7 +102,7 @@ def test_ragged_tensor_with_3_dimensions(self):
ragged_rank=1,
inner_shape=(2,),
)
expected_output_data = tf.ragged.constant(
expected_outputs = tf.ragged.constant(
[
[[0.0, 1.0], [0.84147096, 0.5403023]],
[],
Expand All @@ -121,20 +112,20 @@ def test_ragged_tensor_with_3_dimensions(self):
ragged_rank=1,
inner_shape=(2,),
)
output_data = model.predict(input_data)
self.assertAllClose(output_data, expected_output_data)
outputs = model.predict(inputs)
self.assertAllClose(outputs, expected_outputs, atol=0.01, rtol=0.01)

def test_ragged_tensor_with_4_dimensions(self):
feature_size = 2
test_layer = sine_position_encoding.SinePositionEncoding()
# Create a 4-dimensional ragged input (the first dimension is implicit).
input_tensor = keras.Input(
shape=(None, None, feature_size), dtype=tf.float32, ragged=True
shape=(None, None, feature_size), ragged=True
)
output_tensor = test_layer(input_tensor)
model = keras.Model(input_tensor, output_tensor)

input_data = tf.ragged.constant(
inputs = tf.ragged.constant(
[
[
[[1.0, 1.0], [1.0, 1.0]],
Expand All @@ -148,7 +139,7 @@ def test_ragged_tensor_with_4_dimensions(self):
ragged_rank=2,
inner_shape=(2,),
)
expected_output_data = tf.ragged.constant(
expected_outputs = tf.ragged.constant(
[
[
[[0.0, 1.0], [0.84147096, 0.5403023]],
Expand All @@ -166,8 +157,8 @@ def test_ragged_tensor_with_4_dimensions(self):
ragged_rank=2,
inner_shape=(2,),
)
output_data = model.predict(input_data)
self.assertAllClose(output_data, expected_output_data)
outputs = model.predict(inputs)
self.assertAllClose(outputs, expected_outputs, atol=0.01, rtol=0.01)

def test_get_config_and_from_config(self):
pos_encoding = sine_position_encoding.SinePositionEncoding(
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/layers/transformer_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,13 @@ def test_cached_decoding_is_correct(self, eager):
intermediate_dim=4,
num_heads=num_heads,
)
x = tf.random.uniform(shape=[batch_size, seq_len, num_heads * head_dim])
cache = tf.zeros([batch_size, 2, seq_len, num_heads, head_dim])
dtype = layer.compute_dtype
x = tf.random.uniform(
shape=[batch_size, seq_len, num_heads * head_dim], dtype=dtype
)
cache = tf.zeros(
[batch_size, 2, seq_len, num_heads, head_dim], dtype=dtype
)
outputs = tf.zeros_like(x)

def call(outputs, cache):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/albert/albert_masked_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,4 @@ def test_saved_model(self, save_format, filename):
self.assertIsInstance(restored_model, AlbertMaskedLM)
# Check that output matches.
restored_output = restored_model.predict(self.raw_batch)
self.assertAllClose(model_output, restored_output)
self.assertAllClose(model_output, restored_output, atol=0.01, rtol=0.01)
2 changes: 1 addition & 1 deletion keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _build_cache(self, prompt):
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
cache = tf.zeros(shape)
cache = tf.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, cache = self.call_with_cache(prompt, cache, 0)
return cache
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/models/t5/t5_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def project(
if position_bias is None:
if not self.use_relative_attention_bias:
position_bias = tf.zeros(
(1, self.num_heads, real_seq_length, key_length)
(1, self.num_heads, real_seq_length, key_length),
self.compute_dtype,
)
else:
position_bias = self.compute_bias(real_seq_length, key_length)
Expand Down