From b036284bbb84f5f7ed1efcd0bc12a10177058be3 Mon Sep 17 00:00:00 2001 From: Giri Anantharaman Date: Thu, 1 Jul 2021 11:05:15 -0700 Subject: [PATCH 1/4] Isolating multiprocessing change --- mjrl/samplers/core.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mjrl/samplers/core.py b/mjrl/samplers/core.py index be4a988..994fd0b 100644 --- a/mjrl/samplers/core.py +++ b/mjrl/samplers/core.py @@ -134,7 +134,7 @@ def sample_paths( start_time = timer.time() print("####### Gathering Samples #######") - results = _try_multiprocess(do_rollout, input_dict_list, + results = _try_multiprocess_cf(do_rollout, input_dict_list, num_cpu, max_process_time, max_timeouts) paths = [] # result is a paths type and results is list of paths @@ -186,7 +186,7 @@ def sample_data_batch( return paths -def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts): +def _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts): # Base case if max_timeouts == 0: @@ -208,3 +208,23 @@ def _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_time pool.terminate() pool.join() return results + +def _try_multiprocess_cf(func, input_dict_list, num_cpu, max_process_time, max_timeouts): + import concurrent.futures + results = None + if max_timeouts != 0: + with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpu) as executor: + submit_futures = [executor.submit(func, **input_dict) for input_dict in input_dict_list] + try: + results = [f.result() for f in submit_futures] + except TimeoutError as e: + print(str(e)) + print("Timeout Error raised...") + except concurrent.futures.CancelledError as e: + print(str(e)) + print("Future Cancelled Error raised...") + except Exception as e: + print(str(e)) + print("Error raised...") + raise e + return results From c48b267751124eb1775675e7d06a22aaa47531f3 Mon Sep 17 00:00:00 2001 From: Giri Anantharaman Date: Thu, 1 Jul 2021 11:44:06 -0700 Subject: [PATCH 2/4] Adding ability to print input_dict_list --- mjrl/samplers/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mjrl/samplers/core.py b/mjrl/samplers/core.py index 994fd0b..58333e9 100644 --- a/mjrl/samplers/core.py +++ b/mjrl/samplers/core.py @@ -134,6 +134,8 @@ def sample_paths( start_time = timer.time() print("####### Gathering Samples #######") + print("Input Dict List %s" % (str(input_dict_list))) + results = _try_multiprocess_cf(do_rollout, input_dict_list, num_cpu, max_process_time, max_timeouts) paths = [] @@ -202,7 +204,7 @@ def _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_t pool.close() pool.terminate() pool.join() - return _try_multiprocess(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1) + return _try_multiprocess_mp(func, input_dict_list, num_cpu, max_process_time, max_timeouts-1) pool.close() pool.terminate() From 80ec7d77674314ad17a38fd237e5e32a497f4be2 Mon Sep 17 00:00:00 2001 From: Giri Anantharaman Date: Fri, 20 Aug 2021 22:36:08 -0400 Subject: [PATCH 3/4] Removing unnecessary prints --- mjrl/samplers/core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mjrl/samplers/core.py b/mjrl/samplers/core.py index 58333e9..b8d5408 100644 --- a/mjrl/samplers/core.py +++ b/mjrl/samplers/core.py @@ -134,8 +134,6 @@ def sample_paths( start_time = timer.time() print("####### Gathering Samples #######") - print("Input Dict List %s" % (str(input_dict_list))) - results = _try_multiprocess_cf(do_rollout, input_dict_list, num_cpu, max_process_time, max_timeouts) paths = [] From 55726c3b135d908a871a8d30f3cf3a2cca16181b Mon Sep 17 00:00:00 2001 From: Vikash Kumar Date: Mon, 16 Aug 2021 14:39:16 -0400 Subject: [PATCH 4/4] Small edit to remove warning --- mjrl/utils/tensor_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mjrl/utils/tensor_utils.py b/mjrl/utils/tensor_utils.py index 8b0002a..1372fca 100644 --- a/mjrl/utils/tensor_utils.py +++ b/mjrl/utils/tensor_utils.py @@ -61,7 +61,7 @@ def high_res_normalize(probs): def stack_tensor_list(tensor_list): - return np.array(tensor_list) + return np.array(tensor_list, dtype='object') # tensor_shape = np.array(tensor_list[0]).shape # if tensor_shape is tuple(): # return np.array(tensor_list)