-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
165 lines (111 loc) · 4.55 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import math, random
def get_valid_moves(state):
"""
For the given state of the game, retrieve all the valid moves
"""
return [x for x in range(len(state)) if state[x] == ""]
def check_win(state, position, player):
"""
For the given state of the game, check if placing the player at the given position results in a win
"""
row = math.floor(position / 3) # the x axis from the grid
col = position % 3 # the y axis of the game grid
# Check if the player made the column
if state[row * 3] == player and state[row * 3 + 1] == player and state[ row * 3 + 2] == player:
return True
# Check if the player made the row
if state[col] == player and state[col + 3] == player and state[col + 6] == player:
return True
# Check if the player made the diagonal
if row == col:
if state[0] == player and state[4] == player and state[8] == player:
return True
# Check if the player made the other diagonal``
if row == 2 - col:
if state[2] == player and state[4] == player and state[6] == player:
return True
return False
class TicTacToeNode:
def __init__(self, parent, state, position):
self.state = state
self.parent = parent
self.position = position
# Figure out the move number
self.move = sum(x != "" for x in state)
self.me = state[position]
self.plays = 0
self.wins = 0
def expand_node(self):
"""
Expand a node to produce all child nodes will all possible moves
"""
positions = get_valid_moves(self.state)
self.children = []
for pos in positions:
new_state = self.state.copy()
new_state[pos] = '+' if self.move % 2 == 0 else 'o'
self.children.append( TicTacToeNode(self, new_state, pos))
return self.children
def select_node(self):
selected_node = None
selected_UCT = -float('inf')
for child in self.children:
child_UCT = child.calc_UCT()
if child_UCT > selected_UCT:
selected_node = child
selected_UCT = child_UCT
return selected_node
def backpropagate(self, winner):
self.plays += 1
if not winner:
self.wins += 0.5
elif self.me == winner:
self.wins += 1
else:
self.wins += 0
self.parent.plays += 1
def calc_UCT(self):
# The UCT value is considered infinity for the node that has not been explored.
if self.plays == 0 : return float('inf')
exploitation = self.wins / self.plays
c = math.sqrt(2)
exploration = c * math.sqrt(math.log2(self.parent.plays) / self.plays)
return exploitation + exploration
def dump_node(self):
print(f"Position: {self.position}, Plays: {self.plays}, Wins: {self.wins}, Me: {self.me}")
def simulate_node(self):
cur_move = self.move
while cur_move < 8:
cur_state = self.state.copy()
cur_move += 1
positions = get_valid_moves(cur_state)
# Get a random position to move to
position = random.choice(positions)
# Get the player symbol
player = '+' if cur_move % 2 == 0 else 'o'
cur_state[position] = player
# Does this move win the game
if check_win(cur_state, position, player):
return player
def main(initial_state):
# Create the root node to start with
root_node = TicTacToeNode(None, initial_state.copy(), -1)
# Expand the nodes and see if we get the win on the first move
nodes = root_node.expand_node()
for node in nodes:
if check_win(node.state, node.position, node.me):
return node
# Run the simulation 10,000 times
for i in range(10000):
node = root_node.select_node()
winner = node.simulate_node()
node.backpropagate(winner)
# Dislpay the dump
[node.dump_node() for node in nodes]
# Find the node that has been explored the highest number of times
nodes.sort(key=lambda x: x.plays, reverse=True)
return nodes[0]
chosen = main([ "", "", "o",
"", "+", "",
"+", "o", "" ])
print(f"\nResult: Position: {chosen.position}, Plays: {chosen.plays}, Wins: {chosen.wins}")