From 75835452824f047df2c2d57521c32fae45e763b5 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 10 Dec 2024 09:00:21 +0800 Subject: [PATCH] Use lower precision in DPA --- keras/src/backend/numpy/nn.py | 14 ++++++++------ keras/src/backend/tensorflow/nn.py | 8 ++------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index b5ab2e04e9a..3b14d864ff7 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -1096,12 +1096,14 @@ def _apply_masks(logits, mask, is_causal): def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): + original_dtype = key.dtype logits_dtype = np.promote_types(query.dtype, np.float32) - logits = np.einsum( - "BTNH,BSNH->BNTS", - query.astype(logits_dtype), - key.astype(logits_dtype), - ) + if backend.standardize_dtype(key.dtype) == "bfloat16": + # `np.einsum` doesn't support bfloat16 + key = key.astype("float32") + value = value.astype("float32") + logits = np.einsum("BTNH,BSNH->BNTS", query, key) + logits = logits.astype(logits_dtype) logits *= np.array(scale, dtype=logits.dtype) if bias is not None: @@ -1111,7 +1113,7 @@ def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): # Softmax and it is always carried out in fp32. padded_logits = padded_logits.astype(np.float32) - probs = softmax(padded_logits, axis=-1).astype(key.dtype) + probs = softmax(padded_logits, axis=-1).astype(original_dtype) encoded_dtype = probs.dtype if backend.standardize_dtype(probs.dtype) == "bfloat16": # `np.einsum` doesn't support bfloat16 diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index e9fe8447466..325bd21b69e 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -1015,12 +1015,8 @@ def _apply_masks(logits, mask, is_causal): def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale): logits_dtype = backend.result_type(query.dtype, "float32") - logits = tf.einsum( - "BTNH,BSNH->BNTS", - tf.cast(query, dtype=logits_dtype), - tf.cast(key, dtype=logits_dtype), - optimize="optimal", - ) + logits = tf.einsum("BTNH,BSNH->BNTS", query, key, optimize="optimal") + logits = tf.cast(logits, logits_dtype) logits = tf.multiply(logits, tf.cast(scale, logits.dtype)) if bias is not None: