Skip to content

Commit

Permalink
Convert to bfloat16 weights
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Mar 8, 2024
1 parent 0c5fa70 commit 9bad97a
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tools/checkpoint_conversion/convert_gemma_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert Gemma flax checkpoints to the Keras format.
Setup:
pip install requirements.txt
pip install git+https://github.com/google-deepmind/gemma.git
python pip_build.py --install
Usage:
cd tools/checkpoint_conversion
python convert_gemma_checkpoints.py --preset gemma_2b_en
"""

import os

Expand All @@ -19,6 +31,7 @@
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import kagglehub # noqa: E402
import keras # noqa: E402
import numpy as np # noqa: E402
import sentencepiece # noqa: E402
from absl import app # noqa: E402
Expand All @@ -40,7 +53,10 @@


flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}', required=True
"preset",
None,
f'Must be one of {",".join(PRESET_MAP.keys())}',
required=True,
)


Expand Down Expand Up @@ -170,6 +186,10 @@ def main(_):

print(f"🏃 Coverting {preset}")

# Currently all flax weights are bfloat16 (and have much faster download
# times for it). We follow suit with Keras weights.
keras.config.set_floatx("bfloat16")

handle = PRESET_MAP[preset]
flax_dir = download_flax_model(handle)
proto_path = flax_dir + "/tokenizer.model"
Expand Down

0 comments on commit 9bad97a

Please sign in to comment.