-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
89 lines (67 loc) · 3.46 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import collections
import numpy as np
import torch
import tqdm
import config
import environment # noqa
import environment.sql_env
import environment.wrappers
import experiment_buddy
import ppo.envs
import ppo.model
from lstmDQN.custom_agent import FixedLengthAgent
# Set the random seed manually for reproducibility.
np.random.seed(config.seed)
torch.manual_seed(config.seed)
def train(tb):
env = ppo.envs.make_vec_envs(config.env_name, config.seed, config.num_processes, device=config.device)
# This is a hack but the experiment defines it's own action space
env.action_space = environment.sql_env.TextSpace(ppo.model.get_output_vocab(), env.action_space.sequence_length, (1, env.action_space.sequence_length))
obs_len = env.observation_space.sequence_length + env.action_space.sequence_length * config.action_history_len
env.observation_space = environment.sql_env.TextSpace(env.observation_space.vocab + env.action_space.vocab, obs_len, (1, obs_len))
agent = FixedLengthAgent(env.observation_space, env.action_space, config.device)
agent.model.train()
env = environment.wrappers.WordLevelPreprocessing(env, config.action_history_len)
pbar = tqdm.tqdm(total=config.num_env_steps)
env_steps = 0
num_episodes = 0
task_performance = collections.defaultdict(lambda: collections.deque(maxlen=100))
avg_task_performance = collections.deque(maxlen=100)
while env_steps < config.num_env_steps:
pbar.update(1)
obs = env.reset()
sql_env = env.env.envs[0].env
done = False
episode_length = 0
episode_reward = 0
episode_loss = 0
while not done:
actions = agent.eps_greedy(obs.to(config.device).unsqueeze(-1))
episode_loss += agent.update(config.gamma)
# queries = processed_env.action_decode(actions)
next_obs, rewards, dones, infos = env.step(actions)
done, = dones
env_steps += 1
episode_length += 1
episode_reward += rewards
agent.replay_memory.add(obs, next_obs, actions.cpu(), rewards, dones, infos)
obs = next_obs
if num_episodes < agent.epsilon_anneal_episodes and agent.epsilon > agent.epsilon_anneal_to:
agent.epsilon -= (agent.epsilon_anneal_from - agent.epsilon_anneal_to) / float(agent.epsilon_anneal_episodes)
task_performance[sql_env.hidden_parameter, sql_env.selected_columns].append(episode_length)
avg_task_performance.append(episode_length)
tb.add_scalar('train/epsilon', agent.epsilon, env_steps)
tb.add_scalar('train/episode_reward', episode_reward, env_steps)
tb.add_scalar('train/episode_length', episode_length, env_steps)
tb.add_scalar('train/episode_loss', episode_loss / episode_length, env_steps)
tb.add_scalar('train/avg_episode_length', np.mean(avg_task_performance), env_steps)
tb.add_scalar(f'train/task_{sql_env.hidden_parameter}_{sql_env.selected_columns}', np.mean(task_performance[sql_env.hidden_parameter, sql_env.selected_columns]), env_steps)
if __name__ == '__main__':
experiment_buddy.register_defaults(vars(config))
PROC_NUM = 10
# HOST = "mila" if config.user == "esac" else ""
HOST = "mila"
RUN_SWEEP = True
tb = experiment_buddy.deploy(host=HOST, sweep_yaml="sweep.yml" if RUN_SWEEP else "", proc_num=PROC_NUM,
wandb_kwargs={"mode": "disabled" if config.DEBUG else "online", "entity": "rl-sql"})
train(tb)