Skip to content

Commit

Permalink
model specification
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jan 25, 2025
1 parent e0ced4f commit 94c5936
Show file tree
Hide file tree
Showing 17 changed files with 1,036 additions and 193 deletions.
136 changes: 108 additions & 28 deletions finetrainers/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from .conditions import SUPPORTED_CONDITIONS, ConditionType
from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS
from .models import SUPPORTED_MODEL_CONFIGS

Expand Down Expand Up @@ -33,6 +34,16 @@ class Args:
storage requirements.
cache_dir (`str`, defaults to `None`):
The directory where the downloaded models and datasets will be stored, or loaded from.
text_encoder_id (`str`, defaults to `None`):
Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
text_encoder_2_id (`str`, defaults to `None`):
Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
text_encoder_3_id (`str`, defaults to `None`):
Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
transformer_id (`str`, defaults to `None`):
Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`.
vae_id (`str`, defaults to `None`):
Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`.
text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Data type for the text encoder when generating text embeddings.
text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Expand Down Expand Up @@ -134,6 +145,8 @@ class Args:
Type of training to perform. Choose between ['lora'].
seed (`int`, defaults to `42`):
A seed for reproducible training.
conditions (`List[str]`, defaults to `[]`):
List of conditions to use for training. To get a list of conditions, run `python train.py --list_conditions`.
batch_size (`int`, defaults to `1`):
Per-device batch size.
train_epochs (`int`, defaults to `1`):
Expand Down Expand Up @@ -244,6 +257,11 @@ class Args:
revision: Optional[str] = None
variant: Optional[str] = None
cache_dir: Optional[str] = None
text_encoder_id: Optional[str] = None
text_encoder_2_id: Optional[str] = None
text_encoder_3_id: Optional[str] = None
transformer_id: Optional[str] = None
vae_id: Optional[str] = None
text_encoder_dtype: torch.dtype = torch.bfloat16
text_encoder_2_dtype: torch.dtype = torch.bfloat16
text_encoder_3_dtype: torch.dtype = torch.bfloat16
Expand Down Expand Up @@ -271,8 +289,6 @@ class Args:
image_resolution_buckets: List[Tuple[int, int]] = None
video_resolution_buckets: List[Tuple[int, int, int]] = None
video_reshape_mode: Optional[str] = None
caption_dropout_p: float = 0.00
caption_dropout_technique: str = "empty"
precompute_conditions: bool = False
remove_common_llm_caption_prefixes: bool = False

Expand All @@ -295,6 +311,7 @@ class Args:
# Training arguments
training_type: str = None
seed: int = 42
conditions: List[str] = []
batch_size: int = 1
train_epochs: int = 1
train_steps: int = None
Expand Down Expand Up @@ -349,6 +366,11 @@ class Args:
nccl_timeout: int = 1800 # 30 minutes
report_to: str = "wandb"

# Condition-specfic arguments
# 1. Caption Dropout
caption_dropout_p: float = 0.00
caption_dropout_technique: str = "empty"

def to_dict(self) -> Dict[str, Any]:
return {
"model_arguments": {
Expand All @@ -357,6 +379,11 @@ def to_dict(self) -> Dict[str, Any]:
"revision": self.revision,
"variant": self.variant,
"cache_dir": self.cache_dir,
"text_encoder_id": self.text_encoder_id,
"text_encoder_2_id": self.text_encoder_2_id,
"text_encoder_3_id": self.text_encoder_3_id,
"transformer_id": self.transformer_id,
"vae_id": self.vae_id,
"text_encoder_dtype": self.text_encoder_dtype,
"text_encoder_2_dtype": self.text_encoder_2_dtype,
"text_encoder_3_dtype": self.text_encoder_3_dtype,
Expand All @@ -375,8 +402,6 @@ def to_dict(self) -> Dict[str, Any]:
"image_resolution_buckets": self.image_resolution_buckets,
"video_resolution_buckets": self.video_resolution_buckets,
"video_reshape_mode": self.video_reshape_mode,
"caption_dropout_p": self.caption_dropout_p,
"caption_dropout_technique": self.caption_dropout_technique,
"precompute_conditions": self.precompute_conditions,
"remove_common_llm_caption_prefixes": self.remove_common_llm_caption_prefixes,
},
Expand All @@ -399,6 +424,7 @@ def to_dict(self) -> Dict[str, Any]:
"training_arguments": {
"training_type": self.training_type,
"seed": self.seed,
"conditions": self.conditions,
"batch_size": self.batch_size,
"train_epochs": self.train_epochs,
"train_steps": self.train_steps,
Expand Down Expand Up @@ -450,17 +476,29 @@ def to_dict(self) -> Dict[str, Any]:
"nccl_timeout": self.nccl_timeout,
"report_to": self.report_to,
},
"condition_arguments": {
"text": {
"caption_dropout_p": self.caption_dropout_p,
"caption_dropout_technique": self.caption_dropout_technique,
},
},
}


