-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathq_learning_agent.py
148 lines (124 loc) · 4.77 KB
/
q_learning_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import collections
import copy
import json
import time
import numpy as np
from game import Board
class QLearningAgent:
action_idx_to_move = {
row * 3 + col: (row, col) for row in range(3) for col in range(3)
}
move_to_action_idx = {value: key for key, value in action_idx_to_move.items()}
def __init__(self, *, seed, epsilon, alpha, gamma, gui=None):
self.seed = seed
self.epsilon = epsilon
self.alpha = alpha
self.gamma = gamma
self.reset_policy()
self.rng = np.random.default_rng(self.seed)
def clone(self):
agent = QLearningAgent(
seed=self.seed + 5678,
epsilon=self.epsilon,
alpha=self.alpha,
gamma=self.gamma,
)
agent.policy = copy.deepcopy(self.policy)
return agent
def get_move(self, ui, board, marker):
ui.show_board(board)
p = self.rng.uniform()
if p < self.epsilon:
move = self.random_move(board)
else:
move = self.policy_move(ui, board, marker)
assert board.is_empty(move[0], move[1])
return move
def n_boards_seen(self):
return len(self.policy[Board.FieldState.CROSS]) + len(
self.policy[Board.FieldState.CIRCLE]
)
def load_policy(self, fn):
with open(fn, "r") as f:
policy = json.load(f)
self.reset_policy()
for str_marker in policy:
for key in policy[str_marker]:
self.policy[Board.str_value_to_state[str_marker]][key] = np.array(
policy[str_marker][key]
)
return True
def policy_move(self, ui, board, marker):
key = board.to_str()
values = self.policy[marker][key].copy()
# mask occupied positions
for row in range(3):
for col in range(3):
if not board.is_empty(row, col):
action_idx = QLearningAgent.move_to_action_idx[(row, col)]
values[action_idx] = -np.inf
ui.show_policy(values)
# make sure to evenly sample all states with same value
max_value = np.max(values)
if sum(values == max_value) == 1:
action_idx = np.argmax(values)
else:
probs = np.ones_like(values)
probs[values < max_value] = 0.0
probs /= np.sum(probs)
action_idx = self.rng.choice(range(9), p=probs)
move = QLearningAgent.action_idx_to_move[action_idx]
return move
def random_move(self, board):
possible_moves = []
for row in range(3):
for col in range(3):
if board.is_empty(row, col):
possible_moves.append((row, col))
move = self.rng.choice(possible_moves)
return tuple(move)
def reset_policy(self):
self.policy = {
Board.FieldState.CROSS: collections.defaultdict(lambda: np.zeros(9)),
Board.FieldState.CIRCLE: collections.defaultdict(lambda: np.zeros(9)),
}
def save_policy(self, fn):
policy = {}
for marker in self.policy:
policy[marker] = {}
for key in self.policy[marker]:
policy[marker][key] = self.policy[marker][key].tolist()
with open(fn, "w") as f:
json.dump(policy, f)
def update_policy(self, final_reward, move_history, marker):
T = len(move_history)
board = Board()
next_board = Board()
for t, (key, move) in reversed(list(zip(range(T), move_history))):
board.from_str(key)
considered_keys = set()
for (board_symmetry, move_symmetry) in zip(
Board.board_symmetries(), Board.move_symmetries()
):
rotated_board = board_symmetry(board)
rotated_move = move_symmetry(move)
rotated_key = rotated_board.to_str()
if rotated_key in considered_keys:
continue
considered_keys.add(rotated_key)
action_idx = QLearningAgent.move_to_action_idx[rotated_move]
if t == (T - 1):
r = final_reward
max_Q = 0.0
else:
r = 0.0
next_key, _next_move = move_history[t + 1]
next_board.from_str(next_key)
next_rotated_board = board_symmetry(next_board)
next_rotated_key = next_rotated_board.to_str()
max_Q = np.max(self.policy[marker][next_rotated_key])
self.policy[marker][rotated_key][action_idx] += self.alpha * (
r
+ self.gamma * max_Q
- self.policy[marker][rotated_key][action_idx]
)