Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Feb 5, 2025
1 parent 572d22a commit 60bc298
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions reproduce/ks_mismatch_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,19 @@ def T(t):
def calc_grad_cpu(q, k, in_grad):
with tf.xla.experimental.jit_scope(separate_compiled_gradients=True):
with tf.device("/CPU:0"):
# q - (batch, nh, seqlen, dim)
# k - (batch, nh, seqlen, dim)
# qk - (batch, nh, seqlen, seqlen)
qk = tf.matmul(T(q), T(k), transpose_b=True)
grad_q, grad_k = tf.gradients(qk, [q, k], in_grad)
return [qk, grad_q, grad_k]

def calc_grad_gpu(q, k, in_grad):
with tf.xla.experimental.jit_scope(separate_compiled_gradients=True):
with tf.device("/GPU:0"):
# q - (batch, nh, seqlen, dim)
# k - (batch, nh, seqlen, dim)
# qk - (batch, nh, seqlen, seqlen)
qk = tf.matmul(T(q), T(k), transpose_b=True)
grad_q, grad_k = tf.gradients(qk, [q, k], in_grad)
return [qk, grad_q, grad_k]
Expand Down

0 comments on commit 60bc298

Please sign in to comment.