diff --git a/src/cohere_finetune/base_model.py b/src/cohere_finetune/base_model.py new file mode 100644 index 0000000..a51790e --- /dev/null +++ b/src/cohere_finetune/base_model.py @@ -0,0 +1,114 @@ +import os +from consts import CHAT_PROMPT_TEMPLATE_CMD_R, CHAT_PROMPT_TEMPLATE_CMD_R_08_2024 +from utils import load_file, logger + + +def get_model_name_from_hf_config(hf_config_path: str) -> str: + """ + According to the config.json file in the HuggingFace checkpoint, get the name of the model. + + It distinguishes only the "supported" Cohere's models based on their config.json files, i.e., + it may not correctly identify a model when "un-supported" Cohere's models are used + """ + hf_config = load_file(hf_config_path) + + if hf_config["architectures"] != ["CohereForCausalLM"]: + raise ValueError("The model is not one of Cohere's models for causal LM") + + if hf_config["hidden_size"] == 8192 and hf_config["rope_theta"] == 8000000: + return "command-r" + elif hf_config["hidden_size"] == 8192 and hf_config["rope_theta"] == 4000000 and hf_config["max_position_embeddings"] == 131072: + return "command-r-08-2024" + elif hf_config["hidden_size"] == 12288 and hf_config["rope_theta"] == 75000000: + return "command-r-plus" + elif hf_config["hidden_size"] == 12288 and hf_config["rope_theta"] == 8000000: + return "command-r-plus-08-2024" + elif hf_config["hidden_size"] == 4096 and hf_config["rope_theta"] == 10000: + return "aya-expanse-8b" + elif hf_config["hidden_size"] == 8192 and hf_config["rope_theta"] == 4000000 and hf_config["max_position_embeddings"] == 8192: + return "aya-expanse-32b" + else: + raise ValueError("The model is not one of Cohere's models for causal LM that we support") + + +def get_model_config_from_model_name_and_model_path(model_name: str, model_path: str | None) -> dict: + """ + According to model_name and model_path, get the config of the base model, + which contains all information about the base model that we will use for cohere-finetune. + """ + if model_name == "command-r": + return { + "model_name": model_name, + "prompt_template": CHAT_PROMPT_TEMPLATE_CMD_R, + "hf_model_name_or_path": "CohereForAI/c4ai-command-r-v01" if model_path is None else model_path, + "max_possible_max_sequence_length": 16384, + } + elif model_name == "command-r-08-2024": + return { + "model_name": model_name, + "prompt_template": CHAT_PROMPT_TEMPLATE_CMD_R_08_2024, + "hf_model_name_or_path": "CohereForAI/c4ai-command-r-08-2024" if model_path is None else model_path, + "max_possible_max_sequence_length": 16384, + } + elif model_name == "command-r-plus": + return { + "model_name": model_name, + "prompt_template": CHAT_PROMPT_TEMPLATE_CMD_R, + "hf_model_name_or_path": "CohereForAI/c4ai-command-r-plus" if model_path is None else model_path, + "max_possible_max_sequence_length": 16384, + } + elif model_name == "command-r-plus-08-2024": + return { + "model_name": model_name, + "prompt_template": CHAT_PROMPT_TEMPLATE_CMD_R_08_2024, + "hf_model_name_or_path": "CohereForAI/c4ai-command-r-plus-08-2024" if model_path is None else model_path, + "max_possible_max_sequence_length": 16384, + } + elif model_name == "aya-expanse-8b": + return { + "model_name": model_name, + "prompt_template": CHAT_PROMPT_TEMPLATE_CMD_R_08_2024, + "hf_model_name_or_path": "CohereForAI/aya-expanse-8b" if model_path is None else model_path, + "max_possible_max_sequence_length": 16384, + } + elif model_name == "aya-expanse-32b": + return { + "model_name": model_name, + "prompt_template": CHAT_PROMPT_TEMPLATE_CMD_R_08_2024, + "hf_model_name_or_path": "CohereForAI/aya-expanse-32b" if model_path is None else model_path, + "max_possible_max_sequence_length": 16384, + } + else: + raise ValueError(f"{model_name} is not a valid and supported model name") + + +class BaseModel: + """Base model for finetuning.""" + + def __init__(self, model_name_or_path: str) -> None: + """Initialize BaseModel.""" + try: + model_name = get_model_name_from_hf_config(os.path.join(model_name_or_path, "config.json")) + model_path = model_name_or_path + except FileNotFoundError: + model_name = model_name_or_path + model_path = None + + self.model_config = get_model_config_from_model_name_and_model_path(model_name, model_path) + logger.info(f"The base model config is as follows:\n{self.model_config}") + + def get_model_name(self) -> str: + """Get the name of the model.""" + return self.model_config["model_name"] + + def get_prompt_template(self) -> str: + """Get the prompt template for the model.""" + return self.model_config["prompt_template"] + + def get_hf_model_name_or_path(self) -> str: + """Get the HuggingFace model name or path for the model.""" + return self.model_config["hf_model_name_or_path"] + + def get_max_possible_max_sequence_length(self) -> int: + """Get the max possible max sequence length for the model.""" + return self.model_config["max_possible_max_sequence_length"] diff --git a/src/cohere_finetune/cohere_finetune_service.py b/src/cohere_finetune/cohere_finetune_service.py index 98f0c44..66bcacc 100644 --- a/src/cohere_finetune/cohere_finetune_service.py +++ b/src/cohere_finetune/cohere_finetune_service.py @@ -108,7 +108,7 @@ def finetune(self) -> None: return # Create and prepare the tokenizer - tokenizer = create_and_prepare_tokenizer(self.hyperparameters.base_model.get_hf_model_name()) + tokenizer = create_and_prepare_tokenizer(self.hyperparameters.base_model.get_hf_model_name_or_path()) # Preprocess the finetuning dataset by doing train eval split (if needed) and putting the texts in template try: diff --git a/src/cohere_finetune/configs.py b/src/cohere_finetune/configs.py index 66d04c8..0765fe5 100644 --- a/src/cohere_finetune/configs.py +++ b/src/cohere_finetune/configs.py @@ -1,6 +1,7 @@ import os import torch -from consts import BaseModel, FinetuneStrategy, ParallelStrategy, FINETUNE_BACKEND_KEY, PATH_PREFIX_KEY +from base_model import BaseModel +from consts import FinetuneStrategy, ParallelStrategy, FINETUNE_BACKEND_KEY, PATH_PREFIX_KEY from typing import Any @@ -87,7 +88,7 @@ class Hyperparameters(BaseConfig): def __init__( self, finetune_name: str, - base_model: BaseModel = BaseModel.COMMAND_R_08_2024, + base_model_name_or_path: str = "command-r-08-2024", parallel_strategy: ParallelStrategy = ParallelStrategy.FSDP, finetune_strategy: FinetuneStrategy = FinetuneStrategy.LORA, use_4bit_quantization: bool = False, @@ -103,7 +104,7 @@ def __init__( ) -> None: """Initialize Hyperparameters.""" self.finetune_name = finetune_name - self.base_model = BaseModel(base_model) + self.base_model = BaseModel(base_model_name_or_path) self.parallel_strategy = ParallelStrategy(parallel_strategy) self.finetune_strategy = FinetuneStrategy(finetune_strategy) self.use_4bit_quantization = use_4bit_quantization diff --git a/src/cohere_finetune/consts.py b/src/cohere_finetune/consts.py index 74c9595..b319b1f 100644 --- a/src/cohere_finetune/consts.py +++ b/src/cohere_finetune/consts.py @@ -10,49 +10,6 @@ FINETUNE_BACKEND_KEY = "FINETUNE_BACKEND" -class BaseModel(str, Enum): - """Base model for finetuning.""" - - COMMAND_R = "command-r" - COMMAND_R_08_2024 = "command-r-08-2024" - COMMAND_R_PLUS = "command-r-plus" - COMMAND_R_PLUS_08_2024 = "command-r-plus-08-2024" - - @classmethod - def list_options(cls) -> str: - """List all model options.""" - return ", ".join([item.name for item in BaseModel]) - - def get_prompt_template(self) -> str: - """Get the prompt template for the model.""" - if self == BaseModel.COMMAND_R or self == BaseModel.COMMAND_R_PLUS: - return CHAT_PROMPT_TEMPLATE_CMD_R - else: - return CHAT_PROMPT_TEMPLATE_CMD_R_08_2024 - - def get_hf_model_name(self) -> str: - """Get the HuggingFace model name for the model.""" - if self == BaseModel.COMMAND_R: - return "CohereForAI/c4ai-command-r-v01" - elif self == BaseModel.COMMAND_R_08_2024: - return "CohereForAI/c4ai-command-r-08-2024" - elif self == BaseModel.COMMAND_R_PLUS: - return "CohereForAI/c4ai-command-r-plus" - else: - return "CohereForAI/c4ai-command-r-plus-08-2024" - - def get_max_possible_max_sequence_length(self) -> int: - """Get the max possible max sequence length for the model.""" - if self == BaseModel.COMMAND_R: - return 16384 - elif self == BaseModel.COMMAND_R_08_2024: - return 16384 - elif self == BaseModel.COMMAND_R_PLUS: - return 16384 - else: - return 16384 - - class FinetuneStrategy(str, Enum): """Supported strategies for finetuning.""" diff --git a/src/cohere_finetune/finetune_backends/cohere_peft/peft_data.py b/src/cohere_finetune/finetune_backends/cohere_peft/peft_data.py index 9ab8be8..552ff3c 100644 --- a/src/cohere_finetune/finetune_backends/cohere_peft/peft_data.py +++ b/src/cohere_finetune/finetune_backends/cohere_peft/peft_data.py @@ -100,7 +100,7 @@ def preprocess_hf_datasets( raw_datasets: DatasetDict, tokenizer: CohereTokenizerFast, apply_chat_template: bool = False, -) -> (Dataset, Dataset): +) -> tuple[Dataset, Dataset]: """ Preprocess HuggingFace datasets by applying the template of Cohere on them, if the data has not been preprocessed. diff --git a/src/cohere_finetune/preprocess.py b/src/cohere_finetune/preprocess.py index d435e15..7a9f65e 100644 --- a/src/cohere_finetune/preprocess.py +++ b/src/cohere_finetune/preprocess.py @@ -72,7 +72,7 @@ def get_valid_deduped_chats(chats: list[dict]) -> list[dict]: return valid_deduped_chats -def train_eval_split(chats: list[dict], eval_percentage: float) -> (list[dict], list[dict]): +def train_eval_split(chats: list[dict], eval_percentage: float) -> tuple[list[dict], list[dict]]: """Randomly split the chats into a training set and an evaluation set by eval_percentage.""" n = len(chats) n_eval = int(n * eval_percentage) @@ -209,7 +209,7 @@ def render_chat_context(liquid_template: Liquid, chat_context: ChatContext, incl ) -def pop_turns_to_fit_max_sequence_length(chat_context: ChatContext, max_sequence_length: int) -> (list[Turn], int): +def pop_turns_to_fit_max_sequence_length(chat_context: ChatContext, max_sequence_length: int) -> tuple[list[Turn], int]: """Pop turns one by one from left to right until the total number of tokens <= max_sequence_length.""" n_tokens_removed = 0 for i in range(len(chat_context.turns) - 1): diff --git a/src/cohere_finetune/train.py b/src/cohere_finetune/train.py index 60cb230..e517b8b 100644 --- a/src/cohere_finetune/train.py +++ b/src/cohere_finetune/train.py @@ -54,7 +54,7 @@ def train_with_peft(path_config: CoherePeftPathConfig, hyperparameters: Hyperpar peft_cmd = ["python"] # Get the HuggingFace model name - model_name_or_path = hyperparameters.base_model.get_hf_model_name() + model_name_or_path = hyperparameters.base_model.get_hf_model_name_or_path() # Get the per_device_train_batch_size and per_device_eval_batch_size per_device_train_batch_size, r_train = divmod( diff --git a/src/cohere_finetune/utils.py b/src/cohere_finetune/utils.py index 2c5d255..6199c51 100644 --- a/src/cohere_finetune/utils.py +++ b/src/cohere_finetune/utils.py @@ -149,7 +149,7 @@ def save_file(x: Any, path: str, overwrite_ok: bool = False) -> None: raise NotImplementedError -def get_lines(data_path: str) -> (list[str], int): +def get_lines(data_path: str) -> tuple[list[str], int]: """Read the lines from a file, where a line will be dropped if we can't decode it.""" with open(data_path, "rb") as file_bytes: lines = [] @@ -163,7 +163,7 @@ def get_lines(data_path: str) -> (list[str], int): return lines, n_dropped_lines -def load_and_prepare_csv(data_path: str, column_types: dict) -> (pd.DataFrame, int): +def load_and_prepare_csv(data_path: str, column_types: dict) -> tuple[pd.DataFrame, int]: """Load a CSV file as a Pandas dataframe, and do some basic data cleaning.""" assert get_ext(data_path) == ".csv"