Skip to content

Commit

Permalink
Fix gpt2, t5 and fnet under mixed precision (#958)
Browse files Browse the repository at this point in the history
Also add a testing mode for mixed precision, though it does not fully
pass yet due to saving and other output errors.
  • Loading branch information
mattdangerw authored Apr 6, 2023
1 parent a3deb52 commit b0c457e
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 38 deletions.
9 changes: 9 additions & 0 deletions keras_nlp/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
import tensorflow as tf
from packaging import version
from tensorflow import keras


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -53,9 +54,17 @@ def pytest_addoption(parser):
default=False,
help="run tpu tests",
)
parser.addoption(
"--mixed_precision",
action="store_true",
default=False,
help="run with mixed precision",
)


def pytest_configure(config):
if config.getoption("--mixed_precision"):
keras.mixed_precision.set_global_policy("mixed_float16")
config.addinivalue_line(
"markers", "large: mark test as being slow or requiring a network"
)
Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/layers/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def call(
value = dynamic_update_slice(value_cache, value, start)
cache = tf.stack((key, value), axis=1)

query = tf.multiply(query, 1.0 / tf.math.sqrt(float(self._key_dim)))
query = tf.multiply(
query, 1.0 / tf.math.sqrt(tf.cast(self._key_dim, query.dtype))
)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/layers/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ def test_cache_call_is_correct(self, eager):
key_dim = 4

layer = CachedMultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
x = tf.random.uniform(shape=[batch_size, seq_len, num_heads * key_dim])
cache = tf.zeros([batch_size, 2, seq_len, num_heads, key_dim])
dtype = layer.compute_dtype
x = tf.random.uniform(
shape=[batch_size, seq_len, num_heads * key_dim], dtype=dtype
)
cache = tf.zeros(
[batch_size, 2, seq_len, num_heads, key_dim], dtype=dtype
)
# Use a causal mask.
mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
outputs = tf.zeros_like(x)
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/layers/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def fourier_transform(input):
# Apply FFT on the input and take the real part.
# Before we apply fourier transform, let's convert the dtype of the
# input tensor to complex64.
input = tf.cast(input, tf.complex64)
mixing_output = tf.math.real(tf.signal.fft2d(input))
return mixing_output
x = tf.cast(input, tf.complex64)
mixing_output = tf.math.real(tf.signal.fft2d(x))
return tf.cast(mixing_output, input.dtype)

def add_and_norm(input1, input2, norm_layer):
return norm_layer(input1 + input2)
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/layers/position_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_static_layer_output_shape(self):
# to be the same as the input shape in all dimensions save batch.
expected_output_shape = [None, sequence_length, feature_size]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
# The output dtype for this layer should match the compute dtype.
self.assertEqual(test_layer.compute_dtype, output_tensor.dtype)

def test_more_than_3_dimensions_static(self):
# Create a 4-dimensional input (the first dimension is implicit).
Expand All @@ -68,8 +68,8 @@ def test_more_than_3_dimensions_static(self):
feature_size,
]
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
# The default output dtype for this layer should be tf.float32.
self.assertEqual(tf.float32, output_tensor.dtype)
# The output dtype for this layer should match the compute dtype.
self.assertEqual(test_layer.compute_dtype, output_tensor.dtype)

def test_float16_dtype(self):
# Create a 3-dimensional input (the first dimension is implicit).
Expand Down
37 changes: 14 additions & 23 deletions keras_nlp/layers/sine_position_encoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,29 +79,20 @@ def test_output_correct_values(self):
output = model(input)

# comapre position encoding values for position 0 and 3
expected_encoding_position_0 = [0.0, 1.0, 0.0, 1.0, 0.0, 1.0]
expected_encoding_position_3 = [
0.14112,
-0.9899925,
0.1387981,
0.9903207,
0.00646326,
0.99997914,
]
self.assertAllClose(output[0, 0, :], expected_encoding_position_0)
self.assertAllClose(output[0, 3, :], expected_encoding_position_3)
expected_0 = [0.0, 1.0, 0.0, 1.0, 0.0, 1.0]
expected_3 = [0.14112, -0.98999, 0.13879, 0.99032, 0.00646, 0.99997]
self.assertAllClose(output[0, 0, :], expected_0, atol=0.01, rtol=0.01)
self.assertAllClose(output[0, 3, :], expected_3, atol=0.01, rtol=0.01)

