From dda9e8fedc75eefb0255245abff0782ecb26f882 Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 15:09:53 +0200 Subject: [PATCH 01/10] Refactor chess.engine.BaseCommand to pass stricter type checking --- chess/engine.py | 614 +++++++++++++++++++++++++++--------------------- 1 file changed, 344 insertions(+), 270 deletions(-) diff --git a/chess/engine.py b/chess/engine.py index c2b470dd..d1bbaaee 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -24,6 +24,9 @@ from types import TracebackType from typing import Any, Callable, Coroutine, Deque, Dict, Generator, Generic, Iterable, Iterator, List, Literal, Mapping, MutableMapping, Optional, Tuple, Type, TypedDict, TypeVar, Union +if typing.TYPE_CHECKING: + from typing_extensions import Self + WdlModel = Literal["sf", "sf16.1", "sf16", "sf15.1", "sf15", "sf14", "sf12", "lichess"] @@ -895,7 +898,7 @@ class Protocol(asyncio.SubprocessProtocol, metaclass=abc.ABCMeta): returncode: asyncio.Future[int] """Future: Exit code of the process.""" - def __init__(self: ProtocolT) -> None: + def __init__(self) -> None: self.loop = asyncio.get_running_loop() self.transport: Optional[asyncio.SubprocessTransport] = None @@ -904,8 +907,8 @@ def __init__(self: ProtocolT) -> None: 2: bytearray(), # stderr } - self.command: Optional[BaseCommand[ProtocolT, Any]] = None - self.next_command: Optional[BaseCommand[ProtocolT, Any]] = None + self.command: Optional[BaseCommand[Any]] = None + self.next_command: Optional[BaseCommand[Any]] = None self.initialized = False self.returncode: asyncio.Future[int] = asyncio.Future() @@ -915,7 +918,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = transport # type: ignore LOGGER.debug("%s: Connection made", self) - def connection_lost(self: ProtocolT, exc: Optional[Exception]) -> None: + def connection_lost(self, exc: Optional[Exception]) -> None: assert self.transport is not None code = self.transport.get_returncode() assert code is not None, "connect lost, but got no returncode" @@ -923,10 +926,10 @@ def connection_lost(self: ProtocolT, exc: Optional[Exception]) -> None: # Terminate commands. if self.command is not None: - self.command._engine_terminated(self, code) + self.command._engine_terminated(code) self.command = None if self.next_command is not None: - self.next_command._engine_terminated(self, code) + self.next_command._engine_terminated(code) self.next_command = None self.returncode.set_result(code) @@ -960,18 +963,18 @@ def pipe_data_received(self, fd: int, data: Union[bytes, str]) -> None: def error_line_received(self, line: str) -> None: LOGGER.warning("%s: stderr >> %s", self, line) - def _line_received(self: ProtocolT, line: str) -> None: + def _line_received(self: Protocol, line: str) -> None: LOGGER.debug("%s: >> %s", self, line) self.line_received(line) if self.command: - self.command._line_received(self, line) + self.command._line_received(line) def line_received(self, line: str) -> None: pass - async def communicate(self: ProtocolT, command_factory: Callable[[ProtocolT], BaseCommand[ProtocolT, T]]) -> T: + async def communicate(self, command_factory: Callable[[Self], BaseCommand[T]]) -> T: command = command_factory(self) if self.returncode.done(): @@ -993,18 +996,18 @@ def previous_command_finished(_: Optional[asyncio.Future[None]]) -> None: def cancel_if_cancelled(result: asyncio.Future[T]) -> None: if result.cancelled(): - cmd._cancel(self) + cmd._cancel() cmd.result.add_done_callback(cancel_if_cancelled) cmd.finished.add_done_callback(previous_command_finished) - cmd._start(self) + cmd._start() if self.command is None: previous_command_finished(None) elif not self.command.result.done(): self.command.result.cancel() elif not self.command.result.cancelled(): - self.command._cancel(self) + self.command._cancel() return await command.result @@ -1207,32 +1210,34 @@ class CommandState(enum.Enum): DONE = enum.auto() -class BaseCommand(Generic[ProtocolT, T]): - def __init__(self, engine: ProtocolT) -> None: +class BaseCommand(Generic[T]): + def __init__(self, engine: Protocol) -> None: + self._engine = engine + self.state = CommandState.NEW self.result: asyncio.Future[T] = asyncio.Future() self.finished: asyncio.Future[None] = asyncio.Future() - def _engine_terminated(self, engine: ProtocolT, code: int) -> None: + def _engine_terminated(self, code: int) -> None: hint = ", binary not compatible with cpu?" if code in [-4, 0xc000001d] else "" exc = EngineTerminatedError(f"engine process died unexpectedly (exit code: {code}{hint})") if self.state == CommandState.ACTIVE: - self.engine_terminated(engine, exc) + self.engine_terminated(exc) elif self.state == CommandState.CANCELLING: self.finished.set_result(None) elif self.state == CommandState.NEW: - self._handle_exception(engine, exc) + self._handle_exception(exc) - def _handle_exception(self, engine: ProtocolT, exc: Exception) -> None: + def _handle_exception(self, exc: Exception) -> None: if not self.result.done(): self.result.set_exception(exc) else: - engine.loop.call_exception_handler({ + self._engine.loop.call_exception_handler({ # XXX "message": f"{type(self).__name__} failed after returning preliminary result ({self.result!r})", "exception": exc, - "protocol": engine, - "transport": engine.transport, + "protocol": self._engine, + "transport": self._engine.transport, }) if not self.finished.done(): @@ -1245,43 +1250,43 @@ def set_finished(self) -> None: self.finished.set_result(None) self.state = CommandState.DONE - def _cancel(self, engine: ProtocolT) -> None: + def _cancel(self) -> None: if self.state != CommandState.CANCELLING and self.state != CommandState.DONE: assert self.state == CommandState.ACTIVE self.state = CommandState.CANCELLING - self.cancel(engine) + self.cancel() - def _start(self, engine: ProtocolT) -> None: + def _start(self) -> None: assert self.state == CommandState.NEW self.state = CommandState.ACTIVE try: - self.check_initialized(engine) - self.start(engine) + self.check_initialized() + self.start() except EngineError as err: - self._handle_exception(engine, err) + self._handle_exception(err) - def _line_received(self, engine: ProtocolT, line: str) -> None: + def _line_received(self, line: str) -> None: assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING] try: - self.line_received(engine, line) + self.line_received(line) except EngineError as err: - self._handle_exception(engine, err) + self._handle_exception(err) - def cancel(self, engine: ProtocolT) -> None: + def cancel(self) -> None: pass - def check_initialized(self, engine: ProtocolT) -> None: - if not engine.initialized: + def check_initialized(self) -> None: + if not self._engine.initialized: raise EngineError("tried to run command, but engine is not initialized") - def start(self, engine: ProtocolT) -> None: + def start(self) -> None: raise NotImplementedError - def line_received(self, engine: ProtocolT, line: str) -> None: + def line_received(self, line: str) -> None: pass - def engine_terminated(self, engine: ProtocolT, exc: Exception) -> None: - self._handle_exception(engine, exc) + def engine_terminated(self, exc: Exception) -> None: + self._handle_exception(exc) def __repr__(self) -> str: return "<{} at {:#x} (state={}, result={}, finished={}>".format(type(self).__name__, id(self), self.state, self.result, self.finished) @@ -1307,26 +1312,33 @@ def __init__(self) -> None: self.ponderhit = False async def initialize(self) -> None: - class UciInitializeCommand(BaseCommand[UciProtocol, None]): - def check_initialized(self, engine: UciProtocol) -> None: - if engine.initialized: + class UciInitializeCommand(BaseCommand[None]): + def __init__(self, engine: UciProtocol): + super().__init__(engine) + self.engine = engine + + @typing.override + def check_initialized(self) -> None: + if self.engine.initialized: raise EngineError("engine already initialized") - def start(self, engine: UciProtocol) -> None: - engine.send_line("uci") + @typing.override + def start(self) -> None: + self.engine.send_line("uci") - def line_received(self, engine: UciProtocol, line: str) -> None: + @typing.override + def line_received(self, line: str) -> None: token, remaining = _next_token(line) if line.strip() == "uciok" and not self.result.done(): - engine.initialized = True + self.engine.initialized = True self.result.set_result(None) self.set_finished() elif token == "option": - self._option(engine, remaining) + self._option(remaining) elif token == "id": - self._id(engine, remaining) + self._id(remaining) - def _option(self, engine: UciProtocol, arg: str) -> None: + def _option(self, arg: str) -> None: current_parameter = None option_parts: dict[str, str] = {k: "" for k in ["name", "type", "default", "min", "max"]} var = [] @@ -1357,16 +1369,16 @@ def parse_min_max_value(option_parts: dict[str, str], which: Literal["min", "max without_default = Option(name, type, None, min, max, var) option = Option(without_default.name, without_default.type, without_default.parse(default), min, max, var) - engine.options[option.name] = option + self.engine.options[option.name] = option if option.default is not None: - engine.config[option.name] = option.default + self.engine.config[option.name] = option.default if option.default is not None and not option.is_managed() and option.name.lower() != "uci_analysemode": - engine.target_config[option.name] = option.default + self.engine.target_config[option.name] = option.default - def _id(self, engine: UciProtocol, arg: str) -> None: + def _id(self, arg: str) -> None: key, value = _next_token(arg) - engine.id[key] = value.strip() + self.engine.id[key] = value.strip() return await self.communicate(UciInitializeCommand) @@ -1395,16 +1407,21 @@ def debug(self, on: bool = True) -> None: self.send_line("debug off") async def ping(self) -> None: - class UciPingCommand(BaseCommand[UciProtocol, None]): - def start(self, engine: UciProtocol) -> None: - engine._isready() + class UciPingCommand(BaseCommand[None]): + def __init__(self, engine: UciProtocol) -> None: + super().__init__(engine) + self.engine = engine - def line_received(self, engine: UciProtocol, line: str) -> None: + def start(self) -> None: + self.engine._isready() + + @typing.override + def line_received(self, line: str) -> None: if line.strip() == "readyok": self.result.set_result(None) self.set_finished() else: - LOGGER.warning("%s: Unexpected engine output: %r", engine, line) + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) return await self.communicate(UciPingCommand) @@ -1438,10 +1455,14 @@ def _configure(self, options: ConfigMapping) -> None: self._setoption(name, value) async def configure(self, options: ConfigMapping) -> None: - class UciConfigureCommand(BaseCommand[UciProtocol, None]): - def start(self, engine: UciProtocol) -> None: - engine._configure(options) - engine.target_config.update({name: value for name, value in options.items() if value is not None}) + class UciConfigureCommand(BaseCommand[None]): + def __init__(self, engine: UciProtocol): + super().__init__(engine) + self.engine = engine + + def start(self) -> None: + self.engine._configure(options) + self.engine.target_config.update({name: value for name, value in options.items() if value is not None}) self.result.set_result(None) self.set_finished() @@ -1541,77 +1562,82 @@ async def play(self, board: chess.Board, limit: Limit, *, game: object = None, i new_options[name] = value new_options.update(self._opponent_configuration(opponent=opponent)) - class UciPlayCommand(BaseCommand[UciProtocol, PlayResult]): + engine = self + + class UciPlayCommand(BaseCommand[PlayResult]): def __init__(self, engine: UciProtocol): super().__init__(engine) + self.engine = engine # May ponderhit only in the same game and with unchanged target # options. The managed options UCI_AnalyseMode, Ponder, and # MultiPV never change between pondering play commands. engine.may_ponderhit = board if ponder and not engine.first_game and game == engine.game and not engine._changed_options(new_options) else None - def start(self, engine: UciProtocol) -> None: + @typing.override + def start(self) -> None: self.info: InfoDict = {} self.pondering: Optional[chess.Board] = None self.sent_isready = False self.start_time = time.perf_counter() - if engine.ponderhit: - engine.ponderhit = False - engine.send_line("ponderhit") + if self.engine.ponderhit: + self.engine.ponderhit = False + self.engine.send_line("ponderhit") return - if "UCI_AnalyseMode" in engine.options and "UCI_AnalyseMode" not in engine.target_config and all(name.lower() != "uci_analysemode" for name in new_options): - engine._setoption("UCI_AnalyseMode", False) - if "Ponder" in engine.options: - engine._setoption("Ponder", ponder) - if "MultiPV" in engine.options: - engine._setoption("MultiPV", engine.options["MultiPV"].default) + if "UCI_AnalyseMode" in self.engine.options and "UCI_AnalyseMode" not in self.engine.target_config and all(name.lower() != "uci_analysemode" for name in new_options): + self.engine._setoption("UCI_AnalyseMode", False) + if "Ponder" in self.engine.options: + self.engine._setoption("Ponder", ponder) + if "MultiPV" in self.engine.options: + self.engine._setoption("MultiPV", self.engine.options["MultiPV"].default) - new_opponent = new_options.get("UCI_Opponent") or engine.target_config.get("UCI_Opponent") - opponent_changed = new_opponent != engine.config.get("UCI_Opponent") - engine._configure(new_options) + new_opponent = new_options.get("UCI_Opponent") or self.engine.target_config.get("UCI_Opponent") + opponent_changed = new_opponent != self.engine.config.get("UCI_Opponent") + self.engine._configure(new_options) - if engine.first_game or engine.game != game or opponent_changed: - engine.game = game - engine._ucinewgame() + if self.engine.first_game or self.engine.game != game or opponent_changed: + self.engine.game = game + self.engine._ucinewgame() self.sent_isready = True - engine._isready() + self.engine._isready() else: - self._readyok(engine) + self._readyok() - def line_received(self, engine: UciProtocol, line: str) -> None: + @typing.override + def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token == "info": - self._info(engine, remaining) + self._info(remaining) elif token == "bestmove": - self._bestmove(engine, remaining) + self._bestmove(remaining) elif line.strip() == "readyok" and self.sent_isready: - self._readyok(engine) + self._readyok() else: - LOGGER.warning("%s: Unexpected engine output: %r", engine, line) + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) - def _readyok(self, engine: UciProtocol) -> None: + def _readyok(self) -> None: self.sent_isready = False engine._position(board) engine._go(limit, root_moves=root_moves) - def _info(self, engine: UciProtocol, arg: str) -> None: + def _info(self, arg: str) -> None: if not self.pondering: - self.info.update(_parse_uci_info(arg, engine.board, info)) + self.info.update(_parse_uci_info(arg, self.engine.board, info)) - def _bestmove(self, engine: UciProtocol, arg: str) -> None: + def _bestmove(self, arg: str) -> None: if self.pondering: self.pondering = None elif not self.result.cancelled(): - best = _parse_uci_bestmove(engine.board, arg) + best = _parse_uci_bestmove(self.engine.board, arg) self.result.set_result(PlayResult(best.move, best.ponder, self.info)) if ponder and best.move and best.ponder: self.pondering = board.copy() self.pondering.push(best.move) self.pondering.push(best.ponder) - engine._position(self.pondering) + self.engine._position(self.pondering) # Adjust clocks for pondering. time_used = time.perf_counter() - self.start_time @@ -1627,89 +1653,98 @@ def _bestmove(self, engine: UciProtocol, arg: str) -> None: if ponder_limit.remaining_moves: ponder_limit.remaining_moves -= 1 - engine._go(ponder_limit, ponder=True) + self.engine._go(ponder_limit, ponder=True) if not self.pondering: - self.end(engine) + self.end() - def end(self, engine: UciProtocol) -> None: + def end(self) -> None: engine.may_ponderhit = None self.set_finished() - def cancel(self, engine: UciProtocol) -> None: - if engine.may_ponderhit and self.pondering and engine.may_ponderhit.move_stack == self.pondering.move_stack and engine.may_ponderhit == self.pondering: - engine.ponderhit = True - self.end(engine) + @typing.override + def cancel(self) -> None: + if self.engine.may_ponderhit and self.pondering and self.engine.may_ponderhit.move_stack == self.pondering.move_stack and self.engine.may_ponderhit == self.pondering: + self.engine.ponderhit = True + self.end() else: - engine.send_line("stop") + self.engine.send_line("stop") - def engine_terminated(self, engine: UciProtocol, exc: Exception) -> None: + @typing.override + def engine_terminated(self, exc: Exception) -> None: # Allow terminating engine while pondering. if not self.result.done(): - super().engine_terminated(engine, exc) + super().engine_terminated(exc) return await self.communicate(UciPlayCommand) async def analysis(self, board: chess.Board, limit: Optional[Limit] = None, *, multipv: Optional[int] = None, game: object = None, info: Info = INFO_ALL, root_moves: Optional[Iterable[chess.Move]] = None, options: ConfigMapping = {}) -> AnalysisResult: - class UciAnalysisCommand(BaseCommand[UciProtocol, AnalysisResult]): - def start(self, engine: UciProtocol) -> None: - self.analysis = AnalysisResult(stop=lambda: self.cancel(engine)) + class UciAnalysisCommand(BaseCommand[AnalysisResult]): + def __init__(self, engine: UciProtocol): + super().__init__(engine) + self.engine = engine + + def start(self) -> None: + self.analysis = AnalysisResult(stop=lambda: self.cancel()) self.sent_isready = False - if "Ponder" in engine.options: - engine._setoption("Ponder", False) - if "UCI_AnalyseMode" in engine.options and "UCI_AnalyseMode" not in engine.target_config and all(name.lower() != "uci_analysemode" for name in options): - engine._setoption("UCI_AnalyseMode", True) - if "MultiPV" in engine.options or (multipv and multipv > 1): - engine._setoption("MultiPV", 1 if multipv is None else multipv) + if "Ponder" in self.engine.options: + self.engine._setoption("Ponder", False) + if "UCI_AnalyseMode" in self.engine.options and "UCI_AnalyseMode" not in self.engine.target_config and all(name.lower() != "uci_analysemode" for name in options): + self.engine._setoption("UCI_AnalyseMode", True) + if "MultiPV" in self.engine.options or (multipv and multipv > 1): + self.engine._setoption("MultiPV", 1 if multipv is None else multipv) - engine._configure(options) + self.engine._configure(options) - if engine.first_game or engine.game != game: - engine.game = game - engine._ucinewgame() + if self.engine.first_game or self.engine.game != game: + self.engine.game = game + self.engine._ucinewgame() self.sent_isready = True - engine._isready() + self.engine._isready() else: - self._readyok(engine) + self._readyok() - def line_received(self, engine: UciProtocol, line: str) -> None: + @typing.override + def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token == "info": - self._info(engine, remaining) + self._info(remaining) elif token == "bestmove": - self._bestmove(engine, remaining) + self._bestmove(remaining) elif line.strip() == "readyok" and self.sent_isready: - self._readyok(engine) + self._readyok() else: - LOGGER.warning("%s: Unexpected engine output: %r", engine, line) + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) - def _readyok(self, engine: UciProtocol) -> None: + def _readyok(self) -> None: self.sent_isready = False - engine._position(board) + self.engine._position(board) if limit: - engine._go(limit, root_moves=root_moves) + self.engine._go(limit, root_moves=root_moves) else: - engine._go(Limit(), root_moves=root_moves, infinite=True) + self.engine._go(Limit(), root_moves=root_moves, infinite=True) self.result.set_result(self.analysis) - def _info(self, engine: UciProtocol, arg: str) -> None: - self.analysis.post(_parse_uci_info(arg, engine.board, info)) + def _info(self, arg: str) -> None: + self.analysis.post(_parse_uci_info(arg, self.engine.board, info)) - def _bestmove(self, engine: UciProtocol, arg: str) -> None: + def _bestmove(self, arg: str) -> None: if not self.result.done(): raise EngineError("was not searching, but engine sent bestmove") - best = _parse_uci_bestmove(engine.board, arg) + best = _parse_uci_bestmove(self.engine.board, arg) self.set_finished() self.analysis.set_finished(best) - def cancel(self, engine: UciProtocol) -> None: - engine.send_line("stop") + @typing.override + def cancel(self) -> None: + self.engine.send_line("stop") - def engine_terminated(self, engine: UciProtocol, exc: Exception) -> None: - LOGGER.debug("%s: Closing analysis because engine has been terminated (error: %s)", engine, exc) + @typing.override + def engine_terminated(self, exc: Exception) -> None: + LOGGER.debug("%s: Closing analysis because engine has been terminated (error: %s)", self.engine, exc) self.analysis.set_exception(exc) return await self.communicate(UciAnalysisCommand) @@ -1938,89 +1973,96 @@ def __init__(self) -> None: self.first_game = True async def initialize(self) -> None: - class XBoardInitializeCommand(BaseCommand[XBoardProtocol, None]): - def check_initialized(self, engine: XBoardProtocol) -> None: - if engine.initialized: + class XBoardInitializeCommand(BaseCommand[None]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @typing.override + def check_initialized(self) -> None: + if self.engine.initialized: raise EngineError("engine already initialized") - def start(self, engine: XBoardProtocol) -> None: - engine.send_line("xboard") - engine.send_line("protover 2") - self.timeout_handle = engine.loop.call_later(2.0, lambda: self.timeout(engine)) + @typing.override + def start(self) -> None: + self.engine.send_line("xboard") + self.engine.send_line("protover 2") + self.timeout_handle = self.engine.loop.call_later(2.0, lambda: self.timeout()) - def timeout(self, engine: XBoardProtocol) -> None: - LOGGER.error("%s: Timeout during initialization", engine) - self.end(engine) + def timeout(self) -> None: + LOGGER.error("%s: Timeout during initialization", self.engine) + self.end() - def line_received(self, engine: XBoardProtocol, line: str) -> None: + @typing.override + def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token.startswith("#"): pass elif token == "feature": - self._feature(engine, remaining) + self._feature(remaining) elif XBOARD_ERROR_REGEX.match(line): raise EngineError(line) - def _feature(self, engine: XBoardProtocol, arg: str) -> None: + def _feature(self, arg: str) -> None: for feature in shlex.split(arg): key, value = feature.split("=", 1) if key == "option": option = _parse_xboard_option(value) if option.name not in ["random", "computer", "cores", "memory"]: - engine.options[option.name] = option + self.engine.options[option.name] = option else: try: - engine.features[key] = int(value) + self.engine.features[key] = int(value) except ValueError: - engine.features[key] = value + self.engine.features[key] = value - if "done" in engine.features: + if "done" in self.engine.features: self.timeout_handle.cancel() - if engine.features.get("done"): - self.end(engine) + if self.engine.features.get("done"): + self.end() - def end(self, engine: XBoardProtocol) -> None: - if not engine.features.get("ping", 0): + def end(self) -> None: + if not self.engine.features.get("ping", 0): self.result.set_exception(EngineError("xboard engine did not declare required feature: ping")) self.set_finished() return - if not engine.features.get("setboard", 0): + if not self.engine.features.get("setboard", 0): self.result.set_exception(EngineError("xboard engine did not declare required feature: setboard")) self.set_finished() return - if not engine.features.get("reuse", 1): - LOGGER.warning("%s: Rejecting feature reuse=0", engine) - engine.send_line("rejected reuse") - if not engine.features.get("sigterm", 1): - LOGGER.warning("%s: Rejecting feature sigterm=0", engine) - engine.send_line("rejected sigterm") - if engine.features.get("san", 0): - LOGGER.warning("%s: Rejecting feature san=1", engine) - engine.send_line("rejected san") - - if "myname" in engine.features: - engine.id["name"] = str(engine.features["myname"]) - - if engine.features.get("memory", 0): - engine.options["memory"] = Option("memory", "spin", 16, 1, None, None) - engine.send_line("accepted memory") - if engine.features.get("smp", 0): - engine.options["cores"] = Option("cores", "spin", 1, 1, None, None) - engine.send_line("accepted smp") - if engine.features.get("egt"): - for egt in str(engine.features["egt"]).split(","): + if not self.engine.features.get("reuse", 1): + LOGGER.warning("%s: Rejecting feature reuse=0", self.engine) + self.engine.send_line("rejected reuse") + if not self.engine.features.get("sigterm", 1): + LOGGER.warning("%s: Rejecting feature sigterm=0", self.engine) + self.engine.send_line("rejected sigterm") + if self.engine.features.get("san", 0): + LOGGER.warning("%s: Rejecting feature san=1", self.engine) + self.engine.send_line("rejected san") + + if "myname" in self.engine.features: + self.engine.id["name"] = str(self.engine.features["myname"]) + + if self.engine.features.get("memory", 0): + self.engine.options["memory"] = Option("memory", "spin", 16, 1, None, None) + self.engine.send_line("accepted memory") + if self.engine.features.get("smp", 0): + self.engine.options["cores"] = Option("cores", "spin", 1, 1, None, None) + self.engine.send_line("accepted smp") + if self.engine.features.get("egt"): + for egt in str(self.engine.features["egt"]).split(","): name = f"egtpath {egt}" - engine.options[name] = Option(name, "path", None, None, None, None) - engine.send_line("accepted egt") + self.engine.options[name] = Option(name, "path", None, None, None, None) + self.engine.send_line("accepted egt") - for option in engine.options.values(): + for option in self.engine.options.values(): if option.default is not None: - engine.config[option.name] = option.default + self.engine.config[option.name] = option.default if option.default is not None and not option.is_managed(): - engine.target_config[option.name] = option.default + self.engine.target_config[option.name] = option.default - engine.initialized = True + self.engine.initialized = True self.result.set_result(None) self.set_finished() @@ -2105,18 +2147,24 @@ def _new(self, board: chess.Board, game: object, options: ConfigMapping, opponen self.board.push(move) async def ping(self) -> None: - class XBoardPingCommand(BaseCommand[XBoardProtocol, None]): - def start(self, engine: XBoardProtocol) -> None: + class XBoardPingCommand(BaseCommand[None]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @typing.override + def start(self) -> None: n = id(self) & 0xffff self.pong = f"pong {n}" - engine._ping(n) + self.engine._ping(n) - def line_received(self, engine: XBoardProtocol, line: str) -> None: + @typing.override + def line_received(self, line: str) -> None: if line == self.pong: self.result.set_result(None) self.set_finished() elif not line.startswith("#"): - LOGGER.warning("%s: Unexpected engine output: %r", engine, line) + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) elif XBOARD_ERROR_REGEX.match(line): raise EngineError(line) @@ -2126,54 +2174,60 @@ async def play(self, board: chess.Board, limit: Limit, *, game: object = None, i if root_moves is not None: raise EngineError("play with root_moves, but xboard supports 'include' only in analysis mode") - class XBoardPlayCommand(BaseCommand[XBoardProtocol, PlayResult]): - def start(self, engine: XBoardProtocol) -> None: + class XBoardPlayCommand(BaseCommand[PlayResult]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @typing.override + def start(self) -> None: self.play_result = PlayResult(None, None) self.stopped = False self.pong_after_move: Optional[str] = None self.pong_after_ponder: Optional[str] = None # Set game, position and configure. - engine._new(board, game, options, opponent) + self.engine._new(board, game, options, opponent) # Limit or time control. clock = limit.white_clock if board.turn else limit.black_clock increment = limit.white_inc if board.turn else limit.black_inc - if limit.clock_id is None or limit.clock_id != engine.clock_id: - self._send_time_control(engine, clock, increment) - engine.clock_id = limit.clock_id + if limit.clock_id is None or limit.clock_id != self.engine.clock_id: + self._send_time_control(clock, increment) + self.engine.clock_id = limit.clock_id if limit.nodes is not None: if limit.time is not None or limit.white_clock is not None or limit.black_clock is not None or increment is not None: raise EngineError("xboard does not support mixing node limits with time limits") - if "nps" not in engine.features: + if "nps" not in self.engine.features: LOGGER.warning("%s: Engine did not explicitly declare support for node limits (feature nps=?)") - elif not engine.features["nps"]: + elif not self.engine.features["nps"]: raise EngineError("xboard engine does not support node limits (feature nps=0)") - engine.send_line("nps 1") - engine.send_line(f"st {max(1, int(limit.nodes))}") + self.engine.send_line("nps 1") + self.engine.send_line(f"st {max(1, int(limit.nodes))}") if limit.depth is not None: - engine.send_line(f"sd {max(1, int(limit.depth))}") + self.engine.send_line(f"sd {max(1, int(limit.depth))}") if limit.white_clock is not None: - engine.send_line("{} {}".format("time" if board.turn else "otim", max(1, round(limit.white_clock * 100)))) + self.engine.send_line("{} {}".format("time" if board.turn else "otim", max(1, round(limit.white_clock * 100)))) if limit.black_clock is not None: - engine.send_line("{} {}".format("otim" if board.turn else "time", max(1, round(limit.black_clock * 100)))) + self.engine.send_line("{} {}".format("otim" if board.turn else "time", max(1, round(limit.black_clock * 100)))) - if draw_offered and engine.features.get("draw", 1): - engine.send_line("draw") + if draw_offered and self.engine.features.get("draw", 1): + self.engine.send_line("draw") # Start thinking. - engine.send_line("post" if info else "nopost") - engine.send_line("hard" if ponder else "easy") - engine.send_line("go") + self.engine.send_line("post" if info else "nopost") + self.engine.send_line("hard" if ponder else "easy") + self.engine.send_line("go") - def line_received(self, engine: XBoardProtocol, line: str) -> None: + @typing.override + def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token == "move": - self._move(engine, remaining.strip()) + self._move(remaining.strip()) elif token == "Hint:": - self._hint(engine, remaining.strip()) + self._hint(remaining.strip()) elif token == "pong": pong_line = f"{token} {remaining.strip()}" if pong_line == self.pong_after_move: @@ -2188,84 +2242,86 @@ def line_received(self, engine: XBoardProtocol, line: str) -> None: elif f"{token} {remaining.strip()}" == "offer draw": if not self.result.done(): self.play_result.draw_offered = True - self._ping_after_move(engine) + self._ping_after_move() elif line.strip() == "resign": if not self.result.done(): self.play_result.resigned = True - self._ping_after_move(engine) + self._ping_after_move() elif token in ["1-0", "0-1", "1/2-1/2"]: if "resign" in line and not self.result.done(): self.play_result.resigned = True - self._ping_after_move(engine) + self._ping_after_move() elif token.startswith("#"): pass elif XBOARD_ERROR_REGEX.match(line): - engine.first_game = True # Board state might no longer be in sync + self.engine.first_game = True # Board state might no longer be in sync raise EngineError(line) elif len(line.split()) >= 4 and line.lstrip()[0].isdigit(): - self._post(engine, line) + self._post(line) else: - LOGGER.warning("%s: Unexpected engine output: %r", engine, line) + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) - def _send_time_control(self, engine: XBoardProtocol, clock: Optional[float], increment: Optional[float]) -> None: + def _send_time_control(self, clock: Optional[float], increment: Optional[float]) -> None: if limit.remaining_moves or clock is not None or increment is not None: base_mins, base_secs = divmod(int(clock or 0), 60) - engine.send_line(f"level {limit.remaining_moves or 0} {base_mins}:{base_secs:02d} {increment or 0}") + self.engine.send_line(f"level {limit.remaining_moves or 0} {base_mins}:{base_secs:02d} {increment or 0}") if limit.time is not None: - engine.send_line(f"st {max(0.01, limit.time)}") + self.engine.send_line(f"st {max(0.01, limit.time)}") - def _post(self, engine: XBoardProtocol, line: str) -> None: + def _post(self, line: str) -> None: if not self.result.done(): - self.play_result.info = _parse_xboard_post(line, engine.board, info) + self.play_result.info = _parse_xboard_post(line, self.engine.board, info) - def _move(self, engine: XBoardProtocol, arg: str) -> None: + def _move(self, arg: str) -> None: if not self.result.done() and self.play_result.move is None: try: - self.play_result.move = engine.board.push_xboard(arg) + self.play_result.move = self.engine.board.push_xboard(arg) except ValueError as err: self.result.set_exception(EngineError(err)) else: - self._ping_after_move(engine) + self._ping_after_move() else: try: - engine.board.push_xboard(arg) + self.engine.board.push_xboard(arg) except ValueError: LOGGER.exception("Exception playing unexpected move") - def _hint(self, engine: XBoardProtocol, arg: str) -> None: + def _hint(self, arg: str) -> None: if not self.result.done() and self.play_result.move is not None and self.play_result.ponder is None: try: - self.play_result.ponder = engine.board.parse_xboard(arg) + self.play_result.ponder = self.engine.board.parse_xboard(arg) except ValueError: LOGGER.exception("Exception parsing hint") else: LOGGER.warning("Unexpected hint: %r", arg) - def _ping_after_move(self, engine: XBoardProtocol) -> None: + def _ping_after_move(self) -> None: if self.pong_after_move is None: n = id(self) & 0xffff self.pong_after_move = f"pong {n}" - engine._ping(n) + self.engine._ping(n) - def cancel(self, engine: XBoardProtocol) -> None: + @typing.override + def cancel(self) -> None: if self.stopped: return self.stopped = True if self.result.cancelled(): - engine.send_line("?") + self.engine.send_line("?") if ponder: - engine.send_line("easy") + self.engine.send_line("easy") n = (id(self) + 1) & 0xffff self.pong_after_ponder = f"pong {n}" - engine._ping(n) + self.engine._ping(n) - def engine_terminated(self, engine: XBoardProtocol, exc: Exception) -> None: + @typing.override + def engine_terminated(self, exc: Exception) -> None: # Allow terminating engine while pondering. if not self.result.done(): - super().engine_terminated(engine, exc) + super().engine_terminated(exc) return await self.communicate(XBoardPlayCommand) @@ -2276,49 +2332,55 @@ async def analysis(self, board: chess.Board, limit: Optional[Limit] = None, *, m if limit is not None and (limit.white_clock is not None or limit.black_clock is not None): raise EngineError("xboard analysis does not support clock limits") - class XBoardAnalysisCommand(BaseCommand[XBoardProtocol, AnalysisResult]): - def start(self, engine: XBoardProtocol) -> None: + class XBoardAnalysisCommand(BaseCommand[AnalysisResult]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @typing.override + def start(self) -> None: self.stopped = False self.best_move: Optional[chess.Move] = None - self.analysis = AnalysisResult(stop=lambda: self.cancel(engine)) + self.analysis = AnalysisResult(stop=lambda: self.cancel()) self.final_pong: Optional[str] = None - engine._new(board, game, options) + self.engine._new(board, game, options) if root_moves is not None: - if not engine.features.get("exclude", 0): + if not self.engine.features.get("exclude", 0): raise EngineError("xboard engine does not support root_moves (feature exclude=0)") - engine.send_line("exclude all") + self.engine.send_line("exclude all") for move in root_moves: - engine.send_line(f"include {engine.board.xboard(move)}") + self.engine.send_line(f"include {self.engine.board.xboard(move)}") - engine.send_line("post") - engine.send_line("analyze") + self.engine.send_line("post") + self.engine.send_line("analyze") self.result.set_result(self.analysis) if limit is not None and limit.time is not None: - self.time_limit_handle: Optional[asyncio.Handle] = engine.loop.call_later(limit.time, lambda: self.cancel(engine)) + self.time_limit_handle: Optional[asyncio.Handle] = self.engine.loop.call_later(limit.time, lambda: self.cancel()) else: self.time_limit_handle = None - def line_received(self, engine: XBoardProtocol, line: str) -> None: + @typing.override + def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token.startswith("#"): pass elif len(line.split()) >= 4 and line.lstrip()[0].isdigit(): - self._post(engine, line) + self._post(line) elif f"{token} {remaining.strip()}" == self.final_pong: - self.end(engine) + self.end() elif XBOARD_ERROR_REGEX.match(line): - engine.first_game = True # Board state might no longer be in sync + self.engine.first_game = True # Board state might no longer be in sync raise EngineError(line) else: - LOGGER.warning("%s: Unexpected engine output: %r", engine, line) + LOGGER.warning("%s: Unexpected engine output: %r", self.engine, line) - def _post(self, engine: XBoardProtocol, line: str) -> None: - post_info = _parse_xboard_post(line, engine.board, info) + def _post(self, line: str) -> None: + post_info = _parse_xboard_post(line, self.engine.board, info) self.analysis.post(post_info) pv = post_info.get("pv") @@ -2327,36 +2389,38 @@ def _post(self, engine: XBoardProtocol, line: str) -> None: if limit is not None: if limit.time is not None and post_info.get("time", 0) >= limit.time: - self.cancel(engine) + self.cancel() elif limit.nodes is not None and post_info.get("nodes", 0) >= limit.nodes: - self.cancel(engine) + self.cancel() elif limit.depth is not None and post_info.get("depth", 0) >= limit.depth: - self.cancel(engine) + self.cancel() elif limit.mate is not None and "score" in post_info: if post_info["score"].relative >= Mate(limit.mate): - self.cancel(engine) + self.cancel() - def end(self, engine: XBoardProtocol) -> None: + def end(self) -> None: if self.time_limit_handle: self.time_limit_handle.cancel() self.set_finished() self.analysis.set_finished(BestMove(self.best_move, None)) - def cancel(self, engine: XBoardProtocol) -> None: + @typing.override + def cancel(self) -> None: if self.stopped: return self.stopped = True - engine.send_line(".") - engine.send_line("exit") + self.engine.send_line(".") + self.engine.send_line("exit") n = id(self) & 0xffff self.final_pong = f"pong {n}" - engine._ping(n) + self.engine._ping(n) - def engine_terminated(self, engine: XBoardProtocol, exc: Exception) -> None: - LOGGER.debug("%s: Closing analysis because engine has been terminated (error: %s)", engine, exc) + @typing.override + def engine_terminated(self, exc: Exception) -> None: + LOGGER.debug("%s: Closing analysis because engine has been terminated (error: %s)", self.engine, exc) if self.time_limit_handle: self.time_limit_handle.cancel() @@ -2397,10 +2461,15 @@ def _configure(self, options: ConfigMapping) -> None: self._setoption(name, value) async def configure(self, options: ConfigMapping) -> None: - class XBoardConfigureCommand(BaseCommand[XBoardProtocol, None]): - def start(self, engine: XBoardProtocol) -> None: - engine._configure(options) - engine.target_config.update({name: value for name, value in options.items() if value is not None}) + class XBoardConfigureCommand(BaseCommand[None]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @typing.override + def start(self) -> None: + self.engine._configure(options) + self.engine.target_config.update({name: value for name, value in options.items() if value is not None}) self.result.set_result(None) self.set_finished() @@ -2423,12 +2492,17 @@ async def send_opponent_information(self, *, opponent: Optional[Opponent] = None return await self.configure(self._opponent_configuration(opponent=opponent, engine_rating=engine_rating)) async def send_game_result(self, board: chess.Board, winner: Optional[Color] = None, game_ending: Optional[str] = None, game_complete: bool = True) -> None: - class XBoardGameResultCommand(BaseCommand[XBoardProtocol, None]): - def start(self, engine: XBoardProtocol) -> None: + class XBoardGameResultCommand(BaseCommand[None]): + def __init__(self, engine: XBoardProtocol): + super().__init__(engine) + self.engine = engine + + @typing.override + def start(self) -> None: if game_ending and any(c in game_ending for c in "{}\n\r"): raise EngineError(f"invalid line break or curly braces in game ending message: {game_ending!r}") - engine._new(board, engine.game, {}) # Send final moves to engine. + self.engine._new(board, self.engine.game, {}) # Send final moves to engine. outcome = board.outcome(claim_draw=True) @@ -2451,7 +2525,7 @@ def start(self, engine: XBoardProtocol) -> None: ending = "" ending_text = f"{{{ending}}}" if ending else "" - engine.send_line(f"result {result} {ending_text}".strip()) + self.engine.send_line(f"result {result} {ending_text}".strip()) self.result.set_result(None) self.set_finished() @@ -2845,7 +2919,7 @@ def id(self) -> Mapping[str, str]: future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) return future.result() - def communicate(self, command_factory: Callable[[Protocol], BaseCommand[Protocol, T]]) -> T: + def communicate(self, command_factory: Callable[[Protocol], BaseCommand[T]]) -> T: with self._not_shut_down(): coro = self.protocol.communicate(command_factory) future = asyncio.run_coroutine_threadsafe(coro, self.protocol.loop) From 5d8e82da82156f5eae2d15f07e2c24d800c87a41 Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 16:05:29 +0200 Subject: [PATCH 02/10] Let variant boards manage their own stack --- chess/__init__.py | 9 +++---- chess/variant.py | 66 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/chess/__init__.py b/chess/__init__.py index a90df9ef..b82f8559 100644 --- a/chess/__init__.py +++ b/chess/__init__.py @@ -1542,7 +1542,7 @@ def from_chess960_pos(cls: Type[BaseBoardT], scharnagl: int) -> BaseBoardT: BoardT = TypeVar("BoardT", bound="Board") -class _BoardState(Generic[BoardT]): +class _BoardState: def __init__(self, board: BoardT) -> None: self.pawns = board.pawns @@ -1701,7 +1701,7 @@ def __init__(self: BoardT, fen: Optional[str] = STARTING_FEN, *, chess960: bool self.ep_square = None self.move_stack = [] - self._stack: List[_BoardState[BoardT]] = [] + self._stack: List[_BoardState] = [] if fen is None: self.clear() @@ -2304,9 +2304,6 @@ def is_repetition(self, count: int = 3) -> bool: return False - def _board_state(self: BoardT) -> _BoardState[BoardT]: - return _BoardState(self) - def _push_capture(self, move: Move, capture_square: Square, piece_type: PieceType, was_promoted: bool) -> None: pass @@ -2335,7 +2332,7 @@ def push(self: BoardT, move: Move) -> None: """ # Push move and remember board state. move = self._to_chess960(move) - board_state = self._board_state() + board_state = _BoardState(self) self.castling_rights = self.clean_castling_rights() # Before pushing stack self.move_stack.append(self._from_chess960(self.chess960, move.from_square, move.to_square, move.promotion, move.drop)) self._stack.append(board_state) diff --git a/chess/variant.py b/chess/variant.py index 6160696a..119a0402 100644 --- a/chess/variant.py +++ b/chess/variant.py @@ -673,14 +673,12 @@ def status(self) -> chess.Status: ThreeCheckBoardT = TypeVar("ThreeCheckBoardT", bound="ThreeCheckBoard") -class _ThreeCheckBoardState(Generic[ThreeCheckBoardT], chess._BoardState[ThreeCheckBoardT]): - def __init__(self, board: ThreeCheckBoardT) -> None: - super().__init__(board) +class _ThreeCheckBoardState: + def __init__(self, board: ThreeCheckBoard) -> None: self.remaining_checks_w = board.remaining_checks[chess.WHITE] self.remaining_checks_b = board.remaining_checks[chess.BLACK] - def restore(self, board: ThreeCheckBoardT) -> None: - super().restore(board) + def restore(self, board: ThreeCheckBoard) -> None: board.remaining_checks[chess.WHITE] = self.remaining_checks_w board.remaining_checks[chess.BLACK] = self.remaining_checks_b @@ -698,8 +696,13 @@ class ThreeCheckBoard(chess.Board): def __init__(self, fen: Optional[str] = starting_fen, chess960: bool = False) -> None: self.remaining_checks = [3, 3] + self._three_check_stack: List[_ThreeCheckBoardState] = [] super().__init__(fen, chess960=chess960) + def clear_stack(self) -> None: + super().clear_stack() + self._three_check_stack.clear() + def reset_board(self) -> None: super().reset_board() self.remaining_checks[chess.WHITE] = 3 @@ -710,14 +713,17 @@ def clear_board(self) -> None: self.remaining_checks[chess.WHITE] = 3 self.remaining_checks[chess.BLACK] = 3 - def _board_state(self: ThreeCheckBoardT) -> _ThreeCheckBoardState[ThreeCheckBoardT]: - return _ThreeCheckBoardState(self) - def push(self, move: chess.Move) -> None: + self._three_check_stack.append(_ThreeCheckBoardState(self)) super().push(move) if self.is_check(): self.remaining_checks[not self.turn] -= 1 + def pop(self) -> chess.Move: + move = super().pop() + self._three_check_stack.pop().restore(self) + return move + def has_insufficient_material(self, color: chess.Color) -> bool: # Any remaining piece can give check. return not (self.occupied_co[color] & ~self.kings) @@ -792,8 +798,19 @@ def _transposition_key(self) -> Hashable: def copy(self: ThreeCheckBoardT, stack: Union[bool, int] = True) -> ThreeCheckBoardT: board = super().copy(stack=stack) board.remaining_checks = self.remaining_checks.copy() + if stack: + stack = len(self.move_stack) if stack is True else stack + board._three_check_stack = self._three_check_stack[-stack:] return board + def root(self: ThreeCheckBoardT) -> ThreeCheckBoardT: + if self._three_check_stack: + board = super().root() + self._three_check_stack[0].restore(board) + return board + else: + return self.copy(stack=False) + def mirror(self: ThreeCheckBoardT) -> ThreeCheckBoardT: board = super().mirror() board.remaining_checks[chess.WHITE] = self.remaining_checks[chess.BLACK] @@ -803,14 +820,12 @@ def mirror(self: ThreeCheckBoardT) -> ThreeCheckBoardT: CrazyhouseBoardT = TypeVar("CrazyhouseBoardT", bound="CrazyhouseBoard") -class _CrazyhouseBoardState(Generic[CrazyhouseBoardT], chess._BoardState[CrazyhouseBoardT]): - def __init__(self, board: CrazyhouseBoardT) -> None: - super().__init__(board) +class _CrazyhouseBoardState: + def __init__(self, board: CrazyhouseBoard) -> None: self.pockets_w = board.pockets[chess.WHITE].copy() self.pockets_b = board.pockets[chess.BLACK].copy() - def restore(self, board: CrazyhouseBoardT) -> None: - super().restore(board) + def restore(self, board: CrazyhouseBoard) -> None: board.pockets[chess.WHITE] = self.pockets_w board.pockets[chess.BLACK] = self.pockets_b @@ -870,8 +885,13 @@ class CrazyhouseBoard(chess.Board): def __init__(self, fen: Optional[str] = starting_fen, chess960: bool = False) -> None: self.pockets = [CrazyhousePocket(), CrazyhousePocket()] + self._crazyhouse_stack: List[_CrazyhouseBoardState] = [] super().__init__(fen, chess960=chess960) + def clear_stack(self) -> None: + super().clear_stack() + self._crazyhouse_stack.clear() + def reset_board(self) -> None: super().reset_board() self.pockets[chess.WHITE].reset() @@ -882,10 +902,8 @@ def clear_board(self) -> None: self.pockets[chess.WHITE].reset() self.pockets[chess.BLACK].reset() - def _board_state(self: CrazyhouseBoardT) -> _CrazyhouseBoardState[CrazyhouseBoardT]: - return _CrazyhouseBoardState(self) - def push(self, move: chess.Move) -> None: + self._crazyhouse_stack.append(_CrazyhouseBoardState(self)) super().push(move) if move.drop: self.pockets[not self.turn].remove(move.drop) @@ -896,6 +914,11 @@ def _push_capture(self, move: chess.Move, capture_square: chess.Square, piece_ty else: self.pockets[self.turn].add(piece_type) + def pop(self) -> chess.Move: + move = super().pop() + self._crazyhouse_stack.pop().restore(self) + return move + def _is_halfmoves(self, n: int) -> bool: # No draw by 50-move rule or 75-move rule. return False @@ -1028,8 +1051,19 @@ def copy(self: CrazyhouseBoardT, stack: Union[bool, int] = True) -> CrazyhouseBo board = super().copy(stack=stack) board.pockets[chess.WHITE] = self.pockets[chess.WHITE].copy() board.pockets[chess.BLACK] = self.pockets[chess.BLACK].copy() + if stack: + stack = len(self.move_stack) if stack is True else stack + board._crazyhouse_stack = self._crazyhouse_stack[-stack:] return board + def root(self: CrazyhouseBoardT) -> CrazyhouseBoardT: + if self._crazyhouse_stack: + board = super().root() + self._crazyhouse_stack[0].restore(board) + return board + else: + return self.copy(stack=False) + def mirror(self: CrazyhouseBoardT) -> CrazyhouseBoardT: board = super().mirror() board.pockets[chess.WHITE] = self.pockets[chess.BLACK].copy() From a652aaad5d22edd4cc078dba9214fb1d81789d7a Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 16:13:51 +0200 Subject: [PATCH 03/10] Add fallback for override decorator --- chess/engine.py | 60 ++++++++++++++++++++++++++++--------------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/chess/engine.py b/chess/engine.py index d1bbaaee..f5fdff96 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -24,6 +24,14 @@ from types import TracebackType from typing import Any, Callable, Coroutine, Deque, Dict, Generator, Generic, Iterable, Iterator, List, Literal, Mapping, MutableMapping, Optional, Tuple, Type, TypedDict, TypeVar, Union +try: + from typing import override +except: + # Before Python 3.12 + F = typing.TypeVar("F", bound=Callable[..., Any]) + def override(fn: F, /) -> F: + return fn + if typing.TYPE_CHECKING: from typing_extensions import Self @@ -1317,16 +1325,16 @@ def __init__(self, engine: UciProtocol): super().__init__(engine) self.engine = engine - @typing.override + @override def check_initialized(self) -> None: if self.engine.initialized: raise EngineError("engine already initialized") - @typing.override + @override def start(self) -> None: self.engine.send_line("uci") - @typing.override + @override def line_received(self, line: str) -> None: token, remaining = _next_token(line) if line.strip() == "uciok" and not self.result.done(): @@ -1415,7 +1423,7 @@ def __init__(self, engine: UciProtocol) -> None: def start(self) -> None: self.engine._isready() - @typing.override + @override def line_received(self, line: str) -> None: if line.strip() == "readyok": self.result.set_result(None) @@ -1574,7 +1582,7 @@ def __init__(self, engine: UciProtocol): # MultiPV never change between pondering play commands. engine.may_ponderhit = board if ponder and not engine.first_game and game == engine.game and not engine._changed_options(new_options) else None - @typing.override + @override def start(self) -> None: self.info: InfoDict = {} self.pondering: Optional[chess.Board] = None @@ -1605,7 +1613,7 @@ def start(self) -> None: else: self._readyok() - @typing.override + @override def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token == "info": @@ -1662,7 +1670,7 @@ def end(self) -> None: engine.may_ponderhit = None self.set_finished() - @typing.override + @override def cancel(self) -> None: if self.engine.may_ponderhit and self.pondering and self.engine.may_ponderhit.move_stack == self.pondering.move_stack and self.engine.may_ponderhit == self.pondering: self.engine.ponderhit = True @@ -1670,7 +1678,7 @@ def cancel(self) -> None: else: self.engine.send_line("stop") - @typing.override + @override def engine_terminated(self, exc: Exception) -> None: # Allow terminating engine while pondering. if not self.result.done(): @@ -1705,7 +1713,7 @@ def start(self) -> None: else: self._readyok() - @typing.override + @override def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token == "info": @@ -1738,11 +1746,11 @@ def _bestmove(self, arg: str) -> None: self.set_finished() self.analysis.set_finished(best) - @typing.override + @override def cancel(self) -> None: self.engine.send_line("stop") - @typing.override + @override def engine_terminated(self, exc: Exception) -> None: LOGGER.debug("%s: Closing analysis because engine has been terminated (error: %s)", self.engine, exc) self.analysis.set_exception(exc) @@ -1978,12 +1986,12 @@ def __init__(self, engine: XBoardProtocol): super().__init__(engine) self.engine = engine - @typing.override + @override def check_initialized(self) -> None: if self.engine.initialized: raise EngineError("engine already initialized") - @typing.override + @override def start(self) -> None: self.engine.send_line("xboard") self.engine.send_line("protover 2") @@ -1993,7 +2001,7 @@ def timeout(self) -> None: LOGGER.error("%s: Timeout during initialization", self.engine) self.end() - @typing.override + @override def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token.startswith("#"): @@ -2152,13 +2160,13 @@ def __init__(self, engine: XBoardProtocol): super().__init__(engine) self.engine = engine - @typing.override + @override def start(self) -> None: n = id(self) & 0xffff self.pong = f"pong {n}" self.engine._ping(n) - @typing.override + @override def line_received(self, line: str) -> None: if line == self.pong: self.result.set_result(None) @@ -2179,7 +2187,7 @@ def __init__(self, engine: XBoardProtocol): super().__init__(engine) self.engine = engine - @typing.override + @override def start(self) -> None: self.play_result = PlayResult(None, None) self.stopped = False @@ -2221,7 +2229,7 @@ def start(self) -> None: self.engine.send_line("hard" if ponder else "easy") self.engine.send_line("go") - @typing.override + @override def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token == "move": @@ -2301,7 +2309,7 @@ def _ping_after_move(self) -> None: self.pong_after_move = f"pong {n}" self.engine._ping(n) - @typing.override + @override def cancel(self) -> None: if self.stopped: return @@ -2317,7 +2325,7 @@ def cancel(self) -> None: self.pong_after_ponder = f"pong {n}" self.engine._ping(n) - @typing.override + @override def engine_terminated(self, exc: Exception) -> None: # Allow terminating engine while pondering. if not self.result.done(): @@ -2337,7 +2345,7 @@ def __init__(self, engine: XBoardProtocol): super().__init__(engine) self.engine = engine - @typing.override + @override def start(self) -> None: self.stopped = False self.best_move: Optional[chess.Move] = None @@ -2364,7 +2372,7 @@ def start(self) -> None: else: self.time_limit_handle = None - @typing.override + @override def line_received(self, line: str) -> None: token, remaining = _next_token(line) if token.startswith("#"): @@ -2405,7 +2413,7 @@ def end(self) -> None: self.set_finished() self.analysis.set_finished(BestMove(self.best_move, None)) - @typing.override + @override def cancel(self) -> None: if self.stopped: return @@ -2418,7 +2426,7 @@ def cancel(self) -> None: self.final_pong = f"pong {n}" self.engine._ping(n) - @typing.override + @override def engine_terminated(self, exc: Exception) -> None: LOGGER.debug("%s: Closing analysis because engine has been terminated (error: %s)", self.engine, exc) @@ -2466,7 +2474,7 @@ def __init__(self, engine: XBoardProtocol): super().__init__(engine) self.engine = engine - @typing.override + @override def start(self) -> None: self.engine._configure(options) self.engine.target_config.update({name: value for name, value in options.items() if value is not None}) @@ -2497,7 +2505,7 @@ def __init__(self, engine: XBoardProtocol): super().__init__(engine) self.engine = engine - @typing.override + @override def start(self) -> None: if game_ending and any(c in game_ending for c in "{}\n\r"): raise EngineError(f"invalid line break or curly braces in game ending message: {game_ending!r}") From 7a59a3683f33682c0c2c9b359469c8b9876ece9e Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 17:02:12 +0200 Subject: [PATCH 04/10] Import override from typing_extensions --- chess/engine.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/chess/engine.py b/chess/engine.py index f5fdff96..347acf29 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -24,10 +24,9 @@ from types import TracebackType from typing import Any, Callable, Coroutine, Deque, Dict, Generator, Generic, Iterable, Iterator, List, Literal, Mapping, MutableMapping, Optional, Tuple, Type, TypedDict, TypeVar, Union -try: - from typing import override -except: - # Before Python 3.12 +if typing.TYPE_CHECKING: + from typing_extensions import override +else: F = typing.TypeVar("F", bound=Callable[..., Any]) def override(fn: F, /) -> F: return fn From 571658baa358b97f3cb5104c3dfab33c6d962deb Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 17:25:03 +0200 Subject: [PATCH 05/10] Use typing_extensions.Self --- chess/__init__.py | 32 ++++++++++++++++---------------- chess/engine.py | 2 +- chess/pgn.py | 7 +++++-- chess/syzygy.py | 7 ++++--- chess/variant.py | 18 +++++++++++------- 5 files changed, 37 insertions(+), 29 deletions(-) diff --git a/chess/__init__.py b/chess/__init__.py index b82f8559..3553eee6 100644 --- a/chess/__init__.py +++ b/chess/__init__.py @@ -25,7 +25,7 @@ from typing import ClassVar, Callable, Counter, Dict, Generic, Hashable, Iterable, Iterator, List, Literal, Mapping, Optional, SupportsInt, Tuple, Type, TypeVar, Union if typing.TYPE_CHECKING: - from typing_extensions import TypeAlias + from typing_extensions import Self, TypeAlias EnPassantSpec = Literal["legal", "fen", "xfen"] @@ -1455,7 +1455,7 @@ def apply_transform(self, f: Callable[[Bitboard], Bitboard]) -> None: self.occupied = f(self.occupied) self.promoted = f(self.promoted) - def transform(self: BaseBoardT, f: Callable[[Bitboard], Bitboard]) -> BaseBoardT: + def transform(self, f: Callable[[Bitboard], Bitboard]) -> Self: """ Returns a transformed copy of the board (without move stack) by applying a bitboard transformation function. @@ -1473,11 +1473,11 @@ def transform(self: BaseBoardT, f: Callable[[Bitboard], Bitboard]) -> BaseBoardT board.apply_transform(f) return board - def apply_mirror(self: BaseBoardT) -> None: + def apply_mirror(self) -> None: self.apply_transform(flip_vertical) self.occupied_co[WHITE], self.occupied_co[BLACK] = self.occupied_co[BLACK], self.occupied_co[WHITE] - def mirror(self: BaseBoardT) -> BaseBoardT: + def mirror(self) -> Self: """ Returns a mirrored copy of the board (without move stack). @@ -1491,7 +1491,7 @@ def mirror(self: BaseBoardT) -> BaseBoardT: board.apply_mirror() return board - def copy(self: BaseBoardT) -> BaseBoardT: + def copy(self) -> Self: """Creates a copy of the board.""" board = type(self)(None) @@ -1509,10 +1509,10 @@ def copy(self: BaseBoardT) -> BaseBoardT: return board - def __copy__(self: BaseBoardT) -> BaseBoardT: + def __copy__(self) -> Self: return self.copy() - def __deepcopy__(self: BaseBoardT, memo: Dict[int, object]) -> BaseBoardT: + def __deepcopy__(self, memo: Dict[int, object]) -> Self: board = self.copy() memo[id(self)] = board return board @@ -1694,7 +1694,7 @@ class Board(BaseBoard): manipulation. """ - def __init__(self: BoardT, fen: Optional[str] = STARTING_FEN, *, chess960: bool = False) -> None: + def __init__(self, fen: Optional[str] = STARTING_FEN, *, chess960: bool = False) -> None: BaseBoard.__init__(self, None) self.chess960 = chess960 @@ -1786,7 +1786,7 @@ def clear_stack(self) -> None: self.move_stack.clear() self._stack.clear() - def root(self: BoardT) -> BoardT: + def root(self) -> Self: """Returns a copy of the root position.""" if self._stack: board = type(self)(None, chess960=self.chess960) @@ -2307,7 +2307,7 @@ def is_repetition(self, count: int = 3) -> bool: def _push_capture(self, move: Move, capture_square: Square, piece_type: PieceType, was_promoted: bool) -> None: pass - def push(self: BoardT, move: Move) -> None: + def push(self, move: Move) -> None: """ Updates the position with the given *move* and puts it onto the move stack. @@ -2428,7 +2428,7 @@ def push(self: BoardT, move: Move) -> None: # Swap turn. self.turn = not self.turn - def pop(self: BoardT) -> Move: + def pop(self) -> Move: """ Restores the previous position and returns the last move from the stack. @@ -2838,7 +2838,7 @@ def _validate_epd_opcode(self, opcode: str) -> None: if blacklisted in opcode: raise ValueError(f"invalid character {blacklisted!r} in epd opcode: {opcode!r}") - def _parse_epd_ops(self: BoardT, operation_part: str, make_board: Callable[[], BoardT]) -> Dict[str, Union[None, str, int, float, Move, List[Move]]]: + def _parse_epd_ops(self, operation_part: str, make_board: Callable[[], Self]) -> Dict[str, Union[None, str, int, float, Move, List[Move]]]: operations: Dict[str, Union[None, str, int, float, Move, List[Move]]] = {} state = "opcode" opcode = "" @@ -3831,16 +3831,16 @@ def apply_transform(self, f: Callable[[Bitboard], Bitboard]) -> None: self.ep_square = None if self.ep_square is None else msb(f(BB_SQUARES[self.ep_square])) self.castling_rights = f(self.castling_rights) - def transform(self: BoardT, f: Callable[[Bitboard], Bitboard]) -> BoardT: + def transform(self, f: Callable[[Bitboard], Bitboard]) -> Self: board = self.copy(stack=False) board.apply_transform(f) return board - def apply_mirror(self: BoardT) -> None: + def apply_mirror(self) -> None: super().apply_mirror() self.turn = not self.turn - def mirror(self: BoardT) -> BoardT: + def mirror(self) -> Self: """ Returns a mirrored copy of the board. @@ -3855,7 +3855,7 @@ def mirror(self: BoardT) -> BoardT: board.apply_mirror() return board - def copy(self: BoardT, *, stack: Union[bool, int] = True) -> BoardT: + def copy(self, *, stack: Union[bool, int] = True) -> Self: """ Creates a copy of the board. diff --git a/chess/engine.py b/chess/engine.py index 347acf29..59b843b8 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -970,7 +970,7 @@ def pipe_data_received(self, fd: int, data: Union[bytes, str]) -> None: def error_line_received(self, line: str) -> None: LOGGER.warning("%s: stderr >> %s", self, line) - def _line_received(self: Protocol, line: str) -> None: + def _line_received(self, line: str) -> None: LOGGER.debug("%s: >> %s", self, line) self.line_received(line) diff --git a/chess/pgn.py b/chess/pgn.py index 0d0879ae..83ddaf8f 100644 --- a/chess/pgn.py +++ b/chess/pgn.py @@ -15,6 +15,9 @@ from typing import Any, Callable, Dict, Generic, Iterable, Iterator, List, Literal, Mapping, MutableMapping, Set, TextIO, Tuple, Type, TypeVar, Optional, Union from chess import Color, Square +if typing.TYPE_CHECKING: + from typing_extensions import Self + LOGGER = logging.getLogger(__name__) @@ -1014,10 +1017,10 @@ def __iter__(self) -> Iterator[str]: def __len__(self) -> int: return len(self._tag_roster) + len(self._others) - def copy(self: HeadersT) -> HeadersT: + def copy(self) -> Self: return type(self)(self) - def __copy__(self: HeadersT) -> HeadersT: + def __copy__(self) -> Self: return self.copy() def __repr__(self) -> str: diff --git a/chess/syzygy.py b/chess/syzygy.py index 8ba1cbca..77a6eede 100644 --- a/chess/syzygy.py +++ b/chess/syzygy.py @@ -14,6 +14,9 @@ from types import TracebackType from typing import Deque, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union +if typing.TYPE_CHECKING: + from typing_extensions import Self + UINT64_BE = struct.Struct(">Q") UINT32 = struct.Struct(" None: self.data.close() self.data = None - def __enter__(self: TableT) -> TableT: + def __enter__(self) -> Self: return self def __exit__(self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType]) -> None: diff --git a/chess/variant.py b/chess/variant.py index 119a0402..d28a6e96 100644 --- a/chess/variant.py +++ b/chess/variant.py @@ -2,9 +2,13 @@ import chess import itertools +import typing from typing import Dict, Generic, Hashable, Iterable, Iterator, List, Optional, Type, TypeVar, Union +if typing.TYPE_CHECKING: + from typing_extensions import Self + class SuicideBoard(chess.Board): @@ -795,7 +799,7 @@ def _transposition_key(self) -> Hashable: return (super()._transposition_key(), self.remaining_checks[chess.WHITE], self.remaining_checks[chess.BLACK]) - def copy(self: ThreeCheckBoardT, stack: Union[bool, int] = True) -> ThreeCheckBoardT: + def copy(self, stack: Union[bool, int] = True) -> Self: board = super().copy(stack=stack) board.remaining_checks = self.remaining_checks.copy() if stack: @@ -803,7 +807,7 @@ def copy(self: ThreeCheckBoardT, stack: Union[bool, int] = True) -> ThreeCheckBo board._three_check_stack = self._three_check_stack[-stack:] return board - def root(self: ThreeCheckBoardT) -> ThreeCheckBoardT: + def root(self) -> Self: if self._three_check_stack: board = super().root() self._three_check_stack[0].restore(board) @@ -811,7 +815,7 @@ def root(self: ThreeCheckBoardT) -> ThreeCheckBoardT: else: return self.copy(stack=False) - def mirror(self: ThreeCheckBoardT) -> ThreeCheckBoardT: + def mirror(self) -> Self: board = super().mirror() board.remaining_checks[chess.WHITE] = self.remaining_checks[chess.BLACK] board.remaining_checks[chess.BLACK] = self.remaining_checks[chess.WHITE] @@ -865,7 +869,7 @@ def __len__(self) -> int: def __repr__(self) -> str: return f"CrazyhousePocket('{self}')" - def copy(self: CrazyhousePocketT) -> CrazyhousePocketT: + def copy(self) -> Self: """Returns a copy of this pocket.""" pocket = type(self)() pocket._pieces = self._pieces[:] @@ -1047,7 +1051,7 @@ def epd(self, shredder: bool = False, en_passant: chess.EnPassantSpec = "legal", board_part, info_part = epd.split(" ", 1) return f"{board_part}[{str(self.pockets[chess.WHITE]).upper()}{self.pockets[chess.BLACK]}] {info_part}" - def copy(self: CrazyhouseBoardT, stack: Union[bool, int] = True) -> CrazyhouseBoardT: + def copy(self, stack: Union[bool, int] = True) -> Self: board = super().copy(stack=stack) board.pockets[chess.WHITE] = self.pockets[chess.WHITE].copy() board.pockets[chess.BLACK] = self.pockets[chess.BLACK].copy() @@ -1056,7 +1060,7 @@ def copy(self: CrazyhouseBoardT, stack: Union[bool, int] = True) -> CrazyhouseBo board._crazyhouse_stack = self._crazyhouse_stack[-stack:] return board - def root(self: CrazyhouseBoardT) -> CrazyhouseBoardT: + def root(self) -> Self: if self._crazyhouse_stack: board = super().root() self._crazyhouse_stack[0].restore(board) @@ -1064,7 +1068,7 @@ def root(self: CrazyhouseBoardT) -> CrazyhouseBoardT: else: return self.copy(stack=False) - def mirror(self: CrazyhouseBoardT) -> CrazyhouseBoardT: + def mirror(self) -> Self: board = super().mirror() board.pockets[chess.WHITE] = self.pockets[chess.BLACK].copy() board.pockets[chess.BLACK] = self.pockets[chess.WHITE].copy() From f286d687e496123fb46d26ce3fd06bc0db505dc4 Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 17:40:10 +0200 Subject: [PATCH 06/10] Towards pyright compatibility --- chess/__init__.py | 4 ++-- chess/engine.py | 7 +++++-- chess/variant.py | 8 ++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/chess/__init__.py b/chess/__init__.py index 3553eee6..55bd4cad 100644 --- a/chess/__init__.py +++ b/chess/__init__.py @@ -1544,7 +1544,7 @@ def from_chess960_pos(cls: Type[BaseBoardT], scharnagl: int) -> BaseBoardT: class _BoardState: - def __init__(self, board: BoardT) -> None: + def __init__(self, board: Board) -> None: self.pawns = board.pawns self.knights = board.knights self.bishops = board.bishops @@ -1564,7 +1564,7 @@ def __init__(self, board: BoardT) -> None: self.halfmove_clock = board.halfmove_clock self.fullmove_number = board.fullmove_number - def restore(self, board: BoardT) -> None: + def restore(self, board: Board) -> None: board.pawns = self.pawns board.knights = self.knights board.bishops = self.bishops diff --git a/chess/engine.py b/chess/engine.py index 59b843b8..9a6a21f4 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -406,8 +406,10 @@ class Score(abc.ABC): """ @typing.overload + @abc.abstractmethod def score(self, *, mate_score: int) -> int: ... @typing.overload + @abc.abstractmethod def score(self, *, mate_score: Optional[int] = None) -> Optional[int]: ... @abc.abstractmethod def score(self, *, mate_score: Optional[int] = None) -> Optional[int]: @@ -2120,12 +2122,13 @@ def _new(self, board: chess.Board, game: object, options: ConfigMapping, opponen if self.config.get("computer"): self.send_line("computer") - self.send_line("force") + self.send_line("force") - if new_game: fen = root.fen(shredder=board.chess960, en_passant="fen") if variant != "normal" or fen != chess.STARTING_FEN or board.chess960: self.send_line(f"setboard {fen}") + else: + self.send_line("force") # Undo moves until common position. common_stack_len = 0 diff --git a/chess/variant.py b/chess/variant.py index d28a6e96..6e9161dc 100644 --- a/chess/variant.py +++ b/chess/variant.py @@ -133,7 +133,7 @@ def _transposition_key(self) -> Hashable: else: return super()._transposition_key() - def board_fen(self, promoted: Optional[bool] = None) -> str: + def board_fen(self, *, promoted: Optional[bool] = None) -> str: if promoted is None: promoted = self.has_chess960_castling_rights() return super().board_fen(promoted=promoted) @@ -799,7 +799,7 @@ def _transposition_key(self) -> Hashable: return (super()._transposition_key(), self.remaining_checks[chess.WHITE], self.remaining_checks[chess.BLACK]) - def copy(self, stack: Union[bool, int] = True) -> Self: + def copy(self, *, stack: Union[bool, int] = True) -> Self: board = super().copy(stack=stack) board.remaining_checks = self.remaining_checks.copy() if stack: @@ -1041,7 +1041,7 @@ def set_fen(self, fen: str) -> None: self.pockets[chess.WHITE] = white_pocket self.pockets[chess.BLACK] = black_pocket - def board_fen(self, promoted: Optional[bool] = None) -> str: + def board_fen(self, *, promoted: Optional[bool] = None) -> str: if promoted is None: promoted = True return super().board_fen(promoted=promoted) @@ -1051,7 +1051,7 @@ def epd(self, shredder: bool = False, en_passant: chess.EnPassantSpec = "legal", board_part, info_part = epd.split(" ", 1) return f"{board_part}[{str(self.pockets[chess.WHITE]).upper()}{self.pockets[chess.BLACK]}] {info_part}" - def copy(self, stack: Union[bool, int] = True) -> Self: + def copy(self, *, stack: Union[bool, int] = True) -> Self: board = super().copy(stack=stack) board.pockets[chess.WHITE] = self.pockets[chess.WHITE].copy() board.pockets[chess.BLACK] = self.pockets[chess.BLACK].copy() From 33eea7ca062ed7396590ac3408ec05121bba9483 Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 18:01:47 +0200 Subject: [PATCH 07/10] Add changelog entries --- CHANGELOG.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a4534bad..754b5d41 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,13 @@ Changes: some 8 piece positions with decisive captures can be probed successfully. * The string wrapper returned by ``chess.svg`` functions now also implements ``_repr_html_``. +* Significant changes to ``chess.engine`` internals: + ``chess.engine.BaseCommand`` methods other than the constructor no longer + receive ``engine: Protocol``. +* Significant changes to board state internals: Subclasses of ``chess.Board`` + can no longer hook into board state recording/restoration and need to + override relevant methods instead (``clear_stack``, ``copy``, ``root``, + ``push``, ``pop``). New features: From caefd4dc6c25369750f6cc461885adfbbd52f09c Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 21:11:38 +0200 Subject: [PATCH 08/10] chess.engine._next_token() cosmetics --- chess/engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/chess/engine.py b/chess/engine.py index 9a6a21f4..c6698bc7 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -2650,7 +2650,8 @@ def _parse_xboard_post(line: str, root_board: chess.Board, selector: Info = INFO def _next_token(line: str) -> tuple[str, str]: - """Get the next token in a whitespace-delimited line of text. + """ + Get the next token in a whitespace-delimited line of text. The result is returned as a 2-part tuple of strings. @@ -2660,10 +2661,10 @@ def _next_token(line: str) -> tuple[str, str]: If the input line is not empty and not completely whitespace, then the first element of the returned tuple is a single word with leading and trailing whitespace removed. The second element is the - unchanged rest of the line.""" - + unchanged rest of the line. + """ parts = line.split(maxsplit=1) - return (parts[0] if parts else "", parts[1] if len(parts) == 2 else "") + return parts[0] if parts else "", parts[1] if len(parts) == 2 else "" class BestMove: From 71e7c31fba31554a2b174ff7fb88a77b61674543 Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 22:21:28 +0200 Subject: [PATCH 09/10] Show actual state in engine command state assertions (#1049, #1071) --- chess/engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/chess/engine.py b/chess/engine.py index c6698bc7..302aa2bb 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -989,7 +989,7 @@ async def communicate(self, command_factory: Callable[[Self], BaseCommand[T]]) - if self.returncode.done(): raise EngineTerminatedError(f"engine process dead (exit code: {self.returncode.result()})") - assert command.state == CommandState.NEW + assert command.state == CommandState.NEW, command.state if self.next_command is not None: self.next_command.result.cancel() @@ -1253,7 +1253,7 @@ def _handle_exception(self, exc: Exception) -> None: self.finished.set_result(None) def set_finished(self) -> None: - assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING] + assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING], self.state if not self.result.done(): self.result.set_exception(EngineError(f"engine command finished before returning result: {self!r}")) self.finished.set_result(None) @@ -1261,12 +1261,12 @@ def set_finished(self) -> None: def _cancel(self) -> None: if self.state != CommandState.CANCELLING and self.state != CommandState.DONE: - assert self.state == CommandState.ACTIVE + assert self.state == CommandState.ACTIVE, self.state self.state = CommandState.CANCELLING self.cancel() def _start(self) -> None: - assert self.state == CommandState.NEW + assert self.state == CommandState.NEW, self.state self.state = CommandState.ACTIVE try: self.check_initialized() @@ -1275,7 +1275,7 @@ def _start(self) -> None: self._handle_exception(err) def _line_received(self, line: str) -> None: - assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING] + assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING], self.state try: self.line_received(line) except EngineError as err: From 7299216641f5bd0434c06111608892617aa39147 Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Wed, 31 Jul 2024 22:37:53 +0200 Subject: [PATCH 10/10] Immediately dispatch line/termination/finish (fixes #1049, fixes #1071) Avoids races between queued up lines and command finish callbacks. --- CHANGELOG.rst | 2 ++ chess/engine.py | 40 +++++++++++++++++++++++++++------------- test.py | 18 ++++++++++++++++++ 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 754b5d41..9a0a16e6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -30,6 +30,8 @@ New features: Bugfixes: +* Fix unsolicited engine output may cause assertion errors with regard to + command states. * Fix handling of whitespace in UCI engine communication. * For ``chess.Board.epd()`` and ``chess.Board.set_epd()``, require that EPD opcodes start with a letter. diff --git a/chess/engine.py b/chess/engine.py index 302aa2bb..ccd6894d 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -883,7 +883,7 @@ def write(self, data: bytes) -> None: expectation, responses = self.expectations.popleft() assert expectation == line, f"expected {expectation}, got: {line}" if responses: - self.protocol.pipe_data_received(1, "\n".join(responses + [""]).encode("utf-8")) + self.protocol.loop.call_soon(self.protocol.pipe_data_received, 1, "\n".join(responses + [""]).encode("utf-8")) def get_pid(self) -> int: return id(self) @@ -934,12 +934,12 @@ def connection_lost(self, exc: Optional[Exception]) -> None: LOGGER.debug("%s: Connection lost (exit code: %d, error: %s)", self, code, exc) # Terminate commands. - if self.command is not None: - self.command._engine_terminated(code) - self.command = None - if self.next_command is not None: - self.next_command._engine_terminated(code) - self.next_command = None + command, self.command = self.command, None + next_command, self.next_command = self.next_command, None + if command: + command._engine_terminated(code) + if next_command: + next_command._engine_terminated(code) self.returncode.set_result(code) @@ -965,9 +965,9 @@ def pipe_data_received(self, fd: int, data: Union[bytes, str]) -> None: LOGGER.warning("%s: >> %r (%s)", self, bytes(line_bytes), err) else: if fd == 1: - self.loop.call_soon(self._line_received, line) + self._line_received(line) else: - self.loop.call_soon(self.error_line_received, line) + self.error_line_received(line) def error_line_received(self, line: str) -> None: LOGGER.warning("%s: stderr >> %s", self, line) @@ -998,7 +998,7 @@ async def communicate(self, command_factory: Callable[[Self], BaseCommand[T]]) - self.next_command = command - def previous_command_finished(_: Optional[asyncio.Future[None]]) -> None: + def previous_command_finished() -> None: self.command, self.next_command = self.next_command, None if self.command is not None: cmd = self.command @@ -1008,11 +1008,11 @@ def cancel_if_cancelled(result: asyncio.Future[T]) -> None: cmd._cancel() cmd.result.add_done_callback(cancel_if_cancelled) - cmd.finished.add_done_callback(previous_command_finished) cmd._start() + cmd.add_finished_callback(previous_command_finished) if self.command is None: - previous_command_finished(None) + previous_command_finished() elif not self.command.result.done(): self.command.result.cancel() elif not self.command.result.cancelled(): @@ -1228,6 +1228,17 @@ def __init__(self, engine: Protocol) -> None: self.result: asyncio.Future[T] = asyncio.Future() self.finished: asyncio.Future[None] = asyncio.Future() + self._finished_callbacks: List[Callable[[], None]] = [] + + def add_finished_callback(self, callback: Callable[[], None]) -> None: + self._finished_callbacks.append(callback) + self._dispatch_finished() + + def _dispatch_finished(self) -> None: + if self.finished.done(): + while self._finished_callbacks: + self._finished_callbacks.pop()() + def _engine_terminated(self, code: int) -> None: hint = ", binary not compatible with cpu?" if code in [-4, 0xc000001d] else "" exc = EngineTerminatedError(f"engine process died unexpectedly (exit code: {code}{hint})") @@ -1235,6 +1246,7 @@ def _engine_terminated(self, code: int) -> None: self.engine_terminated(exc) elif self.state == CommandState.CANCELLING: self.finished.set_result(None) + self._dispatch_finished() elif self.state == CommandState.NEW: self._handle_exception(exc) @@ -1251,13 +1263,15 @@ def _handle_exception(self, exc: Exception) -> None: if not self.finished.done(): self.finished.set_result(None) + self._dispatch_finished() def set_finished(self) -> None: assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING], self.state if not self.result.done(): self.result.set_exception(EngineError(f"engine command finished before returning result: {self!r}")) - self.finished.set_result(None) self.state = CommandState.DONE + self.finished.set_result(None) + self._dispatch_finished() def _cancel(self) -> None: if self.state != CommandState.CANCELLING and self.state != CommandState.DONE: diff --git a/test.py b/test.py index 00d766a0..438628b9 100755 --- a/test.py +++ b/test.py @@ -3527,6 +3527,24 @@ async def main(): asyncio.run(main()) + def test_uci_output_after_command(self): + async def main(): + protocol = chess.engine.UciProtocol() + mock = chess.engine.MockTransport(protocol) + + mock.expect("uci", [ + "Arasan v24.0.0-10-g367aa9f Copyright 1994-2023 by Jon Dart.", + "All rights reserved.", + "id name Arasan v24.0.0-10-g367aa9f", + "uciok", + "info string out of do_all_pending, list size=0" + ]) + await protocol.initialize() + + mock.assert_done() + + asyncio.run(main()) + def test_hiarcs_bestmove(self): async def main(): protocol = chess.engine.UciProtocol()