Skip to content

Commit

Permalink
Switch deberta to use the "int" dtype
Browse files Browse the repository at this point in the history
This will be int32 on jax and torch, but int64 on tf, which is what we
need for proper accelerator support
  • Loading branch information
mattdangerw committed Nov 13, 2023
1 parent 11bece8 commit f4712da
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions keras_nlp/models/deberta_v3/disentangled_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def _get_log_pos(abs_pos, mid):
x1=rel_pos,
x2=log_pos * sign,
)
bucket_pos = ops.cast(bucket_pos, dtype="int64")
bucket_pos = ops.cast(bucket_pos, dtype="int")

return bucket_pos

def _get_rel_pos(self, num_positions):
ids = ops.arange(num_positions, dtype="int64")
ids = ops.arange(num_positions, dtype="int")
query_ids = ops.expand_dims(ids, axis=-1)
key_ids = ops.expand_dims(ids, axis=0)
key_ids = ops.repeat(key_ids, repeats=num_positions, axis=0)
Expand Down

0 comments on commit f4712da

Please sign in to comment.