-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTTTMCTS.cpp
161 lines (133 loc) · 4.6 KB
/
TTTMCTS.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include "TTTMCTS.hpp"
MCTS::MCTS() {};
MCTS::~MCTS() {};
// Changes next nodes char based on current char
char MCTS::nextNodeChar(char player)
{
if (player == 'x')
return 'o';
else if (player == 'o')
return 'x';
else
return ' ';
}
// Selection function; Loops through each child of the root and returns the one with the greatest UCT
MCTSNode* MCTS::selection()
{
MCTSNode* current_node = &rootNode;
while (!current_node->children.size() == 0)
{
current_node = current_node->getBestChildChoice();
if(current_node == current_node->getBestChildChoice())
break;
}
return current_node;
};
// Expansion function; Expands based off of selected node
MCTSNode* MCTS::expansion(MCTSNode* node)
{
// Returns node if it's an end game node
if (node->children.size() >= node->availablemoves.size())
return node;
// Switches the player_char to the next turns char
char next_node_char = nextNodeChar(node->player_char);
// Creates a new random child, adds it to the current nodes children, then returns the newly added child
MCTSNode* newchild = new MCTSNode(node->boardstate, next_node_char, node->availablemoves[rand() % node->availablemoves.size()], node);
node->children.push_back(newchild);
return node->children.back();
};
// My Rollout function that returns the char
char MCTS::Rollout(MCTSNode* node)
{
if (node->boardstate.ifWinFound() != ' ')
return node->boardstate.ifWinFound();
if (node->boardstate.ifDrawFound())
return 'd';
// Creates a temp_node for rolled out game
MCTSNode temp_node{node->boardstate, node->player_char};
char current_char = node->player_char;
for (unsigned int i = 0; i < node->availablemoves.size(); i++)
{
// Picks a random move and places it on the board
int randommoveindex = rand() % temp_node.availablemoves.size();
Player::AiMove randommove = temp_node.availablemoves[randommoveindex];
temp_node.boardstate.board[randommove.row][randommove.column].status = current_char;
// Resets the temp_nodes available moves
temp_node.availablemoves.clear();
temp_node.setAvailableMoves();
current_char = nextNodeChar(current_char);
if (temp_node.boardstate.ifWinFound() != ' ')
break;
if (temp_node.boardstate.ifDrawFound())
return 'd';
}
return temp_node.boardstate.ifWinFound();
}
// My simulation function
char MCTS::simulation(MCTSNode* node)
{
if (node->boardstate.ifWinFound() != ' ')
return node->boardstate.ifWinFound();
if (node->boardstate.ifDrawFound())
return 'd';
return Rollout(node);
};
// My update node function
void MCTS::updateNodes(MCTSNode* node, char winner)
{
node->node_visits += 10;
if (winner == node->player_char) // Or in this context, the previous node
node->wins += 10;
else if (winner == 'd')
node->wins += 5;
}
// My backpropogation function, loops up to the parent and updates the nodes on the way
void MCTS::backpropogation(MCTSNode* node, char winner)
{
MCTSNode* currentNode = node;
while (currentNode != nullptr)
{
updateNodes(currentNode, winner);
currentNode = currentNode->parent;
}
};
// Function to find next AI move
Player::AiMove MCTS::findNextMove(Board _board, char _player_char, unsigned int &boardssearched)
{
// Copying given board and player_char to root node
std::copy(&_board.board[0][0], &_board.board[0][0] + _board.dim_of_board * _board.dim_of_board, &rootNode.boardstate.board[0][0]);
rootNode.player_char = _player_char;
rootNode.availablemoves.clear();
rootNode.children.clear();
rootNode.setAvailableMoves();
// Start of timer and max interations
double start = GetTime();
double end = 0.0;
double time_span = 0.0;
unsigned int iter = 0;
//Loops to find next best move
while (time_span < max_time || iter == max_iter) // 15 Seconds
{
// Selection
MCTSNode* bestNode = selection();
// Expansion
MCTSNode* expanded_node = expansion(bestNode);
// Simulation
char winner = simulation(expanded_node);
// Update
backpropogation(expanded_node, winner);
// Used for timing
end = GetTime();
time_span = end - start;
iter++;
boardssearched++;
if(iter == max_iter)
boardspersec = boardssearched/end;
}
boardspersec = boardssearched/end;
// Get child move with most visits
Player::AiMove bestmove;
bestmove.row = rootNode.getChildWithBestScore()->move.row;
bestmove.column = rootNode.getChildWithBestScore()->move.column;
return bestmove;
};