# TODO(aryan): handle more informative messages
_IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv
_LIST_MODELS = "--list_models"
_LIST_CONDITIONS = "--list_conditions"


def parse_arguments() -> Args:
parser = argparse.ArgumentParser()

if _IS_ARGUMENTS_REQUIRED:
special_args = [_LIST_MODELS, _LIST_CONDITIONS]
if any(arg in sys.argv for arg in special_args):
_add_helper_arguments(parser)
args = parser.parse_args()
_display_helper_messages(args)
sys.exit(0)
else:
_add_model_arguments(parser)
_add_dataset_arguments(parser)
_add_dataloader_arguments(parser)
Expand All @@ -470,14 +508,11 @@ def parse_arguments() -> Args:
_add_validation_arguments(parser)
_add_miscellaneous_arguments(parser)

args = parser.parse_args()
return _map_to_args_type(args)
else:
_add_helper_arguments(parser)
args, remaining_args = parser.parse_known_args()
_add_opt_in_arguments(parser, remaining_args, args)

args = parser.parse_args()
_display_helper_messages(args)
sys.exit(0)
return _map_to_args_type(args)


def validate_args(args: Args):
Expand Down Expand Up @@ -519,6 +554,15 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--text_encoder_id", type=str, default=None, help="Identifier for the text encoder model.")
parser.add_argument(
"--text_encoder_2_id", type=str, default=None, help="Identifier for the second text encoder model."
)
parser.add_argument(
"--text_encoder_3_id", type=str, default=None, help="Identifier for the third text encoder model."
)
parser.add_argument("--transformer_id", type=str, default=None, help="Identifier for the transformer model.")
parser.add_argument("--vae_id", type=str, default=None, help="Identifier for the VAE model.")
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.")
parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.")
parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
Expand Down Expand Up @@ -616,19 +660,6 @@ def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int
default=None,
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument(
"--caption_dropout_p",
type=float,
default=0.00,
help="Probability of dropout for the caption tokens.",
)
parser.add_argument(
"--caption_dropout_technique",
type=str,
default="empty",
choices=["empty", "zero"],
help="Technique to use for caption dropout.",
)
parser.add_argument(
"--precompute_conditions",
action="store_true",
Expand Down Expand Up @@ -728,6 +759,14 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
help="Type of training to perform. Choose between ['lora', 'full-finetune']",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--conditions",
type=str,
default=[],
nargs="+",
choices=SUPPORTED_CONDITIONS,
help="List of conditions to use for training.",
)
parser.add_argument(
"--batch_size",
type=int,
Expand Down Expand Up @@ -997,12 +1036,38 @@ def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
)


def _add_opt_in_arguments(
parser: argparse.ArgumentParser, remaining_args: List[str], args: argparse.Namespace
) -> None:
condition_subparser = parser.add_subparsers(dest="condition_subparser", help="Condition-specific arguments.")

# Caption dropout arguments
if any(
condition_name in args.conditions
for condition_name in [ConditionType.CAPTION_TEXT_DROPOUT, ConditionType.CAPTION_EMBEDDING_DROPOUT]
):
caption_dropout_parser = condition_subparser.add_parser(
"caption_dropout", help="Arguments for caption dropout."
)
caption_dropout_parser.add_argument(
"--caption_dropout_p",
type=float,
default=0.00,
help="Probability of dropout for the caption tokens.",
)


def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--list_models",
action="store_true",
help="List all the supported models.",
)
parser.add_argument(
"--list_conditions",
action="store_true",
help="List all the supported conditions.",
)


