diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 3e9f3ef7b4..bcd2a99004 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -817,26 +817,42 @@ def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: def get_model(args, hf_config): model_name = args.model dtype = args.quantization.model_dtype - max_seq_len = args.max_seq_len sep_embed = args.sep_embed position_embedding_base = 10000 max_position_embeddings = 2048 if "rope_theta" in hf_config: position_embedding_base = hf_config["rope_theta"] - if "max_position_embeddings" in hf_config: - max_position_embeddings = hf_config["max_position_embeddings"] - config = LlamaConfig( - **hf_config, - dtype=dtype, - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - ) - if max_seq_len != -1: - config.max_sequence_length = max_seq_len + # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, + # while Llama-1 variants use `max_sequence_length`. + # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. + # If none of them is defined, throw an error. + if "max_sequence_length" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + elif "max_position_embeddings" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + else: + raise Exception("The model config should contain information about maximum sequence length.") + + # If there is a user-provided maximum sequence length, override hf config. + if args.max_seq_len != -1: + config.max_sequence_length = args.max_seq_len param_manager = ParamManager() bb = relax.BlockBuilder()