Skip to content

Commit

Permalink
Add own game over detection
Browse files Browse the repository at this point in the history
  • Loading branch information
ArcticXWolf committed Mar 26, 2021
1 parent c631896 commit a0bde4b
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 104 deletions.
19 changes: 7 additions & 12 deletions engines/axwchessbot/evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations
from typing import Tuple
import chess

from evaluation.game_over_detection import GameOverDetection
from . import score_tables
import functools
import itertools
import operator


class EvaluationResult:
Expand Down Expand Up @@ -61,7 +59,7 @@ def evaluate(self) -> Evaluation:
self.evaluate_tempo(color)
self.evaluate_blocked_pieces(color)
self.evaluate_king_shield(color)
self.evaluate_mobility(color)
# self.evaluate_mobility(color)
self.evaluate_passed_pawns()

self.combine_results()
Expand Down Expand Up @@ -114,22 +112,19 @@ def combine_results(self) -> None:
self.total_score_perspective = -self.total_score

def evaluate_gameover(self) -> float:
if not self.board.is_game_over(claim_draw=True):
if not GameOverDetection.is_game_over(self.board):
return None

result = self.board.result(claim_draw=True)
if result == "0-1":
return float("-inf")
elif result == "1-0":
return float("inf")
if self.board.is_checkmate():
return float("-inf") if self.board.turn == chess.WHITE else float("inf")

