-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
73 lines (57 loc) · 2.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import matplotlib.pyplot as plt
import BitFlipEnv as bflip
from dqn_without_her import DQNAgent as dqn
if __name__ == '__main__':
n_bits = 8
env = bflip.BitFlipEnv(n_bits)
n_episodes = 30000
epsilon_history = []
episodes = []
win_percent = []
success = 0
load_checkpoint = False
checkpoint_dir = os.path.join(os.getcwd(), '/checkpoint/')
# Initializes the DQN agent with simple experience replay
agent = dqn.DQNAgent(learning_rate=0.0001, n_actions=n_bits,
input_dims=n_bits, gamma=0.99,
epsilon=0.9, batch_size=64, memory_size=10000,
replace_network_count=50,
checkpoint_dir=checkpoint_dir)
if load_checkpoint:
agent.load_model()
# Iterate through the episodes
for episode in range(n_episodes):
env.reset_env()
state = env.state
goal = env.goal
done = False
for p in range(n_bits):
if not done:
action = agent.choose_action(state)
next_state, reward, done = env.take_step(action)
if not load_checkpoint:
agent.store_experience(state, action, reward, next_state, done)
agent.learn()
state = next_state
if done:
success += 1
# Average over last 500 episodes to avoid spikes
if episode % 500 == 0:
print('success rate for last 500 episodes after', episode, ':', success/5)
if len(win_percent) > 0 and (success / 500) > win_percent[len(win_percent) - 1]:
agent.save_model()
epsilon_history.append(agent.epsilon)
episodes.append(episode)
win_percent.append(success/500.0)
success = 0
print('Epsilon History:', epsilon_history)
print('Episodes:', episodes)
print('Win percentage:', win_percent)
figure = plt.figure()
plt.plot(episodes, win_percent)
plt.title('DQN without HER')
plt.ylabel('Win Percentage')
plt.xlabel('Number of Episodes')
plt.ylim([0, 1])
plt.savefig(plt.savefig(os.path.join(os.getcwd(), '/plots/')))