-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRidgeAgents.py
37 lines (32 loc) · 1.45 KB
/
RidgeAgents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from util import manhattanDistance, Queue
from game import Directions, Actions
import random, util
from collections import defaultdict
import math
from game import Agent
from MonteCarlo import MCTS, Node
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn import neighbors
from sklearn.svm import SVC
from sklearn.linear_model import Ridge
from multiAgents import MultiAgentSearchAgent, extractFeature, getActionByNumber, dataColumns
class RidgeAgent(MultiAgentSearchAgent):
def __init__(self):
self.dataTrain = pd.read_csv("dataGameWonMoreThan1500WithColumnNames.csv")
self.dataTarget = self.dataTrain["labelNextAction"]
self.dataTrain = self.dataTrain.drop(columns=["labelNextAction"], axis=1)
xtrain, xtest, ytrain, ytest = train_test_split(self.dataTrain, self.dataTarget, train_size=0.8)
self.rr = Ridge(alpha=100)
self.rr.fit(xtrain, ytrain)
def getAction(self, currGameState):
data = pd.DataFrame(
columns=dataColumns)
data.loc[0, :] = extractFeature(currGameState, "South")
dataTrain = data.drop(columns=["labelNextAction"], axis=1)
nextActionNumber = self.rr.predict(dataTrain)
nextPredictedAction = getActionByNumber(nextActionNumber)
if (nextPredictedAction not in currGameState.getLegalActions(0)):
print("Illegal Action")
nextPredictedAction = currGameState.getLegalActions(0)[random.randrange(0, len(currGameState.getLegalActions(0)))]
return nextPredictedAction