_DTYPE_MAP = {
Expand All @@ -1023,6 +1088,11 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.revision = args.revision
result_args.variant = args.variant
result_args.cache_dir = args.cache_dir
result_args.text_encoder_id = args.text_encoder_id
result_args.text_encoder_2_id = args.text_encoder_2_id
result_args.text_encoder_3_id = args.text_encoder_3_id
result_args.transformer_id = args.transformer_id
result_args.vae_id = args.vae_id
result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
Expand All @@ -1044,8 +1114,6 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.image_resolution_buckets = args.image_resolution_buckets or DEFAULT_IMAGE_RESOLUTION_BUCKETS
result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
result_args.video_reshape_mode = args.video_reshape_mode
result_args.caption_dropout_p = args.caption_dropout_p
result_args.caption_dropout_technique = args.caption_dropout_technique
result_args.precompute_conditions = args.precompute_conditions
result_args.remove_common_llm_caption_prefixes = args.remove_common_llm_caption_prefixes

Expand All @@ -1068,6 +1136,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
# Training arguments
result_args.training_type = args.training_type
result_args.seed = args.seed
result_args.conditions = args.conditions
result_args.batch_size = args.batch_size
result_args.train_epochs = args.train_epochs
result_args.train_steps = args.train_steps
Expand Down Expand Up @@ -1147,6 +1216,14 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
result_args.nccl_timeout = args.nccl_timeout
result_args.report_to = args.report_to

# Condition-specific arguments
if any(
condition_name in args.conditions
for condition_name in [ConditionType.CAPTION_TEXT_DROPOUT, ConditionType.CAPTION_EMBEDDING_DROPOUT]
):
result_args.caption_dropout_p = args.caption_dropout_p
result_args.caption_dropout_technique = args.caption_dropout_technique

return result_args


Expand Down Expand Up @@ -1189,3 +1266,6 @@ def _display_helper_messages(args: argparse.Namespace):
print("Supported models:")
for index, model_name in enumerate(SUPPORTED_MODEL_CONFIGS.keys()):
print(f" {index + 1}. {model_name}")

elif args.list_conditions:
print(f"Supported conditions: {SUPPORTED_CONDITIONS}")
41 changes: 41 additions & 0 deletions finetrainers/conditions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from enum import Enum
from typing import Any, Dict

from ..utils import get_parameter_names
from .condition_utils import Condition, get_condition_parameters_from_dict
from .text import CaptionEmbeddingDropoutCondition, CaptionTextDropoutCondition, T5Condition


class ConditionType(str, Enum):
# Text conditions
CLIP = "clip"
T5 = "t5"

# Dropout conditions
CAPTION_TEXT_DROPOUT = "caption_text_dropout"
CAPTION_EMBEDDING_DROPOUT = "caption_embedding_dropout"


SUPPORTED_CONDITIONS = {condition_type.value for condition_type in ConditionType.__members__.values()}

# fmt: off
_CONDITION_TYPE_TO_CONDITION_MAPPING = {
# Text conditions
ConditionType.T5: T5Condition,

# Dropout conditions
ConditionType.CAPTION_EMBEDDING_DROPOUT: CaptionEmbeddingDropoutCondition,
ConditionType.CAPTION_TEXT_DROPOUT: CaptionTextDropoutCondition,
}
# fmt: on


def get_condition_cls(condition_type: ConditionType) -> Condition:
return _CONDITION_TYPE_TO_CONDITION_MAPPING[condition_type]


def get_condition(condition_type: ConditionType, condition_parameters: Dict[str, Any]) -> Condition:
condition_cls = get_condition_cls(condition_type)
accepted_parameters = get_parameter_names(condition_cls.__init__)
condition_parameters = get_condition_parameters_from_dict(accepted_parameters, condition_parameters)
return condition_cls(**condition_parameters)
13 changes: 13 additions & 0 deletions finetrainers/conditions/condition_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Any, Dict, Set


class Condition:
def __init__(self, *args, **kwargs) -> None:
pass

def __call__(self, *args, **kwargs) -> None:
raise NotImplementedError(f"Condition::__call__ is not implemented for {self.__class__.__name__}")


def get_condition_parameters_from_dict(accepted_parameters: Set[str], parameters: Dict[str, Any]) -> Dict[str, Any]:
return {k: v for k, v in parameters.items() if k in accepted_parameters}
Loading

0 comments on commit 94c5936

Please sign in to comment.