Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: validate default values #1075

Merged
merged 4 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def _check_type(cls, v, values):

def get_default(self) -> Any:
"""Get the default value for this question, casted to its expected type."""
cast_fn = self.get_cast_fn()
try:
result = self.answers.init[self.var_name]
except KeyError:
Expand All @@ -240,7 +239,7 @@ def get_default(self) -> Any:
if self.default is MISSING:
return MISSING
result = self.render_value(self.default)
result = cast_answer_type(result, cast_fn)
result = cast_answer_type(result, self.get_type_name())
return result

def get_default_rendered(self) -> Union[bool, str, Choice, None, MissingType]:
Expand Down Expand Up @@ -300,7 +299,7 @@ def _formatted_choices(self) -> Sequence[Choice]:

def filter_answer(self, answer) -> Any:
"""Cast the answer to the desired type."""
return cast_answer_type(answer, self.get_cast_fn())
return cast_answer_type(answer, self.get_type_name())

def get_message(self) -> str:
"""Get the message that will be printed to the user."""
Expand Down Expand Up @@ -360,16 +359,12 @@ def get_questionary_structure(self) -> AnyByStrDict:
result.update({"type": questionary_type})
return result

def get_cast_fn(self) -> Callable:
"""Obtain function to cast user answer to desired type."""
type_name = self.get_type_name()
if type_name not in CAST_STR_TO_NATIVE:
raise InvalidTypeError("Invalid question type")
return CAST_STR_TO_NATIVE.get(type_name, parse_yaml_string)

def get_type_name(self) -> str:
"""Render the type name and return it."""
return self.render_value(self.type)
type_name = self.render_value(self.type)
if type_name not in CAST_STR_TO_NATIVE:
raise InvalidTypeError("Invalid question type")
return type_name

def get_multiline(self) -> bool:
"""Get the value for multiline."""
Expand Down Expand Up @@ -415,10 +410,10 @@ def render_value(

def parse_answer(self, answer: Any) -> Any:
"""Parse the answer according to the question's type."""
cast_fn = self.get_cast_fn()
ans = cast_answer_type(answer, cast_fn)
type_name = self.get_type_name()
ans = cast_answer_type(answer, type_name)
choice_values = [
cast_answer_type(choice.value, cast_fn)
cast_answer_type(choice.value, type_name)
for choice in self._formatted_choices
]
if choice_values and ans not in choice_values:
Expand Down Expand Up @@ -450,13 +445,25 @@ def load_answersfile_data(
return {}


def cast_answer_type(answer: Any, type_fn: Callable) -> Any:
def cast_answer_type(answer: Any, type_name: str) -> Any:
"""Cast answer to expected type."""
try:
type_fn = CAST_STR_TO_NATIVE[type_name]
except KeyError as exc:
raise InvalidTypeError("Invalid answer type") from exc
# Only JSON or YAML questions support `None` as an answer
if answer is None and type_name not in {"json", "yaml"}:
raise TypeError(
f'Invalid answer of type "{type(answer)}" to question of type '
f'"{type_name}"'
)
try:
return type_fn(answer)
except (TypeError, AttributeError):
# JSON or YAML failed because it wasn't a string; no need to convert
return answer
if type_name in {"json", "yaml"}:
return answer
raise


CAST_STR_TO_NATIVE: Mapping[str, Callable] = {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_complex_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def template_path(tmp_path_factory: pytest.TempPathFactory) -> str:
three: third
choose_number:
help: This must be a number
default: null
default: -1.1
type: float
choices:
- -1.1
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_api(template_path: str, tmp_path: Path) -> None:
choose_list: "first"
choose_tuple: "second"
choose_dict: "third"
choose_number: null
choose_number: -1.1
minutes_under_water: 10
optional_value: null
"""
Expand Down
83 changes: 83 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,86 @@ def test_required_choice_question_without_data(
)
with pytest.raises(ValueError, match='Question "question" is required'):
copier.copy(str(src), dst, defaults=True)


@pytest.mark.parametrize(
"type_name, default, expected",
[
("str", "string", does_not_raise()),
("str", "1.0", does_not_raise()),
("str", 1.0, does_not_raise()),
("str", None, pytest.raises(TypeError)),
("int", 1, does_not_raise()),
("int", 1.0, does_not_raise()),
("int", "1", does_not_raise()),
("int", "1.0", pytest.raises(ValueError)),
("int", "no-int", pytest.raises(ValueError)),
("int", None, pytest.raises(TypeError)),
("int", {}, pytest.raises(TypeError)),
("int", [], pytest.raises(TypeError)),
("float", 1.1, does_not_raise()),
("float", 1, does_not_raise()),
("float", "1.1", does_not_raise()),
("float", "no-float", pytest.raises(ValueError)),
("float", None, pytest.raises(TypeError)),
("float", {}, pytest.raises(TypeError)),
("float", [], pytest.raises(TypeError)),
("bool", True, does_not_raise()),
("bool", False, does_not_raise()),
("bool", "y", does_not_raise()),
("bool", "n", does_not_raise()),
("bool", None, pytest.raises(TypeError)),
("json", '"string"', does_not_raise()),
("json", "1", does_not_raise()),
("json", 1, does_not_raise()),
("json", "1.1", does_not_raise()),
("json", 1.1, does_not_raise()),
("json", "true", does_not_raise()),
("json", True, does_not_raise()),
("json", "false", does_not_raise()),
("json", False, does_not_raise()),
("json", "{}", does_not_raise()),
("json", {}, does_not_raise()),
("json", "[]", does_not_raise()),
("json", [], does_not_raise()),
("json", "null", does_not_raise()),
("json", None, does_not_raise()),
("yaml", '"string"', does_not_raise()),
("yaml", "string", does_not_raise()),
("yaml", "1", does_not_raise()),
("yaml", 1, does_not_raise()),
("yaml", "1.1", does_not_raise()),
("yaml", 1.1, does_not_raise()),
("yaml", "true", does_not_raise()),
("yaml", True, does_not_raise()),
("yaml", "false", does_not_raise()),
("yaml", False, does_not_raise()),
("yaml", "{}", does_not_raise()),
("yaml", {}, does_not_raise()),
("yaml", "[]", does_not_raise()),
("yaml", [], does_not_raise()),
("yaml", "null", does_not_raise()),
("yaml", None, does_not_raise()),
],
)
def test_validate_default_value(
tmp_path_factory: pytest.TempPathFactory,
type_name: str,
default: Any,
expected: ContextManager[None],
) -> None:
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yml"): yaml.dump(
{
"q": {
"type": type_name,
"default": default,
}
}
)
}
)
with expected:
copier.copy(str(src), dst, defaults=True)