Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make training job resumable #18

Open
wants to merge 2 commits into
base: redesign
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mjrl/algos/batch_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,10 @@ def log_rollout_statistics(self, paths):
self.logger.log_kv('stoc_pol_std', std_return)
self.logger.log_kv('stoc_pol_max', max_return)
self.logger.log_kv('stoc_pol_min', min_return)

@property
def global_status(self):
return dict()

def load_global_status(self, status_dict):
pass
37 changes: 30 additions & 7 deletions mjrl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,27 @@ def __init__(self):

def log_kv(self, key, value):
# logs the (key, value) pair

# TODO: This implementation is wrong and error-prone:
# it would be NOT aligned if some keys are missing during one iteration.
if key not in self.log:
self.log[key] = []
self.log[key].append(value)
if len(self.log[key]) > self.max_len:
self.max_len = self.max_len + 1

def save_log(self, save_path):
pickle.dump(self.log, open(save_path+'/log.pickle', 'wb'))
with open(save_path+'/log.csv', 'w') as csv_file:
fieldnames = self.log.keys()
# TODO: Validate all lengths are the same.
pickle.dump(self.log, open(save_path + '/log.pickle', 'wb'))
with open(save_path + '/log.csv', 'w') as csv_file:
fieldnames = list(self.log.keys())
if 'iteration' not in fieldnames:
fieldnames = ['iteration'] + fieldnames

writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
for row in range(self.max_len):
row_dict = {}
row_dict = {'iteration': row}
for key in self.log.keys():
if row < len(self.log[key]):
row_dict[key] = self.log[key][row]
Expand All @@ -37,21 +44,37 @@ def save_log(self, save_path):
def get_current_log(self):
row_dict = {}
for key in self.log.keys():
# TODO: this is very error-prone (alignment is not guaranteed)
row_dict[key] = self.log[key][-1]
return row_dict

def shrink_to(self, num_entries):
for key in self.log.keys():
self.log[key] = self.log[key][:num_entries]

assert min([len(series) for series in self.log.values()]) == \
max([len(series) for series in self.log.values()])

def read_log(self, log_path):
assert log_path.endswith('log.csv')

with open(log_path) as csv_file:
reader = csv.DictReader(csv_file)
listr = list(reader)
keys = reader.fieldnames
data = {}
for key in keys:
data[key] = []
for row in listr:
for row, row_dict in enumerate(listr):
for key in keys:
try:
data[key].append(eval(row[key]))
data[key].append(eval(row_dict[key]))
except:
None
print("ERROR on reading key {}: {}".format(key, row_dict[key]))

if 'iteration' in data and data['iteration'][-1] != row:
raise RuntimeError("Iteration %d mismatch -- possibly corrupted logfile?" % row)

self.log = data
self.max_len = max(len(v) for k, v in self.log.items())
print("Log read from {}: had {} entries".format(log_path, self.max_len))
55 changes: 54 additions & 1 deletion mjrl/utils/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,49 @@
import os
import copy


def _load_latest_policy_and_logs(agent, *, policy_dir, logs_dir):
"""Loads the latest policy.
Returns the next step number to begin with.
"""
assert os.path.isdir(policy_dir), str(policy_dir)
assert os.path.isdir(logs_dir), str(logs_dir)

log_csv_path = os.path.join(logs_dir, 'log.csv')
if not os.path.exists(log_csv_path):
return 0 # fresh start

print("Reading: {}".format(log_csv_path))
agent.logger.read_log(log_csv_path)
last_step = agent.logger.max_len - 1
if last_step <= 0:
return 0 # fresh start

# find latest policy/baseline
i = last_step
while i >= 0:
policy_path = os.path.join(policy_dir, 'policy_{}.pickle'.format(i))
baseline_path = os.path.join(policy_dir, 'baseline_{}.pickle'.format(i))
if not os.path.isfile(policy_path):
continue

with open(policy_path, 'rb') as fp:
agent.policy = pickle.load(fp)
with open(baseline_path, 'rb') as fp:
agent.baseline = pickle.load(fp)

# additional
global_status_path = os.path.join(policy_dir, 'global_status.pickle')
with open(global_status_path, 'rb') as fp:
agent.load_global_status( pickle.load(fp) )

agent.logger.shrink_to(i + 1)
assert agent.logger.max_len == i + 1
return agent.logger.max_len

# cannot find any saved policy
raise RuntimeError("Log file exists, but cannot find any saved policy.")

def train_agent(job_name, agent,
seed = 0,
niter = 101,
Expand Down Expand Up @@ -38,7 +81,16 @@ def train_agent(job_name, agent,
mean_pol_perf = 0.0
e = GymEnv(agent.env.env_id)

for i in range(niter):
# Load from any existing checkpoint, policy, statistics, etc.
# Why no checkpointing.. :(
i_start = _load_latest_policy_and_logs(agent,
policy_dir='iterations',
logs_dir='logs')
if i_start:
print("Resuming from an existing job folder ...")


for i in range(i_start, niter):
print("......................................................................................")
print("ITERATION : %i " % i)

Expand Down Expand Up @@ -68,6 +120,7 @@ def train_agent(job_name, agent,
pickle.dump(agent.policy, open('iterations/' + policy_file, 'wb'))
pickle.dump(agent.baseline, open('iterations/' + baseline_file, 'wb'))
pickle.dump(best_policy, open('iterations/best_policy.pickle', 'wb'))
pickle.dump(agent.global_status, open('iterations/global_status.pickle', 'wb'))

# print results to console
if i == 0:
Expand Down