-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
1,036 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from enum import Enum | ||
from typing import Any, Dict | ||
|
||
from ..utils import get_parameter_names | ||
from .condition_utils import Condition, get_condition_parameters_from_dict | ||
from .text import CaptionEmbeddingDropoutCondition, CaptionTextDropoutCondition, T5Condition | ||
|
||
|
||
class ConditionType(str, Enum): | ||
# Text conditions | ||
CLIP = "clip" | ||
T5 = "t5" | ||
|
||
# Dropout conditions | ||
CAPTION_TEXT_DROPOUT = "caption_text_dropout" | ||
CAPTION_EMBEDDING_DROPOUT = "caption_embedding_dropout" | ||
|
||
|
||
SUPPORTED_CONDITIONS = {condition_type.value for condition_type in ConditionType.__members__.values()} | ||
|
||
# fmt: off | ||
_CONDITION_TYPE_TO_CONDITION_MAPPING = { | ||
# Text conditions | ||
ConditionType.T5: T5Condition, | ||
|
||
# Dropout conditions | ||
ConditionType.CAPTION_EMBEDDING_DROPOUT: CaptionEmbeddingDropoutCondition, | ||
ConditionType.CAPTION_TEXT_DROPOUT: CaptionTextDropoutCondition, | ||
} | ||
# fmt: on | ||
|
||
|
||
def get_condition_cls(condition_type: ConditionType) -> Condition: | ||
return _CONDITION_TYPE_TO_CONDITION_MAPPING[condition_type] | ||
|
||
|
||
def get_condition(condition_type: ConditionType, condition_parameters: Dict[str, Any]) -> Condition: | ||
condition_cls = get_condition_cls(condition_type) | ||
accepted_parameters = get_parameter_names(condition_cls.__init__) | ||
condition_parameters = get_condition_parameters_from_dict(accepted_parameters, condition_parameters) | ||
return condition_cls(**condition_parameters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import Any, Dict, Set | ||
|
||
|
||
class Condition: | ||
def __init__(self, *args, **kwargs) -> None: | ||
pass | ||
|
||
def __call__(self, *args, **kwargs) -> None: | ||
raise NotImplementedError(f"Condition::__call__ is not implemented for {self.__class__.__name__}") | ||
|
||
|
||
def get_condition_parameters_from_dict(accepted_parameters: Set[str], parameters: Dict[str, Any]) -> Dict[str, Any]: | ||
return {k: v for k, v in parameters.items() if k in accepted_parameters} |
Oops, something went wrong.