From 1e81fd5ecaecb38785fb490a5fa6e8922cfcb97a Mon Sep 17 00:00:00 2001 From: Sigurd Spieckermann <2206639+sisp@users.noreply.github.com> Date: Sat, 8 Jul 2023 12:36:21 +0200 Subject: [PATCH] feat: add support for computed values via skipped questions (#1220) * feat: add support for computed values via skipped questions * docs: apply review suggestions Co-authored-by: Jairo Llopis <973709+yajo@users.noreply.github.com> * docs: update inline comments * docs: fix formatting --------- Co-authored-by: Jairo Llopis <973709+yajo@users.noreply.github.com> --- copier/main.py | 17 +++++++----- copier/template.py | 5 ---- copier/user_data.py | 19 +++---------- docs/configuring.md | 13 ++++----- tests/test_complex_questions.py | 49 +++++++++++++++++++++++++++++++++ tests/test_prompt.py | 7 +++++ 6 files changed, 75 insertions(+), 35 deletions(-) diff --git a/copier/main.py b/copier/main.py index f9fbba55c..78a2f5ec0 100644 --- a/copier/main.py +++ b/copier/main.py @@ -257,6 +257,7 @@ def _answers_to_remember(self) -> Mapping: (str(k), v) for (k, v) in self.answers.combined.items() if not k.startswith("_") + and k not in self.answers.hidden and k not in self.template.secret_questions and k in self.template.questions_data and isinstance(k, JSONSerializable) @@ -413,7 +414,6 @@ def _render_allowed( def _ask(self) -> None: """Ask the questions of the questionary and record their answers.""" result = AnswersMap( - default=self.template.default_answers, user_defaults=self.user_defaults, init=self.data, last=self.subproject.last_answers, @@ -427,12 +427,14 @@ def _ask(self) -> None: var_name=var_name, **details, ) - # Skip a question when the skip condition is met and remove any data - # from the answers map, so no answer for this question is recorded - # in the answers file. + # Skip a question when the skip condition is met. if not question.get_when(): - result.remove(var_name) - continue + # Omit its answer from the answers file. + result.hide(var_name) + # Skip immediately to the next question when it has no default + # value. + if question.default is MISSING: + continue if var_name in result.init: # Try to parse the answer value. answer = question.parse_answer(result.init[var_name]) @@ -452,7 +454,8 @@ def _ask(self) -> None: raise ValueError(f'Question "{var_name}" is required') else: new_answer = unsafe_prompt( - [question.get_questionary_structure()], answers=result.combined + [question.get_questionary_structure()], + answers={question.var_name: question.get_default()}, )[question.var_name] except KeyboardInterrupt as err: raise CopierAnswersInterrupt(result, question, self.template) from err diff --git a/copier/template.py b/copier/template.py index 71af6f836..b67903b58 100644 --- a/copier/template.py +++ b/copier/template.py @@ -281,11 +281,6 @@ def config_data(self) -> AnyByStrDict: verify_copier_version(result["min_copier_version"]) return result - @cached_property - def default_answers(self) -> AnyByStrDict: - """Get default answers for template's questions.""" - return {key: value.get("default") for key, value in self.questions_data.items()} - @cached_property def envops(self) -> Mapping: """Get the Jinja configuration specified in the template, or default values. diff --git a/copier/user_data.py b/copier/user_data.py index 273631981..1d9154a07 100644 --- a/copier/user_data.py +++ b/copier/user_data.py @@ -103,11 +103,6 @@ class AnswersMap: last: Data from [the answers file][the-copier-answersyml-file]. - default: - Default data from the template. - - See [copier.template.Template.default_answers][]. - user_defaults: Default data from the user e.g. previously completed and restored data. @@ -115,7 +110,7 @@ class AnswersMap: """ # Private - removed: Set[str] = field(default_factory=set, init=False) + hidden: Set[str] = field(default_factory=set, init=False) # Public user: AnyByStrDict = field(default_factory=dict) @@ -123,34 +118,28 @@ class AnswersMap: metadata: AnyByStrDict = field(default_factory=dict) last: AnyByStrDict = field(default_factory=dict) user_defaults: AnyByStrDict = field(default_factory=dict) - default: AnyByStrDict = field(default_factory=dict) @property def combined(self) -> Mapping[str, Any]: """Answers combined from different sources, sorted by priority.""" - combined = dict( + return dict( ChainMap( self.user, self.init, self.metadata, self.last, self.user_defaults, - self.default, DEFAULT_DATA, ) ) - for key in self.removed: - if key in combined: - del combined[key] - return combined def old_commit(self) -> OptStr: """Commit when the project was updated from this template the last time.""" return self.last.get("_commit") - def remove(self, key: str) -> None: + def hide(self, key: str) -> None: """Remove an answer by key.""" - self.removed.add(key) + self.hidden.add(key) @dataclass(config=AllowArbitraryTypes) diff --git a/docs/configuring.md b/docs/configuring.md index f9ff63b6c..c8d8c4bcd 100644 --- a/docs/configuring.md +++ b/docs/configuring.md @@ -175,17 +175,14 @@ Supported keys: - **when**: Condition that, if `false`, skips the question. - If it is a boolean, it is used directly, but it's a bit absurd in that case. + If it is a boolean, it is used directly. Setting it to `false` is useful for + creating a computed value. If it is a string, it is converted to boolean using a parser similar to YAML, but - only for boolean values. + only for boolean values. The string can be [templated][prompt-templating]. - This is most useful when [templated](#prompt-templating). - - If a question is skipped, its answer will be: - - - The default value, if you're generating the project for the 1st time. - - The last answer recorded, if you're updating the project. + If a question is skipped, its answer is not recorded, but its default value is + available in the render context. !!! example diff --git a/tests/test_complex_questions.py b/tests/test_complex_questions.py index 249cec644..829060cf9 100644 --- a/tests/test_complex_questions.py +++ b/tests/test_complex_questions.py @@ -397,6 +397,7 @@ def test_tui_inherited_default( "{{ _copier_answers|to_nice_yaml }}" ), (src / "answers.json.jinja"): "{{ _copier_answers|to_nice_json }}", + (src / "context.json.jinja"): '{"owner2": "{{ owner2 }}"}', } ) with local.cwd(src): @@ -423,6 +424,7 @@ def test_tui_inherited_default( **({"owner2": owner2} if has_2_owners else {}), } assert json.loads((dst / "answers.json").read_text()) == result + assert json.loads((dst / "context.json").read_text()) == {"owner2": owner2} with local.cwd(dst): git("init") git("add", "--all") @@ -432,6 +434,45 @@ def test_tui_inherited_default( assert json.loads((dst / "answers.json").read_text()) == result +def test_tui_typed_default( + tmp_path_factory: pytest.TempPathFactory, spawn: Spawn +) -> None: + """Make sure a template defaults are typed as expected.""" + src, dst = map(tmp_path_factory.mktemp, ("src", "dst")) + build_file_tree( + { + (src / "copier.yaml"): ( + """\ + test1: + type: bool + default: false + when: false + test2: + type: bool + default: "{{ 'a' == 'b' }}" + when: false + """ + ), + (src / "{{ _copier_conf.answers_file }}.jinja"): ( + "{{ _copier_answers|to_nice_yaml }}" + ), + (src / "answers.json.jinja"): "{{ _copier_answers|to_nice_json }}", + (src / "context.json.jinja"): ( + """\ + {"test1": "{{ test1 }}", "test2": "{{ test2 }}"} + """ + ), + } + ) + tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10) + tui.expect_exact(pexpect.EOF) + assert json.loads((dst / "answers.json").read_text()) == {"_src_path": str(src)} + assert json.loads((dst / "context.json").read_text()) == { + "test1": "False", + "test2": "False", + } + + def test_selection_type_cast( tmp_path_factory: pytest.TempPathFactory, spawn: Spawn ) -> None: @@ -559,8 +600,16 @@ def test_omit_answer_for_skipped_question( (src / "{{ _copier_conf.answers_file }}.jinja"): ( "{{ _copier_answers|to_nice_yaml }}" ), + (src / "context.yml.jinja"): yaml.safe_dump( + { + "disabled": "{{ disabled }}", + "disabled_with_default": "{{ disabled_with_default }}", + } + ), } ) run_copy(str(src), dst, defaults=True, data={"disabled": "hello"}) answers = yaml.safe_load((dst / ".copier-answers.yml").read_text()) assert answers == {"_src_path": str(src)} + context = yaml.safe_load((dst / "context.yml").read_text()) + assert context == {"disabled": "hello", "disabled_with_default": "hello"} diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 8b75ef03a..b94d5e00f 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -183,6 +183,11 @@ def test_when( (src / "[[ _copier_conf.answers_file ]].tmpl"): ( "[[ _copier_answers|to_nice_yaml ]]" ), + (src / "context.yml.tmpl"): ( + """\ + question_2: [[ question_2 ]] + """ + ), } ) tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10) @@ -198,6 +203,8 @@ def test_when( "question_1": question_1, **({"question_2": "something"} if asks else {}), } + context = yaml.safe_load((dst / "context.yml").read_text()) + assert context == {"question_2": "something"} def test_placeholder(tmp_path_factory: pytest.TempPathFactory, spawn: Spawn) -> None: