Skip to content

Commit

Permalink
feat: validate default values
Browse files Browse the repository at this point in the history
  • Loading branch information
sisp committed Apr 7, 2023
1 parent 91f21c7 commit 7d90d7d
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 18 deletions.
39 changes: 23 additions & 16 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def _check_type(cls, v, values):

def get_default(self) -> Any:
"""Get the default value for this question, casted to its expected type."""
cast_fn = self.get_cast_fn()
try:
result = self.answers.init[self.var_name]
except KeyError:
Expand All @@ -240,7 +239,7 @@ def get_default(self) -> Any:
if self.default is MISSING:
return MISSING
result = self.render_value(self.default)
result = cast_answer_type(result, cast_fn)
result = cast_answer_type(result, self.get_type_name())
return result

def get_default_rendered(self) -> Union[bool, str, Choice, None, MissingType]:
Expand Down Expand Up @@ -300,7 +299,7 @@ def _formatted_choices(self) -> Sequence[Choice]:

def filter_answer(self, answer) -> Any:
"""Cast the answer to the desired type."""
return cast_answer_type(answer, self.get_cast_fn())
return cast_answer_type(answer, self.get_type_name())

def get_message(self) -> str:
"""Get the message that will be printed to the user."""
Expand Down Expand Up @@ -360,16 +359,12 @@ def get_questionary_structure(self) -> AnyByStrDict:
result.update({"type": questionary_type})
return result

def get_cast_fn(self) -> Callable:
"""Obtain function to cast user answer to desired type."""
type_name = self.get_type_name()
if type_name not in CAST_STR_TO_NATIVE:
raise InvalidTypeError("Invalid question type")
return CAST_STR_TO_NATIVE.get(type_name, parse_yaml_string)

def get_type_name(self) -> str:
"""Render the type name and return it."""
return self.render_value(self.type)
type_name = self.render_value(self.type)
if type_name not in CAST_STR_TO_NATIVE:
raise InvalidTypeError("Invalid question type")
return type_name

def get_multiline(self) -> bool:
"""Get the value for multiline."""
Expand Down Expand Up @@ -415,10 +410,10 @@ def render_value(

def parse_answer(self, answer: Any) -> Any:
"""Parse the answer according to the question's type."""
cast_fn = self.get_cast_fn()
ans = cast_answer_type(answer, cast_fn)
type_name = self.get_type_name()
ans = cast_answer_type(answer, type_name)
choice_values = {
cast_answer_type(choice.value, cast_fn)
cast_answer_type(choice.value, type_name)
for choice in self._formatted_choices
}
if choice_values and ans not in choice_values:
Expand Down Expand Up @@ -450,13 +445,25 @@ def load_answersfile_data(
return {}


def cast_answer_type(answer: Any, type_fn: Callable) -> Any:
def cast_answer_type(answer: Any, type_name: str) -> Any:
"""Cast answer to expected type."""
try:
type_fn = CAST_STR_TO_NATIVE[type_name]
except KeyError:
raise InvalidTypeError("Invalid answer type")
# Only JSON or YAML questions support `None` as an answer
if answer is None and type_name not in {"json", "yaml"}:
raise TypeError(
f'Invalid answer of type "{type(answer)}" to question of type '
f'"{type_name}"'
)
try:
return type_fn(answer)
except (TypeError, AttributeError):
# JSON or YAML failed because it wasn't a string; no need to convert
return answer
if type_name in {"json", "yaml"}:
return answer
raise


CAST_STR_TO_NATIVE: Mapping[str, Callable] = {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_complex_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def template_path(tmp_path_factory: pytest.TempPathFactory) -> str:
three: third
choose_number:
help: This must be a number
default: null
default: -1.1
type: float
choices:
- -1.1
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_api(template_path: str, tmp_path: Path) -> None:
choose_list: "first"
choose_tuple: "second"
choose_dict: "third"
choose_number: null
choose_number: -1.1
minutes_under_water: 10
optional_value: null
"""
Expand Down
79 changes: 79 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,82 @@ def test_required_choice_question_without_data(
)
with pytest.raises(ValueError, match='Question "question" is required'):
copier.copy(str(src), dst, defaults=True)


@pytest.mark.parametrize(
"type_name, default, expected",
[
("str", "string", does_not_raise()),
("str", "1.0", does_not_raise()),
("str", 1.0, does_not_raise()),
("str", None, pytest.raises(TypeError)),
("int", 1, does_not_raise()),
("int", 1.0, does_not_raise()),
("int", "1", does_not_raise()),
("int", "1.0", pytest.raises(ValueError)),
("int", "no-int", pytest.raises(ValueError)),
("int", None, pytest.raises(TypeError)),
("float", 1.1, does_not_raise()),
("float", 1, does_not_raise()),
("float", "1.1", does_not_raise()),
("float", "no-float", pytest.raises(ValueError)),
("float", None, pytest.raises(TypeError)),
("bool", True, does_not_raise()),
("bool", False, does_not_raise()),
("bool", "y", does_not_raise()),
("bool", "n", does_not_raise()),
("bool", None, pytest.raises(TypeError)),
("json", '"string"', does_not_raise()),
("json", "1", does_not_raise()),
("json", 1, does_not_raise()),
("json", "1.1", does_not_raise()),
("json", 1.1, does_not_raise()),
("json", "true", does_not_raise()),
("json", True, does_not_raise()),
("json", "false", does_not_raise()),
("json", False, does_not_raise()),
("json", "{}", does_not_raise()),
("json", {}, does_not_raise()),
("json", "[]", does_not_raise()),
("json", [], does_not_raise()),
("json", "null", does_not_raise()),
("json", None, does_not_raise()),
("yaml", '"string"', does_not_raise()),
("yaml", "string", does_not_raise()),
("yaml", "1", does_not_raise()),
("yaml", 1, does_not_raise()),
("yaml", "1.1", does_not_raise()),
("yaml", 1.1, does_not_raise()),
("yaml", "true", does_not_raise()),
("yaml", True, does_not_raise()),
("yaml", "false", does_not_raise()),
("yaml", False, does_not_raise()),
("yaml", "{}", does_not_raise()),
("yaml", {}, does_not_raise()),
("yaml", "[]", does_not_raise()),
("yaml", [], does_not_raise()),
("yaml", "null", does_not_raise()),
("yaml", None, does_not_raise()),
],
)
def test_validate_default_value(
tmp_path_factory: pytest.TempPathFactory,
type_name: str,
default: Any,
expected: ContextManager[None],
) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): yaml.dump(
{
"q": {
"type": type_name,
"default": default,
}
}
)
}
)
with expected:
copier.copy(str(src), dst, defaults=True)

0 comments on commit 7d90d7d

Please sign in to comment.