diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py index 0e1d7a14f8..517c5f4e3a 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_test.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -15,9 +15,9 @@ import os from unittest.mock import patch -import keras import pytest +from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM