Skip to content

Commit

Permalink
update predictor length init
Browse files Browse the repository at this point in the history
  • Loading branch information
wj-Mcat committed Jan 2, 2024
1 parent ea9f5b3 commit f4cd00c
Showing 1 changed file with 34 additions and 32 deletions.
66 changes: 34 additions & 32 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,40 @@ def create_predictor(
if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.unk_token

config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)

max_position_embeddings = get_model_max_position_embeddings(config)
if max_position_embeddings is None:
max_position_embeddings = 2048
logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048")

if predictor_args.src_length is None:
if predictor_args.max_length is None:
predictor_args.src_length = get_default_max_encoding_length(config)
predictor_args.max_length = get_default_max_decoding_length(config)
else:
predictor_args.src_length = max_position_embeddings - predictor_args.max_length
if predictor_args.src_length <= 0:
raise ValueError(
f"--max_length<{predictor_args.max_length}> param should be smaller "
f"than max_position_embeddings<{max_position_embeddings}>"
)
else:
if predictor_args.max_length is None:
predictor_args.max_length = max_position_embeddings - predictor_args.src_length
if predictor_args.max_length <= 0:
raise ValueError(
f"--src_length<{predictor_args.src_length}> param should be smaller "
f"than max_position_embeddings<{max_position_embeddings}>"
)
else:
if predictor_args.src_length + predictor_args.max_length > max_position_embeddings:
raise ValueError(
f"The sum of src_length<{predictor_args.src_length}> and "
f"max_length<{predictor_args.max_length}> should be smaller than or equal to "
f"the maximum position embedding size<{max_position_embeddings}>"
)

# update config parameter for inference predictor
if predictor_args.decode_strategy == "greedy_search":
predictor_args.top_p = 0.0
Expand Down Expand Up @@ -889,38 +923,6 @@ def create_predictor(
else:
raise ValueError("the `mode` should be one of [dynamic, static]")

max_position_embeddings = get_model_max_position_embeddings(predictor.model_config)
if max_position_embeddings is None:
max_position_embeddings = 2048
logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048")

if predictor.config.src_length is None:
if predictor.config.max_length is None:
predictor.config.src_length = get_default_max_encoding_length(predictor.model_config)
predictor.config.max_length = get_default_max_decoding_length(predictor.model_config)
else:
predictor.config.src_length = max_position_embeddings - predictor.config.max_length
if predictor.config.src_length <= 0:
raise ValueError(
f"--max_length<{predictor.config.max_length}> param should be smaller "
f"than max_position_embeddings<{max_position_embeddings}>"
)
else:
if predictor.config.max_length is None:
predictor.config.max_length = max_position_embeddings - predictor.config.src_length
if predictor.config.max_length <= 0:
raise ValueError(
f"--src_length<{predictor.config.src_length}> param should be smaller "
f"than max_position_embeddings<{max_position_embeddings}>"
)
else:
if predictor.config.src_length + predictor.config.max_length > max_position_embeddings:
raise ValueError(
f"The sum of src_length<{predictor.config.src_length}> and "
f"max_length<{predictor.config.max_length}> should be smaller than or equal to "
f"the maximum position embedding size<{max_position_embeddings}>"
)

return predictor


Expand Down

0 comments on commit f4cd00c

Please sign in to comment.