Skip to content

Support gymnasium #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/task_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ jobs:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT:$COPPELIASIM_ROOT/platforms
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT

pip install ".[dev]"
pip install ".[gym,dev]"
pip install "pytest-xdist[psutil]"
pytest -v -n auto tests/unit
pytest -v -n auto tests/demos
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ jobs:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT:$COPPELIASIM_ROOT/platforms
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT

pip install ".[dev]"
pip install ".[gym,dev]"
pip install "pytest-xdist[psutil]"
pytest -v -n auto tests/unit
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ the observation mode: 'state' or 'vision'.

```python
import gym
import rlbench.gym
import rlbench

env = gym.make('reach_target-state-v0')
# Alternatively, for vision:
Expand Down
13 changes: 9 additions & 4 deletions examples/rlbench_gym.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import gym
import rlbench.gym
import gymnasium as gym
from gymnasium.utils.performance import benchmark_step
import rlbench

env = gym.make('rlbench/reach_target-vision-v0', render_mode="rgb_array")

env = gym.make('reach_target-state-v0', render_mode='human')

training_steps = 120
episode_length = 40
for i in range(training_steps):
if i % episode_length == 0:
print('Reset Episode')
obs = env.reset()
obs, reward, terminate, _ = env.step(env.action_space.sample())
obs, reward, terminate, _, _ = env.step(env.action_space.sample())
env.render() # Note: rendering increases step time.

print('Done')

fps = benchmark_step(env, target_duration=10)
print(f"FPS: {fps:.2f}")
env.close()
48 changes: 48 additions & 0 deletions examples/rlbench_gym_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import time
import gymnasium as gym
import rlbench


def benchmark_vector_step(env, target_duration: int = 5, seed=None) -> float:
steps = 0
end = 0.0
env.reset(seed=seed)
env.action_space.sample()
start = time.monotonic()

while True:
action = env.action_space.sample()
_, _, terminal, truncated, _ = env.step(action)
steps += terminal.shape[0]

# if terminal or truncated:
# env.reset()

if time.monotonic() - start > target_duration:
end = time.monotonic()
break

length = end - start

steps_per_time = steps / length
return steps_per_time

if __name__ == "__main__":
# Only works with spawn (multiprocessing) context
env = gym.make_vec('rlbench/reach_target-vision-v0', num_envs=2, vectorization_mode="async", vector_kwargs={"context": "spawn"})

training_steps = 120
episode_length = 40
for i in range(training_steps):
if i % episode_length == 0:
print('Reset Episode')
obs = env.reset()
obs, reward, terminate, _, _ = env.step(env.action_space.sample())
env.render() # Note: rendering increases step time.

print('Done')

fps = benchmark_vector_step(env, target_duration=10)
print(f"FPS: {fps:.2f}")

env.close()
57 changes: 44 additions & 13 deletions rlbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,48 @@
__version__ = '1.2.0'
__version__ = "1.2.0"

import numpy as np
import pyrep

pr_v = np.array(pyrep.__version__.split('.'), dtype=int)
if pr_v.size < 4 or np.any(pr_v < np.array([4, 1, 0, 2])):
raise ImportError(
'PyRep version must be greater than 4.1.0.2. Please update PyRep.')
import os

from gymnasium import register

import rlbench.backend.task as task
from rlbench.action_modes.action_mode import (
ActionMode,
ArmActionMode,
GripperActionMode,
)
from rlbench.environment import Environment
from rlbench.action_modes.action_mode import ActionMode, ArmActionMode, GripperActionMode
from rlbench.observation_config import ObservationConfig
from rlbench.observation_config import CameraConfig
from rlbench.sim2real.domain_randomization import RandomizeEvery
from rlbench.sim2real.domain_randomization import VisualRandomizationConfig
from rlbench.observation_config import CameraConfig, ObservationConfig
from rlbench.sim2real.domain_randomization import (
RandomizeEvery,
VisualRandomizationConfig,
)
from rlbench.utils import name_to_task_class

__all__ = [
"ActionMode",
"ArmActionMode",
"GripperActionMode",
"CameraConfig",
"Environment",
"ObservationConfig",
"RandomizeEvery",
"VisualRandomizationConfig",
]

TASKS = [
t for t in os.listdir(task.TASKS_PATH) if t != "__init__.py" and t.endswith(".py")
]

for task_file in TASKS:
task_name = task_file.split(".py")[0]
task_class = name_to_task_class(task_name)
for obs_mode in ["state", "vision"]:
register(
id=f"rlbench/{task_name}-{obs_mode}-v0",
entry_point="rlbench.gym:RLBenchEnv",
kwargs={
"task_class": task_class,
"observation_mode": obs_mode,
},
nondeterministic=True,
)
Empty file added rlbench/assets/__init__.py
Empty file.
87 changes: 39 additions & 48 deletions tools/dataset_generator.py → rlbench/dataset_generator.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,21 @@
from multiprocessing import Process, Manager
import argparse
import os
import pickle
from multiprocessing import Manager, Process

import numpy as np
from PIL import Image
from pyrep.const import RenderMode

import rlbench.backend.task as task
from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
import rlbench.backend.task as task

import os
import pickle
from PIL import Image
from rlbench.backend import utils
from rlbench.backend.const import *
import numpy as np

from absl import app
from absl import flags

FLAGS = flags.FLAGS

flags.DEFINE_string('save_path',
'/tmp/rlbench_data/',
'Where to save the demos.')
flags.DEFINE_list('tasks', [],
'The tasks to collect. If empty, all tasks are collected.')
flags.DEFINE_list('image_size', [128, 128],
'The size of the images tp save.')
flags.DEFINE_enum('renderer', 'opengl3', ['opengl', 'opengl3'],
'The renderer to use. opengl does not include shadows, '
'but is faster.')
flags.DEFINE_integer('processes', 1,
'The number of parallel processes during collection.')
flags.DEFINE_integer('episodes_per_task', 10,
'The number of episodes to collect per task.')
flags.DEFINE_integer('variations', -1,
'Number of variations to collect per task. -1 for all.')
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment


def check_and_make(dir):
Expand Down Expand Up @@ -166,15 +144,15 @@ def save_demo(demo, example_path):
pickle.dump(demo, f)


def run(i, lock, task_index, variation_count, results, file_lock, tasks):
def run(i, lock, task_index, variation_count, results, file_lock, tasks, args):
"""Each thread will choose one task and variation, and then gather
all the episodes_per_task for that variation."""

# Initialise each thread with random seed
np.random.seed(None)
num_tasks = len(tasks)

img_size = list(map(int, FLAGS.image_size))
img_size = list(map(int, args.image_size))

obs_config = ObservationConfig()
obs_config.set_all(True)
Expand All @@ -198,13 +176,13 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
obs_config.wrist_camera.masks_as_one_channel = False
obs_config.front_camera.masks_as_one_channel = False

if FLAGS.renderer == 'opengl':
if args.renderer == 'opengl':
obs_config.right_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.left_shoulder_camera.render_mode = RenderMode.OPENGL
obs_config.overhead_camera.render_mode = RenderMode.OPENGL
obs_config.wrist_camera.render_mode = RenderMode.OPENGL
obs_config.front_camera.render_mode = RenderMode.OPENGL
elif FLAGS.renderer == 'opengl3':
elif args.renderer == 'opengl3':
obs_config.right_shoulder_camera.render_mode = RenderMode.OPENGL3
obs_config.left_shoulder_camera.render_mode = RenderMode.OPENGL3
obs_config.overhead_camera.render_mode = RenderMode.OPENGL3
Expand Down Expand Up @@ -233,8 +211,8 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
t = tasks[task_index.value]
task_env = rlbench_env.get_task(t)
var_target = task_env.variation_count()
if FLAGS.variations >= 0:
var_target = np.minimum(FLAGS.variations, var_target)
if args.variations >= 0:
var_target = np.minimum(args.variations, var_target)
if my_variation_count >= var_target:
# If we have reached the required number of variations for this
# task, then move on to the next task.
Expand All @@ -252,7 +230,7 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
descriptions, _ = task_env.reset()

variation_path = os.path.join(
FLAGS.save_path, task_env.get_name(),
args.save_path, task_env.get_name(),
VARIATIONS_FOLDER % my_variation_count)

check_and_make(variation_path)
Expand All @@ -265,7 +243,7 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
check_and_make(episodes_path)

abort_variation = False
for ex_idx in range(FLAGS.episodes_per_task):
for ex_idx in range(args.episodes_per_task):
print('Process', i, '// Task:', task_env.get_name(),
'// Variation:', my_variation_count, '// Demo:', ex_idx)
attempts = 10
Expand Down Expand Up @@ -300,16 +278,29 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
rlbench_env.shutdown()


def main(argv):
def parse_args():
parser = argparse.ArgumentParser(description="RLBench Dataset Generator")
parser.add_argument('--save_path', type=str, default='/tmp/rlbench_data/', help='Where to save the demos.')
parser.add_argument('--tasks', nargs='*', default=[], help='The tasks to collect. If empty, all tasks are collected.')
parser.add_argument('--image_size', nargs=2, type=int, default=[128, 128], help='The size of the images to save.')
parser.add_argument('--renderer', type=str, choices=['opengl', 'opengl3'], default='opengl3', help='The renderer to use. opengl does not include shadows, but is faster.')
parser.add_argument('--processes', type=int, default=1, help='The number of parallel processes during collection.')
parser.add_argument('--episodes_per_task', type=int, default=10, help='The number of episodes to collect per task.')
parser.add_argument('--variations', type=int, default=-1, help='Number of variations to collect per task. -1 for all.')
return parser.parse_args()


def main():
args = parse_args()

task_files = [t.replace('.py', '') for t in os.listdir(task.TASKS_PATH)
if t != '__init__.py' and t.endswith('.py')]

if len(FLAGS.tasks) > 0:
for t in FLAGS.tasks:
if len(args.tasks) > 0:
for t in args.tasks:
if t not in task_files:
raise ValueError('Task %s not recognised!.' % t)
task_files = FLAGS.tasks
task_files = args.tasks

tasks = [task_file_to_task_class(t) for t in task_files]

Expand All @@ -322,20 +313,20 @@ def main(argv):
variation_count = manager.Value('i', 0)
lock = manager.Lock()

check_and_make(FLAGS.save_path)
check_and_make(args.save_path)

processes = [Process(
target=run, args=(
i, lock, task_index, variation_count, result_dict, file_lock,
tasks))
for i in range(FLAGS.processes)]
tasks, args))
for i in range(args.processes)]
[t.start() for t in processes]
[t.join() for t in processes]

print('Data collection done!')
for i in range(FLAGS.processes):
for i in range(args.processes):
print(result_dict[i])


if __name__ == '__main__':
app.run(main)
main()
Loading
Loading