# draw, calculate contempt factor via gamephase
# on midgame, +60 for enemy
# on endgame, 0
draw_score = (
float(
self.eval_result[chess.WHITE].gamephase
+ self.eval_result[chess.WHITE].gamephase
+ self.eval_result[chess.BLACK].gamephase
)
* 60.0
/ 24.0
Expand Down
61 changes: 61 additions & 0 deletions engines/axwchessbot/evaluation/game_over_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import chess
import collections


# Unfortunately the gameover detection of python-chess is very slow
# because it treats the game as drawn if one player can claim draw
# AFTER their next move. This means the detection loops over every
# possible move and that calls the move generator which is SLOW.
# We dont need this claiming mechanic here, so we implement our own
# detection. (Speedup of 2x)
class GameOverDetection:
def is_game_over(board) -> bool:
if board.is_seventyfive_moves():
return True

# Insufficient material.
if board.is_insufficient_material():
return True

# Stalemate or checkmate.
if not any(board.generate_legal_moves()):
return True

# Fivefold repetition.
if board.is_fivefold_repetition():
return True

# Fifty move rule
if board.halfmove_clock > 100:
return True

# Threefold repetition
if GameOverDetection.is_threefold_repetition(board):
return True

return False

def is_threefold_repetition(board) -> bool:
transposition_key = board._transposition_key()
transpositions = collections.Counter()
transpositions.update((transposition_key,))

# Count positions.
switchyard = []
while board.move_stack:
move = board.pop()
switchyard.append(move)

if board.is_irreversible(move):
break

transpositions.update((board._transposition_key(),))

while switchyard:
board.push(switchyard.pop())

# Threefold repetition occured.
if transpositions[transposition_key] >= 3:
return True

return False
6 changes: 3 additions & 3 deletions engines/axwchessbot/profile.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash

python -m cProfile -o tests/results/profiling main.py --bench
gprof2dot -f pstats tests/results/profiling | dot -Tpng -o tests/results/profile.png
dt=$(date '+%Y-%m-%d-%H-%M-%S')
python -m cProfile -o "tests/results/profiling-$dt" main.py --bench
gprof2dot -f pstats "tests/results/profiling-$dt" | dot -Tpng -o "tests/results/profile-$dt.png"
10 changes: 6 additions & 4 deletions engines/axwchessbot/search/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@


class Entry:
def __init__(self, val, flag, entry_depth, move, debug_info):
def __init__(self, val, flag, entry_depth, move):
self.val = val
self.flag = flag
self.entry_depth = entry_depth
self.move = move
self.debug_info = debug_info


class TranspositionTable:
Expand All @@ -19,14 +18,17 @@ def __init__(self, size):
def __getitem__(self, position):
return self.basic_cache.get(chess.polyglot.zobrist_hash(position), None)

def store(self, position, value, flag, entry_depth, move, debug_info):
def store(self, position, value, flag, entry_depth, move):
if len(self.basic_cache) > self.size:
self.empty_cache()
self.basic_cache[chess.polyglot.zobrist_hash(position)] = Entry(
value, flag, entry_depth, move, debug_info
value, flag, entry_depth, move
)

return True

def get_length(self) -> int:
return len(self.basic_cache)

def empty_cache(self):
self.basic_cache = {}
141 changes: 75 additions & 66 deletions engines/axwchessbot/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import chess.syzygy
from evaluation import evaluation
import os

from evaluation.game_over_detection import GameOverDetection
from .cache import TranspositionTable
from .timeout import TimeOut
from timeit import default_timer as timer

LOWER = -1
EXACT = 0
Expand Down Expand Up @@ -40,58 +43,58 @@ def __init__(
self.timeout = timeout
self.killer_moves = {}

self.nodes_traversed = 0
self.q_nodes_traversed = 0
self.cache_hits = 0
self.cache_cutoffs = 0
self.max_depth_used = 0
self.search_finished = False
self.total_search_time = 0.0
self.time_spent_per_depth = {}

if cache:
self.cache = cache

def next_move(self):
debug_info = {"moves_analysis": [], "positions_analyzed": 0, "cache_hits": 0}
if self.search_finished:
raise Exception("reused old search object!")

move = self.next_move_by_opening_db()
if move is not None:
debug_info["moves"] = [move.uci()]
debug_info["eval"] = "opening_db"
debug_info["depth_reached"] = 0
return (move, debug_info)
return move

move = self.next_move_by_ending_db()
if move is not None:
debug_info["moves"] = [move.uci()]
debug_info["eval"] = "ending_db"
debug_info["depth_reached"] = 0
return (move, debug_info)
return move

return self.next_move_by_engine()

def next_move_by_engine(self):
moves, score, debug_info = self.iterative_deepening()
debug_info["moves"] = [move.uci() for move in moves]
eval = evaluation.Evaluation(self.board).evaluate()
debug_info["current_eval"] = eval.total_score
debug_info["gamephase"] = (
eval.eval_result[chess.WHITE].gamephase
+ eval.eval_result[chess.BLACK].gamephase
)
return moves[-1], debug_info
start = timer()
moves, _ = self.iterative_deepening()
self.total_search_time = timer() - start
self.search_finished = True
return moves[-1]

def iterative_deepening(self):
depth_reached = 1
board_copy = self.board.copy()
moves, score, debug_info = self.alpha_beta_search(1)
start_depth_time = timer()
moves, score = self.alpha_beta_search(1)
self.time_spent_per_depth[1] = timer() - start_depth_time
timeout = TimeOut(self.timeout)
try:
timeout.start()
for i in range(2, self.alpha_beta_depth + 1):
moves, score, debug_info = self.alpha_beta_search(
i, previous_moves=moves
)
depth_reached += 1
start_depth_time = timer()
moves, score = self.alpha_beta_search(i, previous_moves=moves)
self.time_spent_per_depth[i] = timer() - start_depth_time
self.max_depth_used = i
except TimeOut.TimeOutException as e:
self.board = board_copy
finally:
timeout.disable_timeout()

debug_info["depth_reached"] = depth_reached

return moves, score, debug_info
return moves, score

def alpha_beta_search(
self,
Expand All @@ -106,42 +109,34 @@ def alpha_beta_search(
best_move = None
alpha_orig = alpha
moves = []
debug_info = {
"moves_analysis": [],
"positions_analyzed": 0,
"cache_hits": 0,
"cache_cutoffs": 0,
}

if depth_left <= 0 or GameOverDetection.is_game_over(self.board):
moves.append(move)
return (
moves,
self.quiesce_search(alpha, beta, self.quiesce_depth - 1),
)

self.nodes_traversed += 1

cached = self.cache[self.board]
if cached:
debug_info["cache_hits"] += 1
self.cache_hits += 1
if cached.entry_depth >= depth_left:
if cached.flag == EXACT:
move = cached.move if not move else move
debug_info["positions_analyzed"] += 1
debug_info["cache_cutoffs"] += 1
moves.append(move)
return moves, cached.val, debug_info
self.cache_cutoffs += 1
return moves, cached.val
elif cached.flag == LOWER:
alpha = max(alpha, cached.val)
elif cached.flag == UPPER:
beta = min(beta, cached.val)
if alpha >= beta:
move = cached.move if not move else move
debug_info["positions_analyzed"] += 1
debug_info["cache_cutoffs"] += 1
moves.append(move)
return moves, cached.val, debug_info

if depth_left <= 0 or self.board.is_game_over():
moves.append(move)
debug_info["positions_analyzed"] += 1
return (
moves,
self.quiesce_search(alpha, beta, self.quiesce_depth - 1),
debug_info,
)
self.cache_cutoffs += 1
return moves, cached.val

move_list_to_choose_from = evaluation.Evaluation(self.board).move_order()

Expand Down Expand Up @@ -178,22 +173,12 @@ def alpha_beta_search(
pass

for m in move_list_to_choose_from:
san = self.board.san(m)
self.board.push(m)

new_moves, score, new_debug_info = self.alpha_beta_search(
new_moves, score = self.alpha_beta_search(
depth_left - 1, -beta, -alpha, m, previous_moves, ply + 1
)
score = -score
debug_info["moves_analysis"].append(
(
str(san),
score,
)
)
debug_info["positions_analyzed"] += new_debug_info["positions_analyzed"]
debug_info["cache_hits"] += new_debug_info["cache_hits"]
debug_info["cache_cutoffs"] += new_debug_info["cache_cutoffs"]

self.board.pop()

Expand All @@ -216,21 +201,21 @@ def alpha_beta_search(

if not best_move:
best_move = m
self.cache.store(
self.board, best_score, flag, depth_left, best_move, debug_info
)
self.cache.store(self.board, best_score, flag, depth_left, best_move)
moves.append(best_move)
return (moves, best_score, debug_info)
return (moves, best_score)

def quiesce_search(self, alpha: float, beta: float, depth_left: int = 0):
self.nodes_traversed += 1
self.q_nodes_traversed += 1

stand_pat = evaluation.Evaluation(self.board).evaluate().total_score_perspective
if stand_pat >= beta:
return beta
if alpha < stand_pat:
alpha = stand_pat

if depth_left > 0 or not self.board.is_game_over(claim_draw=True):
if depth_left > 0 or not GameOverDetection.is_game_over(self.board):
for move in self.get_captures_by_value():
if self.board.is_capture(move):
self.board.push(move)
Expand Down Expand Up @@ -311,4 +296,28 @@ def sort_function(move):
move for move in self.board.legal_moves if self.board.is_capture(move)
]
captures_ordered = sorted(captures, key=sort_function, reverse=True)
return list(captures_ordered)
return list(captures_ordered)

def get_measurements(
self, show_exact_timings: bool = False, show_cache: bool = False
):
result = {
"finished": self.search_finished,
"max_depth_used": self.max_depth_used,
"nodes_traversed": self.nodes_traversed,
"q_nodes_traversed": self.q_nodes_traversed,
"total_search_time": self.total_search_time,
}
if self.nodes_traversed >= 0 and self.total_search_time >= 1.0:
result["nps"] = float(self.nodes_traversed) / self.total_search_time
if show_exact_timings:
result["time_spent_per_depth"] = self.time_spent_per_depth
if show_cache:
result["cache_hits"] = self.cache_hits
result["cache_cutoffs"] = self.cache_cutoffs
result["cache_length"] = self.cache.get_length()
return result

def __str__(self):
measurements = [f"{k}={str(v)}" for k, v in self.get_measurements().items()]
return f"<Search {' '.join(measurements)}>"
Loading

0 comments on commit a0bde4b

Please sign in to comment.