def test_ragged_tensor_with_3_dimensions(self):
feature_size = 2
test_layer = sine_position_encoding.SinePositionEncoding()
# Create a 3-dimensional ragged input (the first dimension is implicit).
input_tensor = keras.Input(
shape=(None, feature_size), dtype=tf.float32, ragged=True
)
input_tensor = keras.Input(shape=(None, feature_size), ragged=True)
output_tensor = test_layer(input_tensor)
model = keras.Model(input_tensor, output_tensor)

input_data = tf.ragged.constant(
inputs = tf.ragged.constant(
[
[[1.0, 1.0], [1.0, 1.0]],
[],
Expand All @@ -111,7 +102,7 @@ def test_ragged_tensor_with_3_dimensions(self):
ragged_rank=1,
inner_shape=(2,),
)
expected_output_data = tf.ragged.constant(
expected_outputs = tf.ragged.constant(
[
[[0.0, 1.0], [0.84147096, 0.5403023]],
[],
Expand All @@ -121,20 +112,20 @@ def test_ragged_tensor_with_3_dimensions(self):
ragged_rank=1,
inner_shape=(2,),
)
output_data = model.predict(input_data)
self.assertAllClose(output_data, expected_output_data)
outputs = model.predict(inputs)
self.assertAllClose(outputs, expected_outputs, atol=0.01, rtol=0.01)

def test_ragged_tensor_with_4_dimensions(self):
feature_size = 2
test_layer = sine_position_encoding.SinePositionEncoding()
# Create a 4-dimensional ragged input (the first dimension is implicit).
input_tensor = keras.Input(
shape=(None, None, feature_size), dtype=tf.float32, ragged=True
shape=(None, None, feature_size), ragged=True
)
output_tensor = test_layer(input_tensor)
model = keras.Model(input_tensor, output_tensor)

input_data = tf.ragged.constant(
inputs = tf.ragged.constant(
[
[
[[1.0, 1.0], [1.0, 1.0]],
Expand All @@ -148,7 +139,7 @@ def test_ragged_tensor_with_4_dimensions(self):
ragged_rank=2,
inner_shape=(2,),
)
expected_output_data = tf.ragged.constant(
expected_outputs = tf.ragged.constant(
[
[
[[0.0, 1.0], [0.84147096, 0.5403023]],
Expand All @@ -166,8 +157,8 @@ def test_ragged_tensor_with_4_dimensions(self):
ragged_rank=2,
inner_shape=(2,),
)
output_data = model.predict(input_data)
self.assertAllClose(output_data, expected_output_data)
outputs = model.predict(inputs)
self.assertAllClose(outputs, expected_outputs, atol=0.01, rtol=0.01)

def test_get_config_and_from_config(self):
pos_encoding = sine_position_encoding.SinePositionEncoding(
Expand Down
9 changes: 7 additions & 2 deletions keras_nlp/layers/transformer_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,13 @@ def test_cached_decoding_is_correct(self, eager):
intermediate_dim=4,
num_heads=num_heads,
)
x = tf.random.uniform(shape=[batch_size, seq_len, num_heads * head_dim])
cache = tf.zeros([batch_size, 2, seq_len, num_heads, head_dim])
dtype = layer.compute_dtype
x = tf.random.uniform(
shape=[batch_size, seq_len, num_heads * head_dim], dtype=dtype
)
cache = tf.zeros(
[batch_size, 2, seq_len, num_heads, head_dim], dtype=dtype
)
outputs = tf.zeros_like(x)

def call(outputs, cache):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/albert/albert_masked_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,4 @@ def test_saved_model(self, save_format, filename):
self.assertIsInstance(restored_model, AlbertMaskedLM)
# Check that output matches.
restored_output = restored_model.predict(self.raw_batch)
self.assertAllClose(model_output, restored_output)
self.assertAllClose(model_output, restored_output, atol=0.01, rtol=0.01)
2 changes: 1 addition & 1 deletion keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _build_cache(self, prompt):
num_heads = self.backbone.num_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
cache = tf.zeros(shape)
cache = tf.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, cache = self.call_with_cache(prompt, cache, 0)
return cache
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/models/t5/t5_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def project(
if position_bias is None:
if not self.use_relative_attention_bias:
position_bias = tf.zeros(
(1, self.num_heads, real_seq_length, key_length)
(1, self.num_heads, real_seq_length, key_length),
self.compute_dtype,
)
else:
position_bias = self.compute_bias(real_seq_length, key_length)
Expand Down

0 comments on commit b0c457e

Please sign in to comment.