RGB image observations, medium-1m rebalance
What's Changed
We have added support for rendering observations as RBG images. This will expand the space of possible experiments and architectures. For example, now we can properly check generalization to new objects, which was impossible with discrete encoding, as it is difficult to add new embeddings for new objects after pre-training. This has some disadvantages, since rendering significantly reduces throughput tho.
This is a major update and therefore experimental for now. Example usage:
import jax
import xminigrid
from xminigrid.wrappers import GymAutoResetWrapper
from xminigrid.experimental.img_obs import RGBImgObservationWrapper
key = jax.random.PRNGKey(0)
reset_key, ruleset_key = jax.random.split(key)
benchmark = xminigrid.load_benchmark(name="trivial-1m")
ruleset = benchmark.sample_ruleset(ruleset_key)
env, env_params = xminigrid.make("XLand-MiniGrid-R9-25x25")
env_params = env_params.replace(ruleset=ruleset)
# auto-reset wrapper
env = GymAutoResetWrapper(env)
# for faster rendering, pre-rendered tiles will be saved at XLAND_MINIGRID_CACHE path
# use XLAND_MINIGRID_RELOAD_CACHE=True to force cache reload
env = RGBImgObservationWrapper(env)
timestep = jax.jit(env.reset)(env_params, reset_key)
timestep = jax.jit(env.step)(env_params, timestep, action=0)
To make rendering possible under jit, we had to make a few changes that changed the IDs of objects and colors. This broke compatibility with the old benchmarks, so we completely re-generated them. We also noticed a some time ago that medium-1m
benchmark is not harder compared to small-1m
, so we took a chance and made it a bit more complex as well. The updated configs can still be found at scripts/generate_benchmarks.sh
. Thus, be careful, as the results from previous release can change significantly!
Full Changelog: v0.6.0...v0.7.0