-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathdqn_train.py
executable file
·112 lines (90 loc) · 3.73 KB
/
dqn_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python
# Copyright (c) 2021 Computer Vision Center (CVC) at the Universitat Autonoma de
# Barcelona (UAB).
#
# This work is licensed under the terms of the MIT license.
# For a copy, see <https://opensource.org/licenses/MIT>.
"""DQN Algorithm. Tested with CARLA.
You can visualize experiment results in ~/ray_results using TensorBoard.
"""
from __future__ import print_function
import argparse
import os
import yaml
import ray
from ray import tune
from rllib_integration.carla_env import CarlaEnv
from rllib_integration.carla_core import kill_all_servers
from rllib_integration.helper import get_checkpoint, launch_tensorboard
from dqn_example.dqn_experiment import DQNExperiment
from dqn_example.dqn_callbacks import DQNCallbacks
from dqn_example.dqn_trainer import CustomDQNTrainer
# Set the experiment to EXPERIMENT_CLASS so that it is passed to the configuration
EXPERIMENT_CLASS = DQNExperiment
def run(args):
try:
ray.init(address= "auto" if args.auto else None)
tune.run(CustomDQNTrainer,
name=args.name,
local_dir=args.directory,
stop={"perf/ram_util_percent": 85.0},
checkpoint_freq=1,
checkpoint_at_end=True,
restore=get_checkpoint(args.name, args.directory,
args.restore, args.overwrite),
config=args.config,
queue_trials=True)
finally:
kill_all_servers()
ray.shutdown()
def parse_config(args):
"""
Parses the .yaml configuration file into a readable dictionary
"""
with open(args.configuration_file) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config["env"] = CarlaEnv
config["env_config"]["experiment"]["type"] = EXPERIMENT_CLASS
config["callbacks"] = DQNCallbacks
return config
def main():
argparser = argparse.ArgumentParser(description=__doc__)
argparser.add_argument("configuration_file",
help="Configuration file (*.yaml)")
argparser.add_argument("-d", "--directory",
metavar='D',
default=os.path.expanduser("~") + "/ray_results/carla_rllib",
help="Specified directory to save results (default: ~/ray_results/carla_rllib")
argparser.add_argument("-n", "--name",
metavar="N",
default="dqn_example",
help="Name of the experiment (default: dqn_example)")
argparser.add_argument("--restore",
action="store_true",
default=False,
help="Flag to restore from the specified directory")
argparser.add_argument("--overwrite",
action="store_true",
default=False,
help="Flag to overwrite a specific directory (warning: all content of the folder will be lost.)")
argparser.add_argument("--tboff",
action="store_true",
default=False,
help="Flag to deactivate Tensorboard")
argparser.add_argument("--auto",
action="store_true",
default=False,
help="Flag to use auto address")
args = argparser.parse_args()
args.config = parse_config(args)
if not args.tboff:
launch_tensorboard(logdir=os.path.join(args.directory, args.name),
host="0.0.0.0" if args.auto else "localhost")
run(args)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass
finally:
print('\ndone.')