From fa65a752f889771551ab8d3c8ca3a2d0a3d6fbbd Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Tue, 16 Jul 2024 16:40:00 +0800 Subject: [PATCH 01/12] start working on bigym integration --- robobase/cfgs/env/bigym.yaml | 4 +-- robobase/envs/bigym.py | 11 ++----- robobase/envs/utils/bigym_utils.py | 53 ++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 11 deletions(-) create mode 100644 robobase/envs/utils/bigym_utils.py diff --git a/robobase/cfgs/env/bigym.yaml b/robobase/cfgs/env/bigym.yaml index a8a8760..919c2a6 100644 --- a/robobase/cfgs/env/bigym.yaml +++ b/robobase/cfgs/env/bigym.yaml @@ -2,8 +2,8 @@ env: env_name: bigym - episode_length: 1000 + episode_length: 3000 cameras: ["head", "right_wrist", "left_wrist"] action_mode: JOINT_POSITION -action_repeat: 2 + demo_frequency: 25 \ No newline at end of file diff --git a/robobase/envs/bigym.py b/robobase/envs/bigym.py index cd311a6..c2f3148 100644 --- a/robobase/envs/bigym.py +++ b/robobase/envs/bigym.py @@ -2,8 +2,7 @@ from bigym.bigym_env import BiGymEnv from bigym.action_modes import ActionMode, JointPositionActionMode, TorqueActionMode -from bigym.envs.reach_target import ReachTarget -from bigym.envs.move_plate_between_drainers import MovePlateBetweenDrainers +from robobase.envs.utils.bigym_utils import TASK_MAP import gymnasium as gym from gymnasium.wrappers import TimeLimit from robobase.envs.env import EnvFactory @@ -26,13 +25,7 @@ class ActionModeType(Enum): def _task_name_to_env_class(task_name: str) -> type[BiGymEnv]: - match task_name: - case "reach_target": - return ReachTarget - case "move_plate_between_drainers": - return MovePlateBetweenDrainers - raise NotImplementedError("Env Not Implemented Yet.") - + return TASK_MAP[task_name] def _create_action_mode(action_mode: str) -> ActionMode: if action_mode == ActionModeType.TORQUE.value: diff --git a/robobase/envs/utils/bigym_utils.py b/robobase/envs/utils/bigym_utils.py new file mode 100644 index 0000000..25c0e63 --- /dev/null +++ b/robobase/envs/utils/bigym_utils.py @@ -0,0 +1,53 @@ +from bigym.envs.reach_target import ReachTarget, ReachTargetDual, ReachTargetSingle +from bigym.envs.move_plates import MovePlate, MoveTwoPlates +from bigym.envs.cupboards import CupboardsOpenAll, CupboardsCloseAll, WallCupboardOpen, WallCupboardClose, DrawerTopOpen, DrawerTopClose, DrawersAllOpen, DrawersAllClose +from bigym.envs.dishwasher import DishwasherOpen, DishwasherClose, DishwasherOpenTrays, DishwasherCloseTrays +from bigym.envs.dishwasher_cups import DishwasherLoadCups, DishwasherUnloadCups, DishwasherUnloadCupsLong +from bigym.envs.dishwasher_cutlery import DishwasherLoadCutlery, DishwasherUnloadCutlery, DishwasherUnloadCutleryLong +from bigym.envs.dishwasher_plates import DishwasherLoadPlates, DishwasherUnloadPlates, DishwasherUnloadPlatesLong +from bigym.envs.pick_and_place import PutCups, TakeCups, PickBox, SaucepanToHob, StoreKitchenware, ToastSandwich, FlipSandwich, RemoveSandwich, StoreBox +from bigym.envs.manipulation import FlipCup, FlipCutlery, StackBlocks +from bigym.envs.groceries import GroceriesStoreLower, GroceriesStoreUpper + +TASK_MAP = dict( + reach_target_single=ReachTargetSingle, + reach_target_multi_modal=ReachTarget, + reach_target_dual=ReachTargetDual, + stack_blocks=StackBlocks, + move_plate=MovePlate, + move_two_plates=MoveTwoPlates, + flip_cup=FlipCup, + flip_cutlery=FlipCutlery, + dishwasher_open=DishwasherOpen, + dishwasher_close=DishwasherClose, + dishwasher_open_trays=DishwasherOpenTrays, + dishwasher_close_trays=DishwasherCloseTrays, + dishwasher_load_cups=DishwasherLoadCups, + dishwasher_unload_cups=DishwasherUnloadCups, + dishwasher_unload_cups_long=DishwasherUnloadCupsLong, + dishwasher_load_cutlery=DishwasherLoadCutlery, + dishwasher_unload_cutlery=DishwasherUnloadCutlery, + dishwasher_unload_cutlery_long=DishwasherUnloadCutleryLong, + dishwasher_load_plates=DishwasherLoadPlates, + dishwasher_unload_plates=DishwasherUnloadPlates, + dishwasher_unload_plates_long=DishwasherUnloadPlatesLong, + drawer_top_open=DrawerTopOpen, + drawer_top_close=DrawerTopClose, + drawers_open_all=DrawersAllOpen, + drawers_close_all=DrawersAllClose, + wall_cupboard_open=WallCupboardOpen, + wall_cupboard_close=WallCupboardClose, + cupboards_open_all=CupboardsOpenAll, + cupboards_close_all=CupboardsCloseAll, + take_cups=TakeCups, + put_cups=PutCups, + pick_box=PickBox, + store_box=StoreBox, + saucepan_to_hob=SaucepanToHob, + store_kitchenware=StoreKitchenware, + sandwich_toast=ToastSandwich, + sandwich_flip=FlipSandwich, + sandwich_remove=RemoveSandwich, + store_groceries_lower=GroceriesStoreLower, + store_groceries_upper=GroceriesStoreUpper, +) From ca13c196944ded7b4c00d729c74faa7b340ae052 Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Wed, 17 Jul 2024 15:45:25 +0800 Subject: [PATCH 02/12] finish the env factory; untested --- robobase/cfgs/env/bigym.yaml | 2 +- robobase/envs/bigym.py | 290 ++++++++++++++++++++++++++--- robobase/envs/utils/bigym_utils.py | 128 ++++++++----- 3 files changed, 342 insertions(+), 78 deletions(-) diff --git a/robobase/cfgs/env/bigym.yaml b/robobase/cfgs/env/bigym.yaml index 919c2a6..08c63c9 100644 --- a/robobase/cfgs/env/bigym.yaml +++ b/robobase/cfgs/env/bigym.yaml @@ -6,4 +6,4 @@ env: cameras: ["head", "right_wrist", "left_wrist"] action_mode: JOINT_POSITION - demo_frequency: 25 \ No newline at end of file + demo_down_sample_rate: 20 diff --git a/robobase/envs/bigym.py b/robobase/envs/bigym.py index c2f3148..83e4ca5 100644 --- a/robobase/envs/bigym.py +++ b/robobase/envs/bigym.py @@ -1,20 +1,36 @@ from enum import Enum -from bigym.bigym_env import BiGymEnv +from bigym.bigym_env import BiGymEnv, CONTROL_FREQUENCY_MAX from bigym.action_modes import ActionMode, JointPositionActionMode, TorqueActionMode +from robobase.utils import rescale_demo_actions, DemoEnv, add_demo_to_replay_buffer from robobase.envs.utils.bigym_utils import TASK_MAP import gymnasium as gym from gymnasium.wrappers import TimeLimit from robobase.envs.env import EnvFactory from robobase.envs.wrappers import ( - RescaleFromTanh, + RescaleFromTanhWithMinMax, OnehotTime, ActionSequence, AppendDemoInfo, FrameStack, ConcatDim, + RecedingHorizonControl, ) from omegaconf import DictConfig +from bigym.utils.observation_config import ObservationConfig, CameraConfig +from bigym.action_modes import PelvisDof +import multiprocessing as mp +import logging +import numpy as np + +from demonstrations.demo import DemoStep +from demonstrations.demo_store import DemoStore, DemoConverter +from demonstrations.utils import Metadata + +from typing import List, Dict, Tuple +from pathlib import Path +import pickle +import copy UNIT_TEST = False @@ -27,6 +43,7 @@ class ActionModeType(Enum): def _task_name_to_env_class(task_name: str) -> type[BiGymEnv]: return TASK_MAP[task_name] + def _create_action_mode(action_mode: str) -> ActionMode: if action_mode == ActionModeType.TORQUE.value: return TorqueActionMode() @@ -35,16 +52,89 @@ def _create_action_mode(action_mode: str) -> ActionMode: class BiGymEnvFactory(EnvFactory): - def _wrap_env(self, env, cfg): - env = RescaleFromTanh(env) - env = ConcatDim(env, 1, -1, "low_dim_state") - env = TimeLimit(env, cfg.env.episode_length) + def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=False): + # last two are grippers + assert cfg.demos > 0 + assert cfg.action_repeat == 1 + + action_space = copy.deepcopy(env.action_space) + observation_space = copy.deepcopy(env.observation_space) + + env = RescaleFromTanhWithMinMax( + action_stats=self._action_stats, + min_max_margin=cfg.min_max_margin, + ) + env = ConcatDim( + env, + shape_length=1, + dim=-1, + new_name="low_dim_state", + keys_to_ignore=["proprioception_floating_base_actions"], + ) if cfg.use_onehot_time_and_no_bootstrap: env = OnehotTime(env, cfg.env.episode_length) env = FrameStack(env, cfg.frame_stack) + env = TimeLimit( + env, cfg.env.episode_length // cfg.demo_down_sample_rate, + ) env = ActionSequence(env, cfg.action_sequence) + + if not demo_env: + if not train: + env = RecedingHorizonControl( + env, + cfg.action_sequence, + cfg.env.episode_length // (cfg.env.demo_down_sample_rate), + cfg.execution_length, + temporal_ensemble=cfg.temporal_ensemble, + gain=cfg.temporal_ensemble_gain, + ) + else: + env = ActionSequence( + env, + cfg.action_sequence, + ) + env = AppendDemoInfo(env) - return env + + if return_raw_spaces: + return env, (action_space, observation_space) + else: + return env + + def _create_env(self, cfg: DictConfig) -> BiGymEnv: + bigym_class = _task_name_to_env_class(cfg.env.task_name) + camera_configs = [ + CameraConfig( + name=camera_name, + rgb=True, + depth=False, + resolution=cfg.visual_observation_shape, + ) for camera_name in cfg.env.cameras + ] + + if cfg.env.enable_all_floating_dof: + action_mode = JointPositionActionMode( + absolute=cfg.env.action_mode == "absolute", + floating_base=True, + floating_dofs=[ + PelvisDof.X, PelvisDof.Y, PelvisDof.Z, PelvisDof.RZ + ] + ) + else: + action_mode = JointPositionActionMode( + absolute=cfg.env.action_mode == "absolute", + floating_base=True, + ) + + return bigym_class( + action_mode=action_mode, + observation_config=ObservationConfig( + cameras=camera_configs if cfg.pixels else [], + proprioception=True, + privileged_information=False if cfg.pixels else True, + ) + ) def make_train_env(self, cfg: DictConfig) -> gym.vector.VectorEnv: vec_env_class = gym.vector.AsyncVectorEnv @@ -52,35 +142,173 @@ def make_train_env(self, cfg: DictConfig) -> gym.vector.VectorEnv: if UNIT_TEST: vec_env_class = gym.vector.SyncVectorEnv kwargs = dict() - bygym_class = _task_name_to_env_class(cfg.env.task_name) - action_mode = _create_action_mode(cfg.env.action_mode) - cameras = cfg.env.cameras if cfg.pixels else None + return vec_env_class( [ lambda: self._wrap_env( - bygym_class( - action_mode=action_mode, - cameras=cameras, - camera_resolution=cfg.visual_observation_shape, - render_mode="rgb_array", - ), - cfg, - ) - for _ in range(cfg.num_train_envs) - ], - **kwargs, + self._create_env(cfg), + demo_env=False, + train=True, + ) for _ in range(cfg.num_train_envs) + ], **kwargs ) - + def make_eval_env(self, cfg: DictConfig) -> gym.Env: - bygym_class = _task_name_to_env_class(cfg.env.task_name) - action_mode = _create_action_mode(cfg.env.action_mode) - cameras = cfg.env.cameras if cfg.pixels else None - return self._wrap_env( - bygym_class( - action_mode=action_mode, - cameras=cameras, - camera_resolution=cfg.visual_observation_shape, - render_mode="rgb_array", + env, self._action_space, self._observation_space = self._wrap_env( + env=self._create_env(cfg), + demo_env=False, + train=False, + return_raw_spaces=True + ) + return env + + def _get_demo_from_scratch(self, cfg: DictConfig, num_demos: int, mp_list: List) -> None: + demos = [] + + logging.info("Start to load demos.") + env = self._create_env(cfg) + + demo_store = DemoStore() + + demos = demo_store.get_demos(Metadata.from_env(env), amount=-1) + + for demo in demos: + for ts in demo.timesteps: + ts.observation = { + k: np.array(v) for k, v in ts.observation.items() + } + + env.close() + logging.info("Finished loading demos.") + mp_list.append(demos) + + def _get_demo_fn(self, cfg: DictConfig, num_demos: int, mp_list: List) -> None: + dataset_root = cfg.env.dataset_root + if dataset_root == "": + dataset_root = Path.home() / ".bigym" / "cache" + + cache_dir = ( + dataset_root / cfg.env.env_name / cfg.env.task_name / cfg.env.action_mode / "pixels" if cfg.pixels else "low_dim_state" + ) + cache_dir.makedirs(exist_ok=True) + cache_path = cache_dir / "cache.pkl" + + if cache_path.exists(): + demos = pickle.load(cache_path.open("rb")) + mp_list.append(demos) + logging.info("Loaded demos from cache.") + else: + self._get_demo_from_scratch(cfg, num_demos, mp_list) + demos = mp_list[0] + pickle.dump(demos, cache_path.open("wb"), pickle.HIGHEST_PROTOCOL) + + def collect_or_fetch_demos(self, cfg: DictConfig, num_demos: int): + manager = mp.Manager() + mp_list = manager.list() + + p = mp.Process( + target=self._get_demo_fn, + args=(cfg, num_demos, mp_list), + ) + p.start() + p.join() + + demos = mp_list[0] + + if num_demos < len(demos): + demos = demos[:num_demos] + + self._raw_demos = [ + DemoConverter.decimate( + demo, + target_freq=CONTROL_FREQUENCY_MAX // cfg.env.demo_down_sample_rate, + ) for demo in demos + ] + self._action_stats = self._compute_action_stats(cfg, demos) + + def post_collect_or_fetch_demos(self, cfg: DictConfig): + demo_list = [demo.timesteps for demo in self._raw_demos] + demo_list = rescale_demo_actions( + self._rescale_demo_action_helper, demo_list, cfg + ) + self._demos = self._demo_to_steps(cfg, demo_list) + + def load_demos_into_replay(self, cfg: DictConfig, buffer): + """See base class for documentation.""" + assert hasattr(self, "_demos"), ( + "There's no _demo attribute inside the factory, " + "Check `collect_or_fetch_demos` is called before calling this method." + ) + demo_env = self._wrap_env( + DemoEnv( + copy.deepcopy(self._demos), self._action_space, self._observation_space ), cfg, + demo_env=True, + train=False, + ) + for _ in range(len(self._demos)): + add_demo_to_replay_buffer(demo_env, buffer) + + + def _demo_to_steps(self, cfg: DictConfig, demo_list: List[List[DemoStep]]) -> List[DemoStep]: + ret_demos = [] + + for demo in demo_list: + cur_demo = [] + for i, step in enumerate(demo): + step.info.update({"demo": 1}) + if i == 0: + cur_demo.append((step.observation, step.info)) + else: + term, trunc = step.termination, step.truncation + reward = step.reward + if i == len(demo) - 1: + if not (term or trunc): + term = False + trunc = True + + reward = 1 + + cur_demo.append( + ( + step.observation, + reward, + term, + trunc, + step.info + ) + ) + ret_demos.append(cur_demo) + + return ret_demos + + def _compute_action_stats(self, cfg: DictConfig, demos: List[List[DemoStep]]) -> Dict: + actions = [] + for demo in demos: + for step in demo.timesteps: + info = step.info + if "demo_action" in info: + actions.append(info["demo_action"]) + actions = np.stack(actions) + + mean, std, gmax, gmin = self._get_gripper_action_stats(cfg) + action_mean = np.hstack([np.mean(actions, 0)[:-2], mean, mean]) + action_std = np.hstack([np.std(actions, 0)[:-2], std, std]) + action_max = np.hstack([np.max(actions, 0)[:-2], gmax, gmax]) + action_min = np.hstack([np.min(actions, 0)[:-2], gmin, gmin]) + action_stats = { + "mean": action_mean, + "std": action_std, + "max": action_max, + "min": action_min, + } + return action_stats + + def _get_gripper_action_stats(self, cfg: DictConfig) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + return (0.5, 1, 1, 0) + + def _rescale_demo_action_helper(self, info, cfg: DictConfig): + return RescaleFromTanhWithMinMax.transform_to_tanh( + info["demo_action"], action_stats=self._action_stats, min_max_margin=cfg.min_max_margin ) diff --git a/robobase/envs/utils/bigym_utils.py b/robobase/envs/utils/bigym_utils.py index 25c0e63..b65c8fd 100644 --- a/robobase/envs/utils/bigym_utils.py +++ b/robobase/envs/utils/bigym_utils.py @@ -1,53 +1,89 @@ from bigym.envs.reach_target import ReachTarget, ReachTargetDual, ReachTargetSingle from bigym.envs.move_plates import MovePlate, MoveTwoPlates -from bigym.envs.cupboards import CupboardsOpenAll, CupboardsCloseAll, WallCupboardOpen, WallCupboardClose, DrawerTopOpen, DrawerTopClose, DrawersAllOpen, DrawersAllClose -from bigym.envs.dishwasher import DishwasherOpen, DishwasherClose, DishwasherOpenTrays, DishwasherCloseTrays -from bigym.envs.dishwasher_cups import DishwasherLoadCups, DishwasherUnloadCups, DishwasherUnloadCupsLong -from bigym.envs.dishwasher_cutlery import DishwasherLoadCutlery, DishwasherUnloadCutlery, DishwasherUnloadCutleryLong -from bigym.envs.dishwasher_plates import DishwasherLoadPlates, DishwasherUnloadPlates, DishwasherUnloadPlatesLong -from bigym.envs.pick_and_place import PutCups, TakeCups, PickBox, SaucepanToHob, StoreKitchenware, ToastSandwich, FlipSandwich, RemoveSandwich, StoreBox +from bigym.envs.cupboards import ( + CupboardsOpenAll, + CupboardsCloseAll, + WallCupboardOpen, + WallCupboardClose, + DrawerTopOpen, + DrawerTopClose, + DrawersAllOpen, + DrawersAllClose, +) +from bigym.envs.dishwasher import ( + DishwasherOpen, + DishwasherClose, + DishwasherOpenTrays, + DishwasherCloseTrays, +) +from bigym.envs.dishwasher_cups import ( + DishwasherLoadCups, + DishwasherUnloadCups, + DishwasherUnloadCupsLong, +) +from bigym.envs.dishwasher_cutlery import ( + DishwasherLoadCutlery, + DishwasherUnloadCutlery, + DishwasherUnloadCutleryLong, +) +from bigym.envs.dishwasher_plates import ( + DishwasherLoadPlates, + DishwasherUnloadPlates, + DishwasherUnloadPlatesLong, +) +from bigym.envs.pick_and_place import ( + PutCups, + TakeCups, + PickBox, + SaucepanToHob, + StoreKitchenware, + ToastSandwich, + FlipSandwich, + RemoveSandwich, + StoreBox, +) from bigym.envs.manipulation import FlipCup, FlipCutlery, StackBlocks from bigym.envs.groceries import GroceriesStoreLower, GroceriesStoreUpper TASK_MAP = dict( - reach_target_single=ReachTargetSingle, - reach_target_multi_modal=ReachTarget, - reach_target_dual=ReachTargetDual, - stack_blocks=StackBlocks, - move_plate=MovePlate, - move_two_plates=MoveTwoPlates, - flip_cup=FlipCup, - flip_cutlery=FlipCutlery, - dishwasher_open=DishwasherOpen, - dishwasher_close=DishwasherClose, - dishwasher_open_trays=DishwasherOpenTrays, - dishwasher_close_trays=DishwasherCloseTrays, - dishwasher_load_cups=DishwasherLoadCups, - dishwasher_unload_cups=DishwasherUnloadCups, - dishwasher_unload_cups_long=DishwasherUnloadCupsLong, - dishwasher_load_cutlery=DishwasherLoadCutlery, - dishwasher_unload_cutlery=DishwasherUnloadCutlery, - dishwasher_unload_cutlery_long=DishwasherUnloadCutleryLong, - dishwasher_load_plates=DishwasherLoadPlates, - dishwasher_unload_plates=DishwasherUnloadPlates, - dishwasher_unload_plates_long=DishwasherUnloadPlatesLong, - drawer_top_open=DrawerTopOpen, - drawer_top_close=DrawerTopClose, - drawers_open_all=DrawersAllOpen, - drawers_close_all=DrawersAllClose, - wall_cupboard_open=WallCupboardOpen, - wall_cupboard_close=WallCupboardClose, - cupboards_open_all=CupboardsOpenAll, - cupboards_close_all=CupboardsCloseAll, - take_cups=TakeCups, - put_cups=PutCups, - pick_box=PickBox, - store_box=StoreBox, - saucepan_to_hob=SaucepanToHob, - store_kitchenware=StoreKitchenware, - sandwich_toast=ToastSandwich, - sandwich_flip=FlipSandwich, - sandwich_remove=RemoveSandwich, - store_groceries_lower=GroceriesStoreLower, - store_groceries_upper=GroceriesStoreUpper, + reach_target_single=ReachTargetSingle, # 2000, 10, enable_all_floating_dofs=False + reach_target_multi_modal=ReachTarget, # 3000, 10, enable_all_floating_dofs=False + reach_target_dual=ReachTargetDual, # 3000, 10, enable_all_floating_dofs=False + stack_blocks=StackBlocks, # 28500, 25 + move_plate=MovePlate, # 3000, 10 + move_two_plates=MoveTwoPlates, # 5500, 10 + flip_cup=FlipCup, # 5500, 10 + flip_cutlery=FlipCutlery, # 12500, 25 + dishwasher_open=DishwasherOpen, # 7500, 20 + dishwasher_close=DishwasherClose, # 7500, 20 + dishwasher_open_trays=DishwasherOpenTrays, # 9500, 25 + dishwasher_close_trays=DishwasherCloseTrays, # 7500, 25 + dishwasher_load_cups=DishwasherLoadCups, # 7500, 10 + dishwasher_unload_cups=DishwasherUnloadCups, # 10000, 25 + dishwasher_unload_cups_long=DishwasherUnloadCupsLong, # 18000, 25 + dishwasher_load_cutlery=DishwasherLoadCutlery, # 7000, 10 + dishwasher_unload_cutlery=DishwasherUnloadCutlery, # 15500, 25 + dishwasher_unload_cutlery_long=DishwasherUnloadCutleryLong, # 18000, 25 + dishwasher_load_plates=DishwasherLoadPlates, # 14000, 25 + dishwasher_unload_plates=DishwasherUnloadPlates, # 20000, 25 + dishwasher_unload_plates_long=DishwasherUnloadPlatesLong, # 26000, 25 + drawer_top_open=DrawerTopOpen, # 5000, 10 + drawer_top_close=DrawerTopClose, # 3000, 10 + drawers_open_all=DrawersAllOpen, # 12000, 25 + drawers_close_all=DrawersAllClose, # 5000, 25 + wall_cupboard_open=WallCupboardOpen, # 6000, 20 + wall_cupboard_close=WallCupboardClose, # 3000, 10 + cupboards_open_all=CupboardsOpenAll, # 22500, 25 + cupboards_close_all=CupboardsCloseAll, # 15500, 25 + take_cups=TakeCups, # 10500, 25 + put_cups=PutCups, # 8500, 20 + pick_box=PickBox, # 13500, 25 + store_box=StoreBox, # 15000, 25 + saucepan_to_hob=SaucepanToHob, # 11000, 25 + store_kitchenware=StoreKitchenware, # 20000, 25 + sandwich_toast=ToastSandwich, # 16500, 25 + sandwich_flip=FlipSandwich, # 15500, 25 + sandwich_remove=RemoveSandwich, # 13500, 25 + store_groceries_lower=GroceriesStoreLower, # 32000, 25 + store_groceries_upper=GroceriesStoreUpper, # 19000, 25 ) From 41623d489e218af51164d2433e708c599bee2f2a Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Wed, 17 Jul 2024 15:45:59 +0800 Subject: [PATCH 03/12] pre-commit format --- CHANGELOG.md | 2 +- robobase/envs/bigym.py | 96 ++++++++++--------- robobase/envs/rlbench.py | 6 +- robobase/envs/utils/bigym_utils.py | 78 +++++++-------- robobase/envs/wrappers/__init__.py | 6 +- robobase/method/value_based.py | 5 +- robobase/models/lix_utils/analysis_modules.py | 4 +- tests/integration/test_training.py | 4 +- train.py | 4 +- 9 files changed, 115 insertions(+), 90 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b494ad..b239296 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,4 +31,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- None. \ No newline at end of file +- None. diff --git a/robobase/envs/bigym.py b/robobase/envs/bigym.py index 83e4ca5..0deb85e 100644 --- a/robobase/envs/bigym.py +++ b/robobase/envs/bigym.py @@ -75,7 +75,8 @@ def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=Fals env = OnehotTime(env, cfg.env.episode_length) env = FrameStack(env, cfg.frame_stack) env = TimeLimit( - env, cfg.env.episode_length // cfg.demo_down_sample_rate, + env, + cfg.env.episode_length // cfg.demo_down_sample_rate, ) env = ActionSequence(env, cfg.action_sequence) @@ -94,14 +95,14 @@ def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=Fals env, cfg.action_sequence, ) - + env = AppendDemoInfo(env) if return_raw_spaces: return env, (action_space, observation_space) else: return env - + def _create_env(self, cfg: DictConfig) -> BiGymEnv: bigym_class = _task_name_to_env_class(cfg.env.task_name) camera_configs = [ @@ -110,30 +111,29 @@ def _create_env(self, cfg: DictConfig) -> BiGymEnv: rgb=True, depth=False, resolution=cfg.visual_observation_shape, - ) for camera_name in cfg.env.cameras + ) + for camera_name in cfg.env.cameras ] if cfg.env.enable_all_floating_dof: action_mode = JointPositionActionMode( absolute=cfg.env.action_mode == "absolute", floating_base=True, - floating_dofs=[ - PelvisDof.X, PelvisDof.Y, PelvisDof.Z, PelvisDof.RZ - ] + floating_dofs=[PelvisDof.X, PelvisDof.Y, PelvisDof.Z, PelvisDof.RZ], ) else: action_mode = JointPositionActionMode( absolute=cfg.env.action_mode == "absolute", floating_base=True, ) - + return bigym_class( action_mode=action_mode, observation_config=ObservationConfig( cameras=camera_configs if cfg.pixels else [], proprioception=True, privileged_information=False if cfg.pixels else True, - ) + ), ) def make_train_env(self, cfg: DictConfig) -> gym.vector.VectorEnv: @@ -149,20 +149,24 @@ def make_train_env(self, cfg: DictConfig) -> gym.vector.VectorEnv: self._create_env(cfg), demo_env=False, train=True, - ) for _ in range(cfg.num_train_envs) - ], **kwargs + ) + for _ in range(cfg.num_train_envs) + ], + **kwargs ) - + def make_eval_env(self, cfg: DictConfig) -> gym.Env: env, self._action_space, self._observation_space = self._wrap_env( env=self._create_env(cfg), demo_env=False, train=False, - return_raw_spaces=True + return_raw_spaces=True, ) return env - def _get_demo_from_scratch(self, cfg: DictConfig, num_demos: int, mp_list: List) -> None: + def _get_demo_from_scratch( + self, cfg: DictConfig, num_demos: int, mp_list: List + ) -> None: demos = [] logging.info("Start to load demos.") @@ -174,21 +178,25 @@ def _get_demo_from_scratch(self, cfg: DictConfig, num_demos: int, mp_list: List) for demo in demos: for ts in demo.timesteps: - ts.observation = { - k: np.array(v) for k, v in ts.observation.items() - } - + ts.observation = {k: np.array(v) for k, v in ts.observation.items()} + env.close() logging.info("Finished loading demos.") mp_list.append(demos) - + def _get_demo_fn(self, cfg: DictConfig, num_demos: int, mp_list: List) -> None: dataset_root = cfg.env.dataset_root if dataset_root == "": dataset_root = Path.home() / ".bigym" / "cache" - + cache_dir = ( - dataset_root / cfg.env.env_name / cfg.env.task_name / cfg.env.action_mode / "pixels" if cfg.pixels else "low_dim_state" + dataset_root + / cfg.env.env_name + / cfg.env.task_name + / cfg.env.action_mode + / "pixels" + if cfg.pixels + else "low_dim_state" ) cache_dir.makedirs(exist_ok=True) cache_path = cache_dir / "cache.pkl" @@ -217,22 +225,23 @@ def collect_or_fetch_demos(self, cfg: DictConfig, num_demos: int): if num_demos < len(demos): demos = demos[:num_demos] - + self._raw_demos = [ DemoConverter.decimate( demo, target_freq=CONTROL_FREQUENCY_MAX // cfg.env.demo_down_sample_rate, - ) for demo in demos + ) + for demo in demos ] self._action_stats = self._compute_action_stats(cfg, demos) - + def post_collect_or_fetch_demos(self, cfg: DictConfig): demo_list = [demo.timesteps for demo in self._raw_demos] demo_list = rescale_demo_actions( self._rescale_demo_action_helper, demo_list, cfg ) self._demos = self._demo_to_steps(cfg, demo_list) - + def load_demos_into_replay(self, cfg: DictConfig, buffer): """See base class for documentation.""" assert hasattr(self, "_demos"), ( @@ -250,8 +259,9 @@ def load_demos_into_replay(self, cfg: DictConfig, buffer): for _ in range(len(self._demos)): add_demo_to_replay_buffer(demo_env, buffer) - - def _demo_to_steps(self, cfg: DictConfig, demo_list: List[List[DemoStep]]) -> List[DemoStep]: + def _demo_to_steps( + self, cfg: DictConfig, demo_list: List[List[DemoStep]] + ) -> List[DemoStep]: ret_demos = [] for demo in demo_list: @@ -267,23 +277,17 @@ def _demo_to_steps(self, cfg: DictConfig, demo_list: List[List[DemoStep]]) -> Li if not (term or trunc): term = False trunc = True - + reward = 1 - - cur_demo.append( - ( - step.observation, - reward, - term, - trunc, - step.info - ) - ) + + cur_demo.append((step.observation, reward, term, trunc, step.info)) ret_demos.append(cur_demo) return ret_demos - - def _compute_action_stats(self, cfg: DictConfig, demos: List[List[DemoStep]]) -> Dict: + + def _compute_action_stats( + self, cfg: DictConfig, demos: List[List[DemoStep]] + ) -> Dict: actions = [] for demo in demos: for step in demo.timesteps: @@ -304,11 +308,15 @@ def _compute_action_stats(self, cfg: DictConfig, demos: List[List[DemoStep]]) -> "min": action_min, } return action_stats - - def _get_gripper_action_stats(self, cfg: DictConfig) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + + def _get_gripper_action_stats( + self, cfg: DictConfig + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: return (0.5, 1, 1, 0) - + def _rescale_demo_action_helper(self, info, cfg: DictConfig): return RescaleFromTanhWithMinMax.transform_to_tanh( - info["demo_action"], action_stats=self._action_stats, min_max_margin=cfg.min_max_margin + info["demo_action"], + action_stats=self._action_stats, + min_max_margin=cfg.min_max_margin, ) diff --git a/robobase/envs/rlbench.py b/robobase/envs/rlbench.py index 6e57399..1668186 100644 --- a/robobase/envs/rlbench.py +++ b/robobase/envs/rlbench.py @@ -24,7 +24,11 @@ RecedingHorizonControl, AppendDemoInfo, ) -from robobase.utils import DemoStep, observations_to_timesteps, add_demo_to_replay_buffer +from robobase.utils import ( + DemoStep, + observations_to_timesteps, + add_demo_to_replay_buffer, +) from robobase.utils import ( observations_to_action_with_onehot_gripper, observations_to_action_with_onehot_gripper_nbp, diff --git a/robobase/envs/utils/bigym_utils.py b/robobase/envs/utils/bigym_utils.py index b65c8fd..69f2272 100644 --- a/robobase/envs/utils/bigym_utils.py +++ b/robobase/envs/utils/bigym_utils.py @@ -46,44 +46,44 @@ from bigym.envs.groceries import GroceriesStoreLower, GroceriesStoreUpper TASK_MAP = dict( - reach_target_single=ReachTargetSingle, # 2000, 10, enable_all_floating_dofs=False + reach_target_single=ReachTargetSingle, # 2000, 10, enable_all_floating_dofs=False reach_target_multi_modal=ReachTarget, # 3000, 10, enable_all_floating_dofs=False - reach_target_dual=ReachTargetDual, # 3000, 10, enable_all_floating_dofs=False - stack_blocks=StackBlocks, # 28500, 25 - move_plate=MovePlate, # 3000, 10 - move_two_plates=MoveTwoPlates, # 5500, 10 - flip_cup=FlipCup, # 5500, 10 - flip_cutlery=FlipCutlery, # 12500, 25 - dishwasher_open=DishwasherOpen, # 7500, 20 - dishwasher_close=DishwasherClose, # 7500, 20 - dishwasher_open_trays=DishwasherOpenTrays, # 9500, 25 - dishwasher_close_trays=DishwasherCloseTrays, # 7500, 25 - dishwasher_load_cups=DishwasherLoadCups, # 7500, 10 - dishwasher_unload_cups=DishwasherUnloadCups, # 10000, 25 - dishwasher_unload_cups_long=DishwasherUnloadCupsLong, # 18000, 25 - dishwasher_load_cutlery=DishwasherLoadCutlery, # 7000, 10 - dishwasher_unload_cutlery=DishwasherUnloadCutlery, # 15500, 25 - dishwasher_unload_cutlery_long=DishwasherUnloadCutleryLong, # 18000, 25 - dishwasher_load_plates=DishwasherLoadPlates, # 14000, 25 - dishwasher_unload_plates=DishwasherUnloadPlates, # 20000, 25 - dishwasher_unload_plates_long=DishwasherUnloadPlatesLong, # 26000, 25 - drawer_top_open=DrawerTopOpen, # 5000, 10 - drawer_top_close=DrawerTopClose, # 3000, 10 - drawers_open_all=DrawersAllOpen, # 12000, 25 - drawers_close_all=DrawersAllClose, # 5000, 25 - wall_cupboard_open=WallCupboardOpen, # 6000, 20 - wall_cupboard_close=WallCupboardClose, # 3000, 10 - cupboards_open_all=CupboardsOpenAll, # 22500, 25 - cupboards_close_all=CupboardsCloseAll, # 15500, 25 - take_cups=TakeCups, # 10500, 25 - put_cups=PutCups, # 8500, 20 - pick_box=PickBox, # 13500, 25 - store_box=StoreBox, # 15000, 25 - saucepan_to_hob=SaucepanToHob, # 11000, 25 - store_kitchenware=StoreKitchenware, # 20000, 25 - sandwich_toast=ToastSandwich, # 16500, 25 - sandwich_flip=FlipSandwich, # 15500, 25 - sandwich_remove=RemoveSandwich, # 13500, 25 - store_groceries_lower=GroceriesStoreLower, # 32000, 25 - store_groceries_upper=GroceriesStoreUpper, # 19000, 25 + reach_target_dual=ReachTargetDual, # 3000, 10, enable_all_floating_dofs=False + stack_blocks=StackBlocks, # 28500, 25 + move_plate=MovePlate, # 3000, 10 + move_two_plates=MoveTwoPlates, # 5500, 10 + flip_cup=FlipCup, # 5500, 10 + flip_cutlery=FlipCutlery, # 12500, 25 + dishwasher_open=DishwasherOpen, # 7500, 20 + dishwasher_close=DishwasherClose, # 7500, 20 + dishwasher_open_trays=DishwasherOpenTrays, # 9500, 25 + dishwasher_close_trays=DishwasherCloseTrays, # 7500, 25 + dishwasher_load_cups=DishwasherLoadCups, # 7500, 10 + dishwasher_unload_cups=DishwasherUnloadCups, # 10000, 25 + dishwasher_unload_cups_long=DishwasherUnloadCupsLong, # 18000, 25 + dishwasher_load_cutlery=DishwasherLoadCutlery, # 7000, 10 + dishwasher_unload_cutlery=DishwasherUnloadCutlery, # 15500, 25 + dishwasher_unload_cutlery_long=DishwasherUnloadCutleryLong, # 18000, 25 + dishwasher_load_plates=DishwasherLoadPlates, # 14000, 25 + dishwasher_unload_plates=DishwasherUnloadPlates, # 20000, 25 + dishwasher_unload_plates_long=DishwasherUnloadPlatesLong, # 26000, 25 + drawer_top_open=DrawerTopOpen, # 5000, 10 + drawer_top_close=DrawerTopClose, # 3000, 10 + drawers_open_all=DrawersAllOpen, # 12000, 25 + drawers_close_all=DrawersAllClose, # 5000, 25 + wall_cupboard_open=WallCupboardOpen, # 6000, 20 + wall_cupboard_close=WallCupboardClose, # 3000, 10 + cupboards_open_all=CupboardsOpenAll, # 22500, 25 + cupboards_close_all=CupboardsCloseAll, # 15500, 25 + take_cups=TakeCups, # 10500, 25 + put_cups=PutCups, # 8500, 20 + pick_box=PickBox, # 13500, 25 + store_box=StoreBox, # 15000, 25 + saucepan_to_hob=SaucepanToHob, # 11000, 25 + store_kitchenware=StoreKitchenware, # 20000, 25 + sandwich_toast=ToastSandwich, # 16500, 25 + sandwich_flip=FlipSandwich, # 15500, 25 + sandwich_remove=RemoveSandwich, # 13500, 25 + store_groceries_lower=GroceriesStoreLower, # 32000, 25 + store_groceries_upper=GroceriesStoreUpper, # 19000, 25 ) diff --git a/robobase/envs/wrappers/__init__.py b/robobase/envs/wrappers/__init__.py index aeb7f92..5858489 100644 --- a/robobase/envs/wrappers/__init__.py +++ b/robobase/envs/wrappers/__init__.py @@ -13,7 +13,11 @@ RecedingHorizonControl, ) from robobase.envs.wrappers.append_demo_info import AppendDemoInfo -from robobase.envs.wrappers.reward_modifiers import ClipReward, ScaleReward, ShapeRewards +from robobase.envs.wrappers.reward_modifiers import ( + ClipReward, + ScaleReward, + ShapeRewards, +) __all__ = [ "ConcatDim", diff --git a/robobase/method/value_based.py b/robobase/method/value_based.py index 60197e6..174737d 100644 --- a/robobase/method/value_based.py +++ b/robobase/method/value_based.py @@ -10,7 +10,10 @@ from robobase.method.core import OffPolicyMethod from robobase.models.fusion import FusionModule from robobase.models.encoder import EncoderModule -from robobase.models.fully_connected import FullyConnectedModule, RNNFullyConnectedModule +from robobase.models.fully_connected import ( + FullyConnectedModule, + RNNFullyConnectedModule, +) from robobase.replay_buffer.replay_buffer import ReplayBuffer from robobase.replay_buffer.prioritized_replay_buffer import PrioritizedReplayBuffer from robobase.method.utils import ( diff --git a/robobase/models/lix_utils/analysis_modules.py b/robobase/models/lix_utils/analysis_modules.py index 7374059..d3ee7e1 100644 --- a/robobase/models/lix_utils/analysis_modules.py +++ b/robobase/models/lix_utils/analysis_modules.py @@ -4,7 +4,9 @@ import numpy as np import torch.nn as nn -from robobase.models.lix_utils.analysis_layers import NonLearnableParameterizedRegWrapper +from robobase.models.lix_utils.analysis_layers import ( + NonLearnableParameterizedRegWrapper, +) from robobase.models.lix_utils import analysis_layers from robobase.models import EncoderCNNMultiViewDownsampleWithStrides from robobase.utils import weight_init diff --git a/tests/integration/test_training.py b/tests/integration/test_training.py index 51c260c..a331873 100644 --- a/tests/integration/test_training.py +++ b/tests/integration/test_training.py @@ -30,7 +30,9 @@ def run_cmd(hydra_overrides: list[str], target_reward: float, result_queue): try: with tempfile.TemporaryDirectory(dir=Path.cwd()) as tmpdirname: with initialize( - version_base=None, config_path="../../robobase/cfgs", job_name="test_app" + version_base=None, + config_path="../../robobase/cfgs", + job_name="test_app", ): hydra_overrides.append(f"replay.save_dir={tmpdirname}/replay") tmp_dir = Path(tmpdirname) diff --git a/train.py b/train.py index 5e49b2f..79492ca 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,9 @@ import hydra -@hydra.main(config_path="robobase/cfgs", config_name="robobase_config", version_base=None) +@hydra.main( + config_path="robobase/cfgs", config_name="robobase_config", version_base=None +) def main(cfg): from robobase.workspace import Workspace From ec3ade93a439282a95921eff505e826301dc274c Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Wed, 17 Jul 2024 17:44:06 +0800 Subject: [PATCH 04/12] running pipeline; bug with eval; probably the gripper_stats --- robobase/cfgs/env/bigym.yaml | 9 +- .../cfgs/env/bigym/put_plate_in_drainer.yaml | 9 -- robobase/cfgs/env/bigym/reach_target.yaml | 9 -- .../env/bigym/reach_target_multi_modal.yaml | 12 ++ robobase/cfgs/launch/act_pixel_bigym.yaml | 35 ++++++ robobase/envs/bigym.py | 116 +++++++----------- robobase/method/act.py | 2 +- 7 files changed, 102 insertions(+), 90 deletions(-) delete mode 100644 robobase/cfgs/env/bigym/put_plate_in_drainer.yaml delete mode 100644 robobase/cfgs/env/bigym/reach_target.yaml create mode 100644 robobase/cfgs/env/bigym/reach_target_multi_modal.yaml create mode 100644 robobase/cfgs/launch/act_pixel_bigym.yaml diff --git a/robobase/cfgs/env/bigym.yaml b/robobase/cfgs/env/bigym.yaml index 08c63c9..99f746f 100644 --- a/robobase/cfgs/env/bigym.yaml +++ b/robobase/cfgs/env/bigym.yaml @@ -4,6 +4,11 @@ env: env_name: bigym episode_length: 3000 cameras: ["head", "right_wrist", "left_wrist"] - action_mode: JOINT_POSITION - + action_mode: absolute + floating: true + dataset_root: "" demo_down_sample_rate: 20 + render_mode: rgb_array + enable_all_floating_dof: false + +demos: !!float .inf diff --git a/robobase/cfgs/env/bigym/put_plate_in_drainer.yaml b/robobase/cfgs/env/bigym/put_plate_in_drainer.yaml deleted file mode 100644 index 7125122..0000000 --- a/robobase/cfgs/env/bigym/put_plate_in_drainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ -# @package _global_ - -defaults: - - bigym - - _self_ - -env: - task_name: move_plate_between_drainers - stddev_schedule: linear(1.0,0.1,500000) diff --git a/robobase/cfgs/env/bigym/reach_target.yaml b/robobase/cfgs/env/bigym/reach_target.yaml deleted file mode 100644 index d9c8184..0000000 --- a/robobase/cfgs/env/bigym/reach_target.yaml +++ /dev/null @@ -1,9 +0,0 @@ -# @package _global_ - -defaults: - - bigym - - _self_ - -env: - task_name: reach_target - stddev_schedule: linear(1.0,0.1,500000) diff --git a/robobase/cfgs/env/bigym/reach_target_multi_modal.yaml b/robobase/cfgs/env/bigym/reach_target_multi_modal.yaml new file mode 100644 index 0000000..c5315c4 --- /dev/null +++ b/robobase/cfgs/env/bigym/reach_target_multi_modal.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: reach_target_multi_modal + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 3000 + demo_down_sample_rate: 25 + enable_all_floating_dof: false diff --git a/robobase/cfgs/launch/act_pixel_bigym.yaml b/robobase/cfgs/launch/act_pixel_bigym.yaml new file mode 100644 index 0000000..b4678f4 --- /dev/null +++ b/robobase/cfgs/launch/act_pixel_bigym.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +defaults: + - ../env: null + - ../method: act + +demos: !!float .inf +num_pretrain_steps: 150000 +num_train_frames: 0 +eval_every_steps: 5000 +num_eval_episodes: 10 +batch_size: 256 +save_snapshot: true +snapshot_every_n: 0 +replay_size_before_train: 500 + +pixels: true +frame_stack: 2 + +action_repeat: 1 +action_sequence: 16 +execution_length: 1 +temporal_ensemble: true +use_standardization: false # Demo-based standardization for action space +use_min_max_normalization: true # Demo-based min-max normalization for action space +min_max_margin: 0 + +update_every_steps: 1 + +replay: + nstep: 1 + +hydra: + run: + dir: ./exp_local/pixel_act/bigym_${env.task_name}_${now:%Y%m%d%H%M%S} diff --git a/robobase/envs/bigym.py b/robobase/envs/bigym.py index 0deb85e..9874814 100644 --- a/robobase/envs/bigym.py +++ b/robobase/envs/bigym.py @@ -1,8 +1,6 @@ -from enum import Enum - from bigym.bigym_env import BiGymEnv, CONTROL_FREQUENCY_MAX -from bigym.action_modes import ActionMode, JointPositionActionMode, TorqueActionMode -from robobase.utils import rescale_demo_actions, DemoEnv, add_demo_to_replay_buffer +from bigym.action_modes import JointPositionActionMode +from robobase.utils import DemoEnv, add_demo_to_replay_buffer from robobase.envs.utils.bigym_utils import TASK_MAP import gymnasium as gym from gymnasium.wrappers import TimeLimit @@ -24,31 +22,41 @@ import numpy as np from demonstrations.demo import DemoStep -from demonstrations.demo_store import DemoStore, DemoConverter +from demonstrations.demo_store import DemoStore from demonstrations.utils import Metadata -from typing import List, Dict, Tuple -from pathlib import Path -import pickle +from typing import List, Dict, Tuple, Callable import copy UNIT_TEST = False -class ActionModeType(Enum): - TORQUE = "TORQUE" - JOINT_POSITION = "JOINT_POSITION" +def rescale_demo_actions( + rescale_fn: Callable, demos: List[List[DemoStep]], cfg: DictConfig +): + """Rescale actions in demonstrations to [-1, 1] Tanh space. + This is because RoboBase assumes everything to be in [-1, 1] space. + Args: + rescale_fn: callable that takes info containing demo action and cfg and + outputs the rescaled action + demos: list of demo episodes whose actions are raw, i.e., not scaled + cfg: Configs -def _task_name_to_env_class(task_name: str) -> type[BiGymEnv]: - return TASK_MAP[task_name] + Returns: + List[Demo]: list of demo episodes whose actions are rescaled + """ + for demo in demos: + for step in demo: + info = step.info + if "demo_action" in info: + # Rescale demo actions + info["demo_action"] = rescale_fn(info, cfg) + return demos -def _create_action_mode(action_mode: str) -> ActionMode: - if action_mode == ActionModeType.TORQUE.value: - return TorqueActionMode() - elif action_mode == ActionModeType.JOINT_POSITION.value: - return JointPositionActionMode() +def _task_name_to_env_class(task_name: str) -> type[BiGymEnv]: + return TASK_MAP[task_name] class BiGymEnvFactory(EnvFactory): @@ -61,6 +69,7 @@ def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=Fals observation_space = copy.deepcopy(env.observation_space) env = RescaleFromTanhWithMinMax( + env=env, action_stats=self._action_stats, min_max_margin=cfg.min_max_margin, ) @@ -73,12 +82,12 @@ def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=Fals ) if cfg.use_onehot_time_and_no_bootstrap: env = OnehotTime(env, cfg.env.episode_length) - env = FrameStack(env, cfg.frame_stack) + if not demo_env: + env = FrameStack(env, cfg.frame_stack) env = TimeLimit( env, - cfg.env.episode_length // cfg.demo_down_sample_rate, + cfg.env.episode_length // cfg.env.demo_down_sample_rate, ) - env = ActionSequence(env, cfg.action_sequence) if not demo_env: if not train: @@ -99,7 +108,7 @@ def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=Fals env = AppendDemoInfo(env) if return_raw_spaces: - return env, (action_space, observation_space) + return env, action_space, observation_space else: return env @@ -128,88 +137,66 @@ def _create_env(self, cfg: DictConfig) -> BiGymEnv: ) return bigym_class( + render_mode=cfg.env.render_mode, action_mode=action_mode, observation_config=ObservationConfig( cameras=camera_configs if cfg.pixels else [], proprioception=True, privileged_information=False if cfg.pixels else True, ), + control_frequency=CONTROL_FREQUENCY_MAX // cfg.env.demo_down_sample_rate, ) def make_train_env(self, cfg: DictConfig) -> gym.vector.VectorEnv: vec_env_class = gym.vector.AsyncVectorEnv - kwargs = dict(context="fork") - if UNIT_TEST: - vec_env_class = gym.vector.SyncVectorEnv - kwargs = dict() - return vec_env_class( [ lambda: self._wrap_env( self._create_env(cfg), + cfg, demo_env=False, train=True, ) for _ in range(cfg.num_train_envs) ], - **kwargs ) def make_eval_env(self, cfg: DictConfig) -> gym.Env: env, self._action_space, self._observation_space = self._wrap_env( env=self._create_env(cfg), + cfg=cfg, demo_env=False, train=False, return_raw_spaces=True, ) return env - def _get_demo_from_scratch( - self, cfg: DictConfig, num_demos: int, mp_list: List - ) -> None: + def _get_demo_fn(self, cfg: DictConfig, num_demos: int, mp_list: List) -> None: demos = [] logging.info("Start to load demos.") env = self._create_env(cfg) demo_store = DemoStore() + if np.isinf(num_demos): + num_demos = -1 - demos = demo_store.get_demos(Metadata.from_env(env), amount=-1) + demos = demo_store.get_demos( + Metadata.from_env(env), + amount=num_demos, + frequency=CONTROL_FREQUENCY_MAX // cfg.env.demo_down_sample_rate, + ) for demo in demos: for ts in demo.timesteps: - ts.observation = {k: np.array(v) for k, v in ts.observation.items()} + ts.observation = { + k: np.array(v, dtype=np.float32) for k, v in ts.observation.items() + } env.close() logging.info("Finished loading demos.") mp_list.append(demos) - def _get_demo_fn(self, cfg: DictConfig, num_demos: int, mp_list: List) -> None: - dataset_root = cfg.env.dataset_root - if dataset_root == "": - dataset_root = Path.home() / ".bigym" / "cache" - - cache_dir = ( - dataset_root - / cfg.env.env_name - / cfg.env.task_name - / cfg.env.action_mode - / "pixels" - if cfg.pixels - else "low_dim_state" - ) - cache_dir.makedirs(exist_ok=True) - cache_path = cache_dir / "cache.pkl" - - if cache_path.exists(): - demos = pickle.load(cache_path.open("rb")) - mp_list.append(demos) - logging.info("Loaded demos from cache.") - else: - self._get_demo_from_scratch(cfg, num_demos, mp_list) - demos = mp_list[0] - pickle.dump(demos, cache_path.open("wb"), pickle.HIGHEST_PROTOCOL) - def collect_or_fetch_demos(self, cfg: DictConfig, num_demos: int): manager = mp.Manager() mp_list = manager.list() @@ -223,16 +210,7 @@ def collect_or_fetch_demos(self, cfg: DictConfig, num_demos: int): demos = mp_list[0] - if num_demos < len(demos): - demos = demos[:num_demos] - - self._raw_demos = [ - DemoConverter.decimate( - demo, - target_freq=CONTROL_FREQUENCY_MAX // cfg.env.demo_down_sample_rate, - ) - for demo in demos - ] + self._raw_demos = demos self._action_stats = self._compute_action_stats(cfg, demos) def post_collect_or_fetch_demos(self, cfg: DictConfig): diff --git a/robobase/method/act.py b/robobase/method/act.py index 56825e2..0f52da2 100644 --- a/robobase/method/act.py +++ b/robobase/method/act.py @@ -381,7 +381,7 @@ def update( metrics = dict() batch = next(replay_iter) - batch = {k: v.to(self.device) for k, v in batch.items()} + batch = {k: v.float().to(self.device) for k, v in batch.items()} actions = batch["action"] reward = batch["reward"] From 3112000021e944e35b7ede71a8c5f38af0a86a84 Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Wed, 17 Jul 2024 18:46:25 +0800 Subject: [PATCH 05/12] working pipeline; note that we clipped actions in the rescale wrapper, need to check if this is correct --- robobase/envs/bigym.py | 5 ++++- robobase/envs/wrappers/rescale_from_tanh.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/robobase/envs/bigym.py b/robobase/envs/bigym.py index 9874814..66cefd0 100644 --- a/robobase/envs/bigym.py +++ b/robobase/envs/bigym.py @@ -290,7 +290,10 @@ def _compute_action_stats( def _get_gripper_action_stats( self, cfg: DictConfig ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - return (0.5, 1, 1, 0) + if cfg.env.action_mode in ["absolute", "delta"]: + return (0.5, 0.25, 1, 0) + else: + raise NotImplementedError("Unsupported action mode.") def _rescale_demo_action_helper(self, info, cfg: DictConfig): return RescaleFromTanhWithMinMax.transform_to_tanh( diff --git a/robobase/envs/wrappers/rescale_from_tanh.py b/robobase/envs/wrappers/rescale_from_tanh.py index eb7f4a9..0c36b3b 100644 --- a/robobase/envs/wrappers/rescale_from_tanh.py +++ b/robobase/envs/wrappers/rescale_from_tanh.py @@ -218,6 +218,7 @@ def __init__( @staticmethod def transform_from_tanh(action, action_stats, min_max_margin): + action = action.clip(-1.0, 1.0) action_min, action_max = action_stats["min"], action_stats["max"] _action_min = action_min - np.fabs(action_min) * min_max_margin _action_max = action_max + np.fabs(action_max) * min_max_margin From c6458d111935a5ab18ae102fde7446fa3018831e Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Wed, 17 Jul 2024 20:33:17 +0800 Subject: [PATCH 06/12] add all task configs --- robobase/cfgs/env/bigym/cupboards_close_all.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/cupboards_open_all.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/dishwasher_close.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/dishwasher_close_trays.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/dishwasher_load_cups.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/dishwasher_load_cutlery.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/dishwasher_load_plates.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/dishwasher_open.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/dishwasher_open_trays.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/dishwasher_unload_cups.yaml | 12 ++++++++++++ .../cfgs/env/bigym/dishwasher_unload_cups_long.yaml | 12 ++++++++++++ .../cfgs/env/bigym/dishwasher_unload_cutlery.yaml | 12 ++++++++++++ .../env/bigym/dishwasher_unload_cutlery_long.yaml | 12 ++++++++++++ .../cfgs/env/bigym/dishwasher_unload_plates.yaml | 12 ++++++++++++ .../env/bigym/dishwasher_unload_plates_long.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/drawer_top_close.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/drawer_top_open.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/drawers_close_all.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/drawers_open_all.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/flip_cup.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/flip_cutlery.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/move_plate.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/move_two_plates.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/pick_box.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/put_cups.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/reach_target_dual.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/reach_target_single.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/sandwich_flip.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/sandwich_remove.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/sandwich_toast.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/saucepan_to_hob.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/stack_blocks.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/store_box.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/store_groceries_lower.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/store_groceries_upper.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/store_kitchenware.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/take_cups.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/wall_cupboard_close.yaml | 12 ++++++++++++ robobase/cfgs/env/bigym/wall_cupboard_open.yaml | 12 ++++++++++++ 39 files changed, 468 insertions(+) create mode 100644 robobase/cfgs/env/bigym/cupboards_close_all.yaml create mode 100644 robobase/cfgs/env/bigym/cupboards_open_all.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_close.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_close_trays.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_load_cups.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_load_cutlery.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_load_plates.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_open.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_open_trays.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_unload_cups.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_unload_cups_long.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_unload_cutlery.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_unload_cutlery_long.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_unload_plates.yaml create mode 100644 robobase/cfgs/env/bigym/dishwasher_unload_plates_long.yaml create mode 100644 robobase/cfgs/env/bigym/drawer_top_close.yaml create mode 100644 robobase/cfgs/env/bigym/drawer_top_open.yaml create mode 100644 robobase/cfgs/env/bigym/drawers_close_all.yaml create mode 100644 robobase/cfgs/env/bigym/drawers_open_all.yaml create mode 100644 robobase/cfgs/env/bigym/flip_cup.yaml create mode 100644 robobase/cfgs/env/bigym/flip_cutlery.yaml create mode 100644 robobase/cfgs/env/bigym/move_plate.yaml create mode 100644 robobase/cfgs/env/bigym/move_two_plates.yaml create mode 100644 robobase/cfgs/env/bigym/pick_box.yaml create mode 100644 robobase/cfgs/env/bigym/put_cups.yaml create mode 100644 robobase/cfgs/env/bigym/reach_target_dual.yaml create mode 100644 robobase/cfgs/env/bigym/reach_target_single.yaml create mode 100644 robobase/cfgs/env/bigym/sandwich_flip.yaml create mode 100644 robobase/cfgs/env/bigym/sandwich_remove.yaml create mode 100644 robobase/cfgs/env/bigym/sandwich_toast.yaml create mode 100644 robobase/cfgs/env/bigym/saucepan_to_hob.yaml create mode 100644 robobase/cfgs/env/bigym/stack_blocks.yaml create mode 100644 robobase/cfgs/env/bigym/store_box.yaml create mode 100644 robobase/cfgs/env/bigym/store_groceries_lower.yaml create mode 100644 robobase/cfgs/env/bigym/store_groceries_upper.yaml create mode 100644 robobase/cfgs/env/bigym/store_kitchenware.yaml create mode 100644 robobase/cfgs/env/bigym/take_cups.yaml create mode 100644 robobase/cfgs/env/bigym/wall_cupboard_close.yaml create mode 100644 robobase/cfgs/env/bigym/wall_cupboard_open.yaml diff --git a/robobase/cfgs/env/bigym/cupboards_close_all.yaml b/robobase/cfgs/env/bigym/cupboards_close_all.yaml new file mode 100644 index 0000000..d7eb5b8 --- /dev/null +++ b/robobase/cfgs/env/bigym/cupboards_close_all.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: cupboards_close_all + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 15500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/cupboards_open_all.yaml b/robobase/cfgs/env/bigym/cupboards_open_all.yaml new file mode 100644 index 0000000..6a3db30 --- /dev/null +++ b/robobase/cfgs/env/bigym/cupboards_open_all.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: cupboards_open_all + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 22500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_close.yaml b/robobase/cfgs/env/bigym/dishwasher_close.yaml new file mode 100644 index 0000000..c698954 --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_close.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_close + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 7500 + demo_down_sample_rate: 20 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_close_trays.yaml b/robobase/cfgs/env/bigym/dishwasher_close_trays.yaml new file mode 100644 index 0000000..c48473c --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_close_trays.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_close_trays + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 3000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_load_cups.yaml b/robobase/cfgs/env/bigym/dishwasher_load_cups.yaml new file mode 100644 index 0000000..de63eb1 --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_load_cups.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_load_cups + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 7500 + demo_down_sample_rate: 10 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_load_cutlery.yaml b/robobase/cfgs/env/bigym/dishwasher_load_cutlery.yaml new file mode 100644 index 0000000..cb93c2b --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_load_cutlery.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_load_cutlery + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 7000 + demo_down_sample_rate: 10 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_load_plates.yaml b/robobase/cfgs/env/bigym/dishwasher_load_plates.yaml new file mode 100644 index 0000000..54dc772 --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_load_plates.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_load_plates + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 14000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_open.yaml b/robobase/cfgs/env/bigym/dishwasher_open.yaml new file mode 100644 index 0000000..af046a8 --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_open.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_open + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 7500 + demo_down_sample_rate: 20 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_open_trays.yaml b/robobase/cfgs/env/bigym/dishwasher_open_trays.yaml new file mode 100644 index 0000000..b89a080 --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_open_trays.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_open_trays + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 3000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_unload_cups.yaml b/robobase/cfgs/env/bigym/dishwasher_unload_cups.yaml new file mode 100644 index 0000000..ebd0a0e --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_unload_cups.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_unload_cups + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 10000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_unload_cups_long.yaml b/robobase/cfgs/env/bigym/dishwasher_unload_cups_long.yaml new file mode 100644 index 0000000..a02a8ed --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_unload_cups_long.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_unload_cups_long + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 18000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_unload_cutlery.yaml b/robobase/cfgs/env/bigym/dishwasher_unload_cutlery.yaml new file mode 100644 index 0000000..63a1bbf --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_unload_cutlery.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_unload_cutlery + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 15500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_unload_cutlery_long.yaml b/robobase/cfgs/env/bigym/dishwasher_unload_cutlery_long.yaml new file mode 100644 index 0000000..9785aba --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_unload_cutlery_long.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_unload_cutlery_long + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 18000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_unload_plates.yaml b/robobase/cfgs/env/bigym/dishwasher_unload_plates.yaml new file mode 100644 index 0000000..e078274 --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_unload_plates.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_unload_plates + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 20000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/dishwasher_unload_plates_long.yaml b/robobase/cfgs/env/bigym/dishwasher_unload_plates_long.yaml new file mode 100644 index 0000000..984795d --- /dev/null +++ b/robobase/cfgs/env/bigym/dishwasher_unload_plates_long.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: dishwasher_unload_plates_long + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 26000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/drawer_top_close.yaml b/robobase/cfgs/env/bigym/drawer_top_close.yaml new file mode 100644 index 0000000..fc162c7 --- /dev/null +++ b/robobase/cfgs/env/bigym/drawer_top_close.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: drawer_top_close + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 3000 + demo_down_sample_rate: 10 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/drawer_top_open.yaml b/robobase/cfgs/env/bigym/drawer_top_open.yaml new file mode 100644 index 0000000..7ffdeae --- /dev/null +++ b/robobase/cfgs/env/bigym/drawer_top_open.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: drawer_top_open + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 5000 + demo_down_sample_rate: 10 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/drawers_close_all.yaml b/robobase/cfgs/env/bigym/drawers_close_all.yaml new file mode 100644 index 0000000..71f6341 --- /dev/null +++ b/robobase/cfgs/env/bigym/drawers_close_all.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: drawers_close_all + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 5000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/drawers_open_all.yaml b/robobase/cfgs/env/bigym/drawers_open_all.yaml new file mode 100644 index 0000000..a614a42 --- /dev/null +++ b/robobase/cfgs/env/bigym/drawers_open_all.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: drawers_open_all + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 12000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/flip_cup.yaml b/robobase/cfgs/env/bigym/flip_cup.yaml new file mode 100644 index 0000000..04fbe5f --- /dev/null +++ b/robobase/cfgs/env/bigym/flip_cup.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: flip_cup + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 5500 + demo_down_sample_rate: 10 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/flip_cutlery.yaml b/robobase/cfgs/env/bigym/flip_cutlery.yaml new file mode 100644 index 0000000..a14fc94 --- /dev/null +++ b/robobase/cfgs/env/bigym/flip_cutlery.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: flip_cutlery + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 12500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/move_plate.yaml b/robobase/cfgs/env/bigym/move_plate.yaml new file mode 100644 index 0000000..14fe5ca --- /dev/null +++ b/robobase/cfgs/env/bigym/move_plate.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: move_plate + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 3000 + demo_down_sample_rate: 10 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/move_two_plates.yaml b/robobase/cfgs/env/bigym/move_two_plates.yaml new file mode 100644 index 0000000..1749ed0 --- /dev/null +++ b/robobase/cfgs/env/bigym/move_two_plates.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: move_two_plates + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 5500 + demo_down_sample_rate: 10 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/pick_box.yaml b/robobase/cfgs/env/bigym/pick_box.yaml new file mode 100644 index 0000000..de2b083 --- /dev/null +++ b/robobase/cfgs/env/bigym/pick_box.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: pick_box + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 13500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/put_cups.yaml b/robobase/cfgs/env/bigym/put_cups.yaml new file mode 100644 index 0000000..406d629 --- /dev/null +++ b/robobase/cfgs/env/bigym/put_cups.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: put_cups + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 8500 + demo_down_sample_rate: 20 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/reach_target_dual.yaml b/robobase/cfgs/env/bigym/reach_target_dual.yaml new file mode 100644 index 0000000..001284f --- /dev/null +++ b/robobase/cfgs/env/bigym/reach_target_dual.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: reach_target_dual + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 3000 + demo_down_sample_rate: 10 + enable_all_floating_dof: false diff --git a/robobase/cfgs/env/bigym/reach_target_single.yaml b/robobase/cfgs/env/bigym/reach_target_single.yaml new file mode 100644 index 0000000..e4c6652 --- /dev/null +++ b/robobase/cfgs/env/bigym/reach_target_single.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: reach_target_single + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 2000 + demo_down_sample_rate: 10 + enable_all_floating_dof: false diff --git a/robobase/cfgs/env/bigym/sandwich_flip.yaml b/robobase/cfgs/env/bigym/sandwich_flip.yaml new file mode 100644 index 0000000..1a7bb4a --- /dev/null +++ b/robobase/cfgs/env/bigym/sandwich_flip.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: sandwich_flip + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 15500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/sandwich_remove.yaml b/robobase/cfgs/env/bigym/sandwich_remove.yaml new file mode 100644 index 0000000..43a6733 --- /dev/null +++ b/robobase/cfgs/env/bigym/sandwich_remove.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: sandwich_remove + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 13500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/sandwich_toast.yaml b/robobase/cfgs/env/bigym/sandwich_toast.yaml new file mode 100644 index 0000000..ce6f626 --- /dev/null +++ b/robobase/cfgs/env/bigym/sandwich_toast.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: sandwich_toast + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 16500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/saucepan_to_hob.yaml b/robobase/cfgs/env/bigym/saucepan_to_hob.yaml new file mode 100644 index 0000000..8a26692 --- /dev/null +++ b/robobase/cfgs/env/bigym/saucepan_to_hob.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: saucepan_to_hob + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 11000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/stack_blocks.yaml b/robobase/cfgs/env/bigym/stack_blocks.yaml new file mode 100644 index 0000000..78d1e57 --- /dev/null +++ b/robobase/cfgs/env/bigym/stack_blocks.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: stack_blocks + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 28500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/store_box.yaml b/robobase/cfgs/env/bigym/store_box.yaml new file mode 100644 index 0000000..9966600 --- /dev/null +++ b/robobase/cfgs/env/bigym/store_box.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: store_box + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 15000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/store_groceries_lower.yaml b/robobase/cfgs/env/bigym/store_groceries_lower.yaml new file mode 100644 index 0000000..f45391b --- /dev/null +++ b/robobase/cfgs/env/bigym/store_groceries_lower.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: store_groceries_lower + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 32000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/store_groceries_upper.yaml b/robobase/cfgs/env/bigym/store_groceries_upper.yaml new file mode 100644 index 0000000..db254a6 --- /dev/null +++ b/robobase/cfgs/env/bigym/store_groceries_upper.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: store_groceries_upper + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 19000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/store_kitchenware.yaml b/robobase/cfgs/env/bigym/store_kitchenware.yaml new file mode 100644 index 0000000..7392e97 --- /dev/null +++ b/robobase/cfgs/env/bigym/store_kitchenware.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: store_kitchenware + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 20000 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/take_cups.yaml b/robobase/cfgs/env/bigym/take_cups.yaml new file mode 100644 index 0000000..b65106e --- /dev/null +++ b/robobase/cfgs/env/bigym/take_cups.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: take_cups + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 10500 + demo_down_sample_rate: 25 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/wall_cupboard_close.yaml b/robobase/cfgs/env/bigym/wall_cupboard_close.yaml new file mode 100644 index 0000000..28f66e6 --- /dev/null +++ b/robobase/cfgs/env/bigym/wall_cupboard_close.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: wall_cupboard_close + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 3000 + demo_down_sample_rate: 10 + enable_all_floating_dof: true diff --git a/robobase/cfgs/env/bigym/wall_cupboard_open.yaml b/robobase/cfgs/env/bigym/wall_cupboard_open.yaml new file mode 100644 index 0000000..4969bb4 --- /dev/null +++ b/robobase/cfgs/env/bigym/wall_cupboard_open.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +defaults: + - bigym + - _self_ + +env: + task_name: wall_cupboard_open + stddev_schedule: linear(1.0,0.1,500000) + episode_length: 6000 + demo_down_sample_rate: 20 + enable_all_floating_dof: true From ee3f3072010b207c53377d20cb79d35f8f056a50 Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Thu, 18 Jul 2024 12:10:09 +0800 Subject: [PATCH 07/12] validated training --- robobase/cfgs/launch/act_pixel_bigym.yaml | 6 +++--- robobase/cfgs/robobase_config.yaml | 1 - robobase/logger.py | 1 - 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/robobase/cfgs/launch/act_pixel_bigym.yaml b/robobase/cfgs/launch/act_pixel_bigym.yaml index b4678f4..d0d395a 100644 --- a/robobase/cfgs/launch/act_pixel_bigym.yaml +++ b/robobase/cfgs/launch/act_pixel_bigym.yaml @@ -4,11 +4,11 @@ defaults: - ../env: null - ../method: act -demos: !!float .inf +demos: 30 num_pretrain_steps: 150000 num_train_frames: 0 -eval_every_steps: 5000 -num_eval_episodes: 10 +eval_every_steps: 2000 +num_eval_episodes: 5 batch_size: 256 save_snapshot: true snapshot_every_n: 0 diff --git a/robobase/cfgs/robobase_config.yaml b/robobase/cfgs/robobase_config.yaml index ebcccdc..b9cc69a 100644 --- a/robobase/cfgs/robobase_config.yaml +++ b/robobase/cfgs/robobase_config.yaml @@ -61,7 +61,6 @@ replay: wandb: # weight and bias use: false project: ${oc.env:USER}RoboBase - entity: rll name: null tb: # TensorBoard diff --git a/robobase/logger.py b/robobase/logger.py index aa49486..a8012a3 100644 --- a/robobase/logger.py +++ b/robobase/logger.py @@ -174,7 +174,6 @@ def __init__(self, log_dir, cfg): wandb.init( project=cfg.wandb.project, - entity=cfg.wandb.entity, name=cfg.wandb.name, config=cfg_dict, ) From 205408559e5a7b34c8d5a76efdc193d70ea5a402 Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Thu, 18 Jul 2024 12:27:38 +0800 Subject: [PATCH 08/12] fix divided by 0 bug of the minmax wrapper --- robobase/envs/wrappers/rescale_from_tanh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/robobase/envs/wrappers/rescale_from_tanh.py b/robobase/envs/wrappers/rescale_from_tanh.py index 0c36b3b..b3ec8ad 100644 --- a/robobase/envs/wrappers/rescale_from_tanh.py +++ b/robobase/envs/wrappers/rescale_from_tanh.py @@ -233,7 +233,7 @@ def transform_to_tanh(action, action_stats, min_max_margin): _action_min = action_min - np.fabs(action_min) * min_max_margin _action_max = action_max + np.fabs(action_max) * min_max_margin - new_action = (action - _action_min) / (_action_max - _action_min) # to [0, 1] + new_action = (action - _action_min) / (_action_max - _action_min + 1e-8) # to [0, 1] new_action = new_action * 2 - 1 # to [-1, 1] return new_action.astype(np.float32, copy=False) From a97cb5e2e915d9036980187156fa13b2175e3578 Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Thu, 18 Jul 2024 14:10:59 +0800 Subject: [PATCH 09/12] change act frame stack --- robobase/cfgs/launch/act_pixel_bigym.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/robobase/cfgs/launch/act_pixel_bigym.yaml b/robobase/cfgs/launch/act_pixel_bigym.yaml index d0d395a..30008a7 100644 --- a/robobase/cfgs/launch/act_pixel_bigym.yaml +++ b/robobase/cfgs/launch/act_pixel_bigym.yaml @@ -15,7 +15,7 @@ snapshot_every_n: 0 replay_size_before_train: 500 pixels: true -frame_stack: 2 +frame_stack: 4 action_repeat: 1 action_sequence: 16 From 47b924af18e95e80fccb92f5f72fb3b3433bd746 Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Fri, 19 Jul 2024 11:19:06 +0800 Subject: [PATCH 10/12] add obs normalization --- README.md | 2 +- robobase/cfgs/launch/act_pixel_bigym.yaml | 1 + robobase/cfgs/robobase_config.yaml | 1 + robobase/envs/bigym.py | 27 +++++++++++++++++++++ robobase/envs/wrappers/concat_dim.py | 10 ++++++-- robobase/envs/wrappers/rescale_from_tanh.py | 4 ++- 6 files changed, 41 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c5df2c6..31f7a38 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ pip install ".[bigym]" | Method | Paper | 1-line Summary | Differences to paper? | Stable | |-----------------------------------------------|---------------------------------------------------------------------------------------------------------|---------------------------------------------|-----------------------------------|-----------| | [diffusion](robobase/cfgs/method/diffusion.yaml) | [Diffusion Policy: Visuomotor Policy Learning via Action Diffusion](https://arxiv.org/abs/2303.04137) | Brings diffusion to robotics. | None. | :warning: | -| [act](robobase/cfgs/method/act.yaml) | [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) | Transformer and action-sequence prediction. | None. | :warning: | +| [act](robobase/cfgs/method/act.yaml) | [Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware](https://arxiv.org/abs/2304.13705) | Transformer and action-sequence prediction. | None. | :white_check_mark: | ### Algorithmic Features diff --git a/robobase/cfgs/launch/act_pixel_bigym.yaml b/robobase/cfgs/launch/act_pixel_bigym.yaml index 30008a7..20fcfd1 100644 --- a/robobase/cfgs/launch/act_pixel_bigym.yaml +++ b/robobase/cfgs/launch/act_pixel_bigym.yaml @@ -24,6 +24,7 @@ temporal_ensemble: true use_standardization: false # Demo-based standardization for action space use_min_max_normalization: true # Demo-based min-max normalization for action space min_max_margin: 0 +norm_obs: true update_every_steps: 1 diff --git a/robobase/cfgs/robobase_config.yaml b/robobase/cfgs/robobase_config.yaml index b9cc69a..501c534 100644 --- a/robobase/cfgs/robobase_config.yaml +++ b/robobase/cfgs/robobase_config.yaml @@ -41,6 +41,7 @@ temporal_ensemble_gain: 0.01 use_standardization: false # Demo-based standardization for action space use_min_max_normalization: false # Demo-based min-max normalization for action space min_max_margin: 0.0 # If set to > 0, introduce margin for demo-driven min-max normalization +norm_obs: false # Replay buffer settings replay: diff --git a/robobase/envs/bigym.py b/robobase/envs/bigym.py index 66cefd0..5a818a9 100644 --- a/robobase/envs/bigym.py +++ b/robobase/envs/bigym.py @@ -73,11 +73,17 @@ def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=Fals action_stats=self._action_stats, min_max_margin=cfg.min_max_margin, ) + obs_stats = None + if cfg.norm_obs: + obs_stats = self._obs_stats + env = ConcatDim( env, shape_length=1, dim=-1, new_name="low_dim_state", + norm_obs=cfg.norm_obs, + obs_stats=obs_stats, keys_to_ignore=["proprioception_floating_base_actions"], ) if cfg.use_onehot_time_and_no_bootstrap: @@ -212,6 +218,7 @@ def collect_or_fetch_demos(self, cfg: DictConfig, num_demos: int): self._raw_demos = demos self._action_stats = self._compute_action_stats(cfg, demos) + self._obs_stats = self._compute_obs_stats(cfg, demos) def post_collect_or_fetch_demos(self, cfg: DictConfig): demo_list = [demo.timesteps for demo in self._raw_demos] @@ -287,6 +294,26 @@ def _compute_action_stats( } return action_stats + def _compute_obs_stats(self, cfg: DictConfig, demos: List[List[DemoStep]]) -> Dict: + obs = [] + for demo in demos: + for step in demo.timesteps: + obs.append(step.observation) + + keys = obs[0].keys() + obs = {key: np.stack([o[key] for o in obs], axis=0) for key in keys} + obs_mean = {key: np.mean(obs[key], 0) for key in keys} + obs_std = {key: np.std(obs[key], 0) for key in keys} + obs_min = {key: np.min(obs[key], 0) for key in keys} + obs_max = {key: np.max(obs[key], 0) for key in keys} + obs_stats = { + "mean": obs_mean, + "std": obs_std, + "max": obs_max, + "min": obs_min, + } + return obs_stats + def _get_gripper_action_stats( self, cfg: DictConfig ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: diff --git a/robobase/envs/wrappers/concat_dim.py b/robobase/envs/wrappers/concat_dim.py index 113b722..d0eb479 100644 --- a/robobase/envs/wrappers/concat_dim.py +++ b/robobase/envs/wrappers/concat_dim.py @@ -14,6 +14,8 @@ def __init__( shape_length: int, dim: int, new_name: str, + norm_obs: bool = False, + obs_stats: dict = None, keys_to_ignore: list[str] = None, ): """Init. @@ -31,8 +33,10 @@ def __init__( self.is_vector_env = getattr(env, "is_vector_env", False) self._shape_length = shape_length + int(self.is_vector_env) self._dim = dim + int(self.is_vector_env) - self._name_name = new_name + self._new_name = new_name self._keys_to_ignore = [] if keys_to_ignore is None else keys_to_ignore + self._norm_obs = norm_obs + self._obs_stats = obs_stats new_obs_dict = {} combined = [] for k, v in self.observation_space.items(): @@ -52,10 +56,12 @@ def _transform_timestep(self, observation, final: bool = False): combined = [] for k, v in observation.items(): if len(v.shape) == shape_len and k not in self._keys_to_ignore: + if self._norm_obs and k in self._obs_stats: + v = (v - self._obs_stats["mean"][k]) / self._obs_stats["std"][k] combined.append(v) else: new_obs[k] = v - new_obs[self._name_name] = np.concatenate(combined, dim) + new_obs[self._new_name] = np.concatenate(combined, dim) return new_obs def observation(self, observation): diff --git a/robobase/envs/wrappers/rescale_from_tanh.py b/robobase/envs/wrappers/rescale_from_tanh.py index b3ec8ad..3654ead 100644 --- a/robobase/envs/wrappers/rescale_from_tanh.py +++ b/robobase/envs/wrappers/rescale_from_tanh.py @@ -233,7 +233,9 @@ def transform_to_tanh(action, action_stats, min_max_margin): _action_min = action_min - np.fabs(action_min) * min_max_margin _action_max = action_max + np.fabs(action_max) * min_max_margin - new_action = (action - _action_min) / (_action_max - _action_min + 1e-8) # to [0, 1] + new_action = (action - _action_min) / ( + _action_max - _action_min + 1e-8 + ) # to [0, 1] new_action = new_action * 2 - 1 # to [-1, 1] return new_action.astype(np.float32, copy=False) From d4556c1dcd563defeabf4b8a91fb184f41767e1a Mon Sep 17 00:00:00 2001 From: Ma Xiao Date: Fri, 19 Jul 2024 23:44:13 +0800 Subject: [PATCH 11/12] fix move plate task configs --- robobase/cfgs/env/bigym/move_plate.yaml | 2 +- robobase/cfgs/env/bigym/move_two_plates.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/robobase/cfgs/env/bigym/move_plate.yaml b/robobase/cfgs/env/bigym/move_plate.yaml index 14fe5ca..eaee437 100644 --- a/robobase/cfgs/env/bigym/move_plate.yaml +++ b/robobase/cfgs/env/bigym/move_plate.yaml @@ -9,4 +9,4 @@ env: stddev_schedule: linear(1.0,0.1,500000) episode_length: 3000 demo_down_sample_rate: 10 - enable_all_floating_dof: true + enable_all_floating_dof: false diff --git a/robobase/cfgs/env/bigym/move_two_plates.yaml b/robobase/cfgs/env/bigym/move_two_plates.yaml index 1749ed0..ed13c72 100644 --- a/robobase/cfgs/env/bigym/move_two_plates.yaml +++ b/robobase/cfgs/env/bigym/move_two_plates.yaml @@ -9,4 +9,4 @@ env: stddev_schedule: linear(1.0,0.1,500000) episode_length: 5500 demo_down_sample_rate: 10 - enable_all_floating_dof: true + enable_all_floating_dof: false From 1c1cf41e1d7f8f3dc5eacd6dd06e9c40f2bd5e80 Mon Sep 17 00:00:00 2001 From: Xiao Ma Date: Tue, 30 Jul 2024 16:39:51 +0800 Subject: [PATCH 12/12] add comments --- robobase/envs/bigym.py | 2 ++ robobase/envs/wrappers/concat_dim.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/robobase/envs/bigym.py b/robobase/envs/bigym.py index 5a818a9..593cd8a 100644 --- a/robobase/envs/bigym.py +++ b/robobase/envs/bigym.py @@ -77,6 +77,8 @@ def _wrap_env(self, env, cfg, demo_env=False, train=True, return_raw_spaces=Fals if cfg.norm_obs: obs_stats = self._obs_stats + # We normalize the low dimensional observations in the ConcatDim wrapper. + # This is to be consistent with the original ACT implementation. env = ConcatDim( env, shape_length=1, diff --git a/robobase/envs/wrappers/concat_dim.py b/robobase/envs/wrappers/concat_dim.py index d0eb479..872a986 100644 --- a/robobase/envs/wrappers/concat_dim.py +++ b/robobase/envs/wrappers/concat_dim.py @@ -25,6 +25,8 @@ def __init__( shape_length: The ndim we are interested in, e.g. images=3, low_dim=1. dim: The oberservations with this ... new_name: The name of the new observation. + norm_obs: Whether to normalize observations. + obs_stats: The obs statistics for normalizing observations. keys_to_ignore: A list of keys to not include in this combined observation, regardless if they meet shape_len. """ @@ -55,6 +57,10 @@ def _transform_timestep(self, observation, final: bool = False): new_obs = {} combined = [] for k, v in observation.items(): + # We allow normalizing observations in the ConcatDim wrapper + # because all obs stats are stored with original key names and + # ConcatDim will rename them to new keys. Doing it here would + # safer and cleaner. if len(v.shape) == shape_len and k not in self._keys_to_ignore: if self._norm_obs and k in self._obs_stats: v = (v - self._obs_stats["mean"][k]) / self._obs_stats["std"][k]