Skip to content

Commit

Permalink
Enhance and generalize how heightmaps are inserted in the mujoco viewer
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jul 11, 2024
1 parent 6005a82 commit 02e45e4
Showing 1 changed file with 81 additions and 15 deletions.
96 changes: 81 additions & 15 deletions src/jaxsim/mujoco/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import mujoco as mj
import numpy as np
import numpy.typing as npt
import xmltodict
from scipy.spatial.transform import Rotation

import jaxsim.typing as jtp
Expand Down Expand Up @@ -42,16 +43,27 @@ def build_from_xml(
mjcf_description: str | pathlib.Path,
assets: dict[str, Any] | None = None,
heightmap: HeightmapCallable | None = None,
heightmap_name: str = "terrain",
heightmap_radius_xy: tuple[float, float] = (1.0, 1.0),
) -> MujocoModelHelper:
"""
Build a Mujoco model from an XML description and an optional assets dictionary.
Build a Mujoco model from an MJCF description.
Args:
mjcf_description: A string containing the XML description of the Mujoco model
mjcf_description:
A string containing the XML description of the Mujoco model
or a path to a file containing the XML description.
assets: An optional dictionary containing the assets of the model.
heightmap: A function in two variables that returns the height of a terrain
heightmap:
A function in two variables that returns the height of a terrain
in the specified coordinate point.
heightmap_name:
The default name of the heightmap in the MJCF description
to load the corresponding configuration.
heightmap_radius_xy:
The extension of the heightmap in the x-y surface corresponding to the
plane over which the grid of the sampled heightmap is generated.
Returns:
A MujocoModelHelper object.
"""
Expand All @@ -63,15 +75,61 @@ def build_from_xml(
else mjcf_description
)

# Create the Mujoco model from the XML and, optionally, the assets dictionary.
if heightmap is None:
hfield = None

else:

mjcf_description_dict = xmltodict.parse(xml_input=mjcf_description)

# Create a dictionary of all hfield configurations from the MJCF.
hfields = mjcf_description_dict["mujoco"]["asset"].get("hfield", [])
hfields = hfields if isinstance(hfields, list) else [hfields]
hfields_dict = {hfield["@name"]: hfield for hfield in hfields}

if heightmap_name not in hfields_dict:
raise ValueError(f"Heightmap '{heightmap_name}' not found in MJCF")

hfield_element = hfields_dict[heightmap_name]

# Generate the hfield by sampling the heightmap function.
hfield = generate_hfield(
heightmap=heightmap,
samples_xy=(int(hfield_element["@nrow"]), int(hfield_element["@ncol"])),
radius_xy=heightmap_radius_xy,
)

# Update dynamically the '/asset/hfield[@name=heightmap_name]@size' attribute
# with the information of the sampled points.
# This is necessary for correctly rendering the heightmap over the
# specified xy area with the correct z elevation.
size = [float(el) for el in hfield_element["@size"].split(" ")]
size[0], size[1] = heightmap_radius_xy
size[2] = 1.0
size[3] = max(0, -min(hfield))

# Replace the 'size' attribute.
hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size)

# Update the hfield elements of the original MJCF.
# Only the hfield corresponding to 'heightmap_name' was actually edited.
mjcf_description_dict["mujoco"]["asset"]["hfield"] = list(
hfields_dict.values()
)

# Serialize the updated MJCF to XML.
mjcf_description = xmltodict.unparse(
input_dict=mjcf_description_dict, pretty=True
)

# Create the Mujoco model from the XML and, optionally, the dictionary of assets.
model = mj.MjModel.from_xml_string(xml=mjcf_description, assets=assets)
data = mj.MjData(model)

if heightmap:
nrow = model.hfield_nrow.item()
ncol = model.hfield_ncol.item()
new_hfield = generate_hfield(heightmap, (nrow, ncol))
model.hfield_data = new_hfield
# Store the sampled heightmap into the Mujoco model.
if heightmap is not None:
assert hfield is not None
model.hfield_data = hfield

return MujocoModelHelper(model=model, data=data)

Expand Down Expand Up @@ -385,10 +443,13 @@ def _mask_qpos(self, joint_names: tuple[str, ...]) -> npt.NDArray:


def generate_hfield(
heightmap: HeightmapCallable, size: tuple[int, int] = (10, 10)
heightmap: HeightmapCallable,
samples_xy: tuple[int, int] = (11, 11),
radius_xy: tuple[float, float] = (1.0, 1.0),
) -> npt.NDArray:
"""
Generates a numpy array representing the heightmap of
Generate an array with elevation points sampled from a heightmap function.
The map will have the following format:
```
heightmap[0, 0] heightmap[0, 1] ... heightmap[0, size[1]-1]
Expand All @@ -398,17 +459,22 @@ def generate_hfield(
```
Args:
heightmap: A function that takes two arguments (x, y) and returns the height
heightmap:
A function that takes two arguments (x, y) and returns the height
at that point.
size: A tuple of two integers representing the size of the grid.
radius_xy:
A tuple of two floats representing extension of the heightmap in the
x-y surface corresponding to the area over which the grid of the sampled
heightmap is generated.
Returns:
np.ndarray: The terrain heightmap
A flat array of the sampled terrain heightmap.
"""

# Generate the grid.
x = np.linspace(0, 1, size[0])
y = np.linspace(0, 1, size[1])
x = np.linspace(-radius_xy[0], radius_xy[0], samples_xy[0])
y = np.linspace(-radius_xy[1], radius_xy[1], samples_xy[1])

# Generate the heightmap.
return np.array([[heightmap(xi, yi) for xi in x] for yi in y]).flatten()

0 comments on commit 02e45e4

Please sign in to comment.