Skip to content

Commit

Permalink
feat: add support for computed values via skipped questions (#1220)
Browse files Browse the repository at this point in the history
* feat: add support for computed values via skipped questions

* docs: apply review suggestions

Co-authored-by: Jairo Llopis <[email protected]>

* docs: update inline comments

* docs: fix formatting

---------

Co-authored-by: Jairo Llopis <[email protected]>
  • Loading branch information
sisp and yajo authored Jul 8, 2023
1 parent 0c5a6ca commit 1e81fd5
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 35 deletions.
17 changes: 10 additions & 7 deletions copier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def _answers_to_remember(self) -> Mapping:
(str(k), v)
for (k, v) in self.answers.combined.items()
if not k.startswith("_")
and k not in self.answers.hidden
and k not in self.template.secret_questions
and k in self.template.questions_data
and isinstance(k, JSONSerializable)
Expand Down Expand Up @@ -413,7 +414,6 @@ def _render_allowed(
def _ask(self) -> None:
"""Ask the questions of the questionary and record their answers."""
result = AnswersMap(
default=self.template.default_answers,
user_defaults=self.user_defaults,
init=self.data,
last=self.subproject.last_answers,
Expand All @@ -427,12 +427,14 @@ def _ask(self) -> None:
var_name=var_name,
**details,
)
# Skip a question when the skip condition is met and remove any data
# from the answers map, so no answer for this question is recorded
# in the answers file.
# Skip a question when the skip condition is met.
if not question.get_when():
result.remove(var_name)
continue
# Omit its answer from the answers file.
result.hide(var_name)
# Skip immediately to the next question when it has no default
# value.
if question.default is MISSING:
continue
if var_name in result.init:
# Try to parse the answer value.
answer = question.parse_answer(result.init[var_name])
Expand All @@ -452,7 +454,8 @@ def _ask(self) -> None:
raise ValueError(f'Question "{var_name}" is required')
else:
new_answer = unsafe_prompt(
[question.get_questionary_structure()], answers=result.combined
[question.get_questionary_structure()],
answers={question.var_name: question.get_default()},
)[question.var_name]
except KeyboardInterrupt as err:
raise CopierAnswersInterrupt(result, question, self.template) from err
Expand Down
5 changes: 0 additions & 5 deletions copier/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,6 @@ def config_data(self) -> AnyByStrDict:
verify_copier_version(result["min_copier_version"])
return result

@cached_property
def default_answers(self) -> AnyByStrDict:
"""Get default answers for template's questions."""
return {key: value.get("default") for key, value in self.questions_data.items()}

@cached_property
def envops(self) -> Mapping:
"""Get the Jinja configuration specified in the template, or default values.
Expand Down
19 changes: 4 additions & 15 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,54 +103,43 @@ class AnswersMap:
last:
Data from [the answers file][the-copier-answersyml-file].
default:
Default data from the template.
See [copier.template.Template.default_answers][].
user_defaults:
Default data from the user e.g. previously completed and restored data.
See [copier.main.Worker][].
"""

# Private
removed: Set[str] = field(default_factory=set, init=False)
hidden: Set[str] = field(default_factory=set, init=False)

# Public
user: AnyByStrDict = field(default_factory=dict)
init: AnyByStrDict = field(default_factory=dict)
metadata: AnyByStrDict = field(default_factory=dict)
last: AnyByStrDict = field(default_factory=dict)
user_defaults: AnyByStrDict = field(default_factory=dict)
default: AnyByStrDict = field(default_factory=dict)

@property
def combined(self) -> Mapping[str, Any]:
"""Answers combined from different sources, sorted by priority."""
combined = dict(
return dict(
ChainMap(
self.user,
self.init,
self.metadata,
self.last,
self.user_defaults,
self.default,
DEFAULT_DATA,
)
)
for key in self.removed:
if key in combined:
del combined[key]
return combined

def old_commit(self) -> OptStr:
"""Commit when the project was updated from this template the last time."""
return self.last.get("_commit")

def remove(self, key: str) -> None:
def hide(self, key: str) -> None:
"""Remove an answer by key."""
self.removed.add(key)
self.hidden.add(key)


@dataclass(config=AllowArbitraryTypes)
Expand Down
13 changes: 5 additions & 8 deletions docs/configuring.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,14 @@ Supported keys:

- **when**: Condition that, if `false`, skips the question.

If it is a boolean, it is used directly, but it's a bit absurd in that case.
If it is a boolean, it is used directly. Setting it to `false` is useful for
creating a computed value.

If it is a string, it is converted to boolean using a parser similar to YAML, but
only for boolean values.
only for boolean values. The string can be [templated][prompt-templating].

This is most useful when [templated](#prompt-templating).

If a question is skipped, its answer will be:

- The default value, if you're generating the project for the 1st time.
- The last answer recorded, if you're updating the project.
If a question is skipped, its answer is not recorded, but its default value is
available in the render context.

!!! example

Expand Down
49 changes: 49 additions & 0 deletions tests/test_complex_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def test_tui_inherited_default(
"{{ _copier_answers|to_nice_yaml }}"
),
(src / "answers.json.jinja"): "{{ _copier_answers|to_nice_json }}",
(src / "context.json.jinja"): '{"owner2": "{{ owner2 }}"}',
}
)
with local.cwd(src):
Expand All @@ -423,6 +424,7 @@ def test_tui_inherited_default(
**({"owner2": owner2} if has_2_owners else {}),
}
assert json.loads((dst / "answers.json").read_text()) == result
assert json.loads((dst / "context.json").read_text()) == {"owner2": owner2}
with local.cwd(dst):
git("init")
git("add", "--all")
Expand All @@ -432,6 +434,45 @@ def test_tui_inherited_default(
assert json.loads((dst / "answers.json").read_text()) == result


def test_tui_typed_default(
tmp_path_factory: pytest.TempPathFactory, spawn: Spawn
) -> None:
"""Make sure a template defaults are typed as expected."""
src, dst = map(tmp_path_factory.mktemp, ("src", "dst"))
build_file_tree(
{
(src / "copier.yaml"): (
"""\
test1:
type: bool
default: false
when: false
test2:
type: bool
default: "{{ 'a' == 'b' }}"
when: false
"""
),
(src / "{{ _copier_conf.answers_file }}.jinja"): (
"{{ _copier_answers|to_nice_yaml }}"
),
(src / "answers.json.jinja"): "{{ _copier_answers|to_nice_json }}",
(src / "context.json.jinja"): (
"""\
{"test1": "{{ test1 }}", "test2": "{{ test2 }}"}
"""
),
}
)
tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10)
tui.expect_exact(pexpect.EOF)
assert json.loads((dst / "answers.json").read_text()) == {"_src_path": str(src)}
assert json.loads((dst / "context.json").read_text()) == {
"test1": "False",
"test2": "False",
}


def test_selection_type_cast(
tmp_path_factory: pytest.TempPathFactory, spawn: Spawn
) -> None:
Expand Down Expand Up @@ -559,8 +600,16 @@ def test_omit_answer_for_skipped_question(
(src / "{{ _copier_conf.answers_file }}.jinja"): (
"{{ _copier_answers|to_nice_yaml }}"
),
(src / "context.yml.jinja"): yaml.safe_dump(
{
"disabled": "{{ disabled }}",
"disabled_with_default": "{{ disabled_with_default }}",
}
),
}
)
run_copy(str(src), dst, defaults=True, data={"disabled": "hello"})
answers = yaml.safe_load((dst / ".copier-answers.yml").read_text())
assert answers == {"_src_path": str(src)}
context = yaml.safe_load((dst / "context.yml").read_text())
assert context == {"disabled": "hello", "disabled_with_default": "hello"}
7 changes: 7 additions & 0 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def test_when(
(src / "[[ _copier_conf.answers_file ]].tmpl"): (
"[[ _copier_answers|to_nice_yaml ]]"
),
(src / "context.yml.tmpl"): (
"""\
question_2: [[ question_2 ]]
"""
),
}
)
tui = spawn(COPIER_PATH + ("copy", str(src), str(dst)), timeout=10)
Expand All @@ -198,6 +203,8 @@ def test_when(
"question_1": question_1,
**({"question_2": "something"} if asks else {}),
}
context = yaml.safe_load((dst / "context.yml").read_text())
assert context == {"question_2": "something"}


def test_placeholder(tmp_path_factory: pytest.TempPathFactory, spawn: Spawn) -> None:
Expand Down

0 comments on commit 1e81fd5

Please sign in to comment.