This codebase is implemented in JAX and is based on EasyLM.
mesh_dim
refers to the the mesh used by JAX to parallelize computation across multiple accelerators and hosts. Please refer to the EasyLM paralellization documentation for configuration.seq_length
andglobal_batch_size
determine the total number of tokens per batch (fixed to 0.5 million in our paper).load_model_config
is used to load a default configs frommodel.py
update_model_config
is used to update a default config. To update specific keys, pass a dictionary to the flag:
--update_model_config="dict(seq_modeling_block='ttt_linear', ttt_base_lr=1.0)"
All additional hyperparameters are specified Appendix C of our paper.
All model configuration flags can be found in model.py
. Here are a few important details to note:
We implement four TTT choices for the seq_modeling_block
:
ttt_linear
andttt_mlp
, which specify TTT layers within the Mamba backbone.ttt_linear_base
andttt_mlp_base
, which specify TTT layers within the Transformer backbone.
- For all
ttt_linear
experiments,ttt_base_lr
is set to 1.0. - For all
ttt_mlp
experiments:ttt_base_lr
is set to 0.1ttt_base_lr_init
is set to 0.01ttt_base_lr_warmup
is set to the total number of outer loop warmup steps.