diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 7bc0274..9fd92c2 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -11,7 +11,7 @@ from tabulate import tabulate from together import Together -from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX +from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX, FROM_STEP_TYPE from together.utils import ( finetune_price_to_dollars, log_warn, @@ -126,6 +126,12 @@ def fine_tuning(ctx: click.Context) -> None: help="Whether to mask the user messages in conversational data or prompts in instruction data. " "`auto` will automatically determine whether to mask the inputs based on the data format.", ) +@click.option( + "--from-step", + type=FROM_STEP_TYPE, + default="final", + help="From which checkpoint start a fine-tuning job" +) def create( ctx: click.Context, training_file: str, @@ -152,6 +158,7 @@ def create( wandb_name: str, confirm: bool, train_on_inputs: bool | Literal["auto"], + from_step: int | Literal["final"], ) -> None: """Start fine-tuning""" client: Together = ctx.obj @@ -180,6 +187,7 @@ def create( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + from_step=from_step, ) model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits( diff --git a/src/together/cli/api/utils.py b/src/together/cli/api/utils.py index 08dfe49..116de08 100644 --- a/src/together/cli/api/utils.py +++ b/src/together/cli/api/utils.py @@ -47,5 +47,22 @@ def convert( ) +class FromStepParamType(click.ParamType): + name = "from_step" + + def convert(self, value: str, param: click.Parameter | None, ctx: click.Context | None) -> int | Literal["final"] | None: + if value == "final": + return "final" + try: + return int(value) + except ValueError: + self.fail( + _("{value!r} is not a valid {type}.").format( + value=value, type=self.name + ), + ) + + INT_WITH_MAX = AutoIntParamType() BOOL_WITH_AUTO = BooleanWithAutoParamType() +FROM_STEP_TYPE = FromStepParamType() diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index b58cdae..4f90ac5 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -52,6 +52,7 @@ def createFinetuneRequest( wandb_project_name: str | None = None, wandb_name: str | None = None, train_on_inputs: bool | Literal["auto"] = "auto", + from_step: int | Literal["final"] = "final" ) -> FinetuneRequest: if batch_size == "max": log_warn_once( @@ -100,6 +101,9 @@ def createFinetuneRequest( if weight_decay is not None and (weight_decay < 0): raise ValueError("Weight decay should be non-negative") + if from_step == "final": + from_step = -1 + lrScheduler = FinetuneLRScheduler( lr_scheduler_type="linear", lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), @@ -125,6 +129,7 @@ def createFinetuneRequest( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + from_step=from_step, ) return finetune_request @@ -162,6 +167,7 @@ def create( verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, train_on_inputs: bool | Literal["auto"] = "auto", + from_step: int | Literal["final"] = "final", ) -> FinetuneResponse: """ Method to initiate a fine-tuning job @@ -207,6 +213,7 @@ def create( For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. Defaults to "auto". + from_step (int or "final"): From which checkpoint start a fine-tuning job Returns: FinetuneResponse: Object containing information about fine-tuning job. @@ -244,6 +251,7 @@ def create( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + from_step=from_step, ) if verbose: diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 05bc8c4..8a06919 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -178,6 +178,8 @@ class FinetuneRequest(BaseModel): training_type: FullTrainingType | LoRATrainingType | None = None # train on inputs train_on_inputs: StrictBool | Literal["auto"] = "auto" + # from step + from_step: int | None = -1 class FinetuneResponse(BaseModel): @@ -256,6 +258,7 @@ class FinetuneResponse(BaseModel): training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines") training_file_size: int | None = Field(None, alias="TrainingFileSize") train_on_inputs: StrictBool | Literal["auto"] | None = "auto" + from_step: int | None = "-1" @field_validator("training_type") @classmethod