diff --git a/tests/plugins/seeder/_utils.py b/tests/plugins/seeder/_utils.py index 995af570..d7b5fa1d 100644 --- a/tests/plugins/seeder/_utils.py +++ b/tests/plugins/seeder/_utils.py @@ -44,7 +44,7 @@ async def fire_arg_parsed_event(dispatcher: Dispatcher, *, seed: Optional[str] = arg_parse_event = ArgParseEvent(ArgumentParser()) await dispatcher.fire(arg_parse_event) - arg_parsed_event = ArgParsedEvent(Namespace(seed=seed)) + arg_parsed_event = ArgParsedEvent(Namespace(seed=seed, fixed_seed=False)) await dispatcher.fire(arg_parsed_event) diff --git a/tests/plugins/slicer/test_slicer_plugin.py b/tests/plugins/slicer/test_slicer_plugin.py index f5909e5e..e5c89e12 100644 --- a/tests/plugins/slicer/test_slicer_plugin.py +++ b/tests/plugins/slicer/test_slicer_plugin.py @@ -120,16 +120,17 @@ async def test_arg_validation(total: Union[int, None], index: Union[int, None], assert res is None -@pytest.mark.parametrize(("total", "index"), [ - (1, None), # index is None - (None, 1), # total is None - (0, 1), # total <= 0 - (1, -1), # index < 0 - (1, 1), # index > total +@pytest.mark.parametrize(("total", "index", "error"), [ + (1, None, "`--slicer-index` must be specified if `--slicer-total` is specified"), + (None, 1, "`--slicer-total` must be specified if `--slicer-index` is specified"), + (0, 1, "`--slicer-total` must be greater than 0, 0 given"), + (1, -1, + "`--slicer-index` must be greater than 0 and less than `--slicer-total` (1), -1 given"), + (1, 1, "`--slicer-index` must be greater than 0 and less than `--slicer-total` (1), 1 given"), ]) @pytest.mark.asyncio -async def test_arg_validation_error(total: Union[int, None], index: Union[int, None], *, - slicer: SlicerPlugin, dispatcher: Dispatcher): +async def test_arg_validation_error(total: Union[int, None], index: Union[int, None], error: str, + *, slicer: SlicerPlugin, dispatcher: Dispatcher): with given: event = ArgParsedEvent(Namespace(slicer_total=total, slicer_index=index)) @@ -137,4 +138,5 @@ async def test_arg_validation_error(total: Union[int, None], index: Union[int, N await dispatcher.fire(event) with then: - assert exc_info.type is AssertionError + assert exc_info.type is ValueError + assert str(exc_info.value) == error diff --git a/vedro/plugins/seeder/_seeder.py b/vedro/plugins/seeder/_seeder.py index 3750b0af..23f748b0 100644 --- a/vedro/plugins/seeder/_seeder.py +++ b/vedro/plugins/seeder/_seeder.py @@ -25,7 +25,8 @@ def __init__(self, config: Type["Seeder"], *, random: RandomGenerator = _random) super().__init__(config) self._random = random - self._inital_seed: Union[str, None] = None + self._use_fixed_seed = config.use_fixed_seed + self._initial_seed: Union[str, None] = None self._discovered_seed: Union[int, None] = None self._scheduled_seed: Union[int, None] = None self._scheduled_state: Union[StateType, None] = None @@ -45,13 +46,17 @@ def subscribe(self, dispatcher: Dispatcher) -> None: def on_arg_parse(self, event: ArgParseEvent) -> None: event.arg_parser.add_argument("--seed", nargs="?", help="Set seed") + help_msg = "Use the same seed when a scenario is run multiple times in the same execution" + event.arg_parser.add_argument("--fixed-seed", action="store_true", + default=self._use_fixed_seed, help=help_msg) def on_arg_parsed(self, event: ArgParsedEvent) -> None: - self._inital_seed = event.args.seed if (event.args.seed is not None) else str(uuid.uuid4()) + self._initial_seed = event.args.seed if event.args.seed is not None else str(uuid.uuid4()) + self._use_fixed_seed = event.args.fixed_seed def on_startup(self, event: StartupEvent) -> None: - assert self._inital_seed is not None - self._random.set_seed(self._inital_seed) + assert self._initial_seed is not None + self._random.set_seed(self._initial_seed) self._scheduled_seed = self._generate_seed() self._random.set_seed(self._scheduled_seed) @@ -78,15 +83,21 @@ def on_scenario_run(self, event: ScenarioRunEvent) -> None: seed = self._scenarios[unique_id] self._random.set_seed(seed) + if self._use_fixed_seed: + return + for _ in range(self._history[unique_id]): seed = self._generate_seed() self._random.set_seed(seed) def on_cleanup(self, event: CleanupEvent) -> None: if (event.report.passed + event.report.failed) > 0: - event.report.add_summary(f"--seed {self._inital_seed}") + event.report.add_summary(f"--seed {self._initial_seed}") class Seeder(PluginConfig): plugin = SeederPlugin description = "Sets seeds for deterministic random behavior in scenarios" + + # Use the same seed when a scenario is run multiple times in the same execution + use_fixed_seed: bool = False diff --git a/vedro/plugins/slicer/_slicer.py b/vedro/plugins/slicer/_slicer.py index 5c863dfa..bcfca0b7 100644 --- a/vedro/plugins/slicer/_slicer.py +++ b/vedro/plugins/slicer/_slicer.py @@ -25,12 +25,24 @@ def on_arg_parse(self, event: ArgParseEvent) -> None: def on_arg_parsed(self, event: ArgParsedEvent) -> None: self._total = event.args.slicer_total self._index = event.args.slicer_index + if self._total is not None: - assert self._index is not None - assert self._total > 0 + if self._index is None: + raise ValueError( + "`--slicer-index` must be specified if `--slicer-total` is specified") + if self._total <= 0: + raise ValueError( + f"`--slicer-total` must be greater than 0, {self._total} given") + if self._index is not None: - assert self._total is not None - assert 0 <= self._index < self._total + if self._total is None: + raise ValueError( + "`--slicer-total` must be specified if `--slicer-index` is specified") + if not (0 <= self._index < self._total): + raise ValueError( + "`--slicer-index` must be greater than 0 and " + f"less than `--slicer-total` ({self._total}), {self._index} given" + ) async def on_startup(self, event: StartupEvent) -> None: if (self._total is None) or (self._index is None):