Skip to content

Commit

Permalink
mlx - passing most layers tests (#20862)
Browse files Browse the repository at this point in the history
* adding ops functions and passing most layer tests

* passing most layers tests

* fix for tensorflow tests, handling mlx array slicing in random crop
  • Loading branch information
acsweet authored Feb 8, 2025
1 parent b9f7141 commit 7237ec7
Show file tree
Hide file tree
Showing 19 changed files with 984 additions and 358 deletions.
7 changes: 4 additions & 3 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"int64": mx.int64,
"bfloat16": mx.bfloat16,
"bool": mx.bool_,
"complex64": mx.complex64,
}


Expand Down Expand Up @@ -376,9 +377,9 @@ def random_seed_dtype():
return "uint32"


def reverse_sequence(xs):
indices = mx.arange(xs.shape[0] - 1, -1, -1)
return mx.take(xs, indices, axis=0)
def reverse_sequence(xs, axis=0):
indices = mx.arange(xs.shape[axis] - 1, -1, -1)
return mx.take(xs, indices, axis=axis)


def flip(x, axis=None):
Expand Down
Loading

0 comments on commit 7237ec7

Please sign in to comment.