Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 29, 2024
1 parent 7234614 commit a597f85
Showing 1 changed file with 2 additions and 41 deletions.
43 changes: 2 additions & 41 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,9 @@ def test_box_with_external_forces(
additive=False,
)

# Create the integrator.
integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model, data=data0, system_dynamics=js.ode.system_dynamics
)
)

# Initialize the integrator.
tf = 0.5
T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int)
state_aux_dict = None

# Copy the initial data...
data = data0.copy()
Expand All @@ -83,8 +75,6 @@ def test_box_with_external_forces(
data, state_aux_dict = js.model.step(
model=model,
data=data,
integrator=integrator,
integrator_state=state_aux_dict,
link_forces=references.link_forces(model=model, data=data),
)

Expand Down Expand Up @@ -154,8 +144,6 @@ def test_box_with_zero_gravity(
# Copy the initial data...
data = data0.copy()

state_aux_dict = None

# ... and step the simulation.
for _ in T:

Expand All @@ -164,11 +152,10 @@ def test_box_with_zero_gravity(
references.switch_velocity_representation(velocity_representation),
):

data, state_aux_dict = js.model.step(
data, _ = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
integrator_state=state_aux_dict,
)

# Check that the box moved as expected.
Expand All @@ -186,33 +173,13 @@ def run_simulation(
tf: jtp.FloatLike,
) -> js.data.JaxSimModelData:

@functools.cache
def get_integrator() -> tuple[jaxsim.integrators.Integrator, dict[str, jtp.PyTree]]:

# Create the integrator.
integrator = jaxsim.integrators.fixed_step.Heun2.build(
fsal_enabled_if_supported=False,
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model,
data=data_t0,
system_dynamics=js.ode.system_dynamics,
),
)

# Initialize the integrator state.
integrator_state_t0 = {}

return integrator, integrator_state_t0

# Initialize the integration horizon.
T_ns = jnp.arange(start=0.0, stop=int(tf * 1e9), step=int(dt * 1e9)).astype(int)

# Initialize the simulation data.
integrator = None
integrator_state = None
data = data_t0.copy()

for t_ns in T_ns:
for _ in T_ns:

match model.contact_model:

Expand All @@ -226,16 +193,10 @@ def get_integrator() -> tuple[jaxsim.integrators.Integrator, dict[str, jtp.PyTre

case _:

integrator, integrator_state = (
get_integrator() if t_ns == 0 else (integrator, integrator_state)
)

data, integrator_state = js.model.step(
model=model,
data=data,
dt=dt,
integrator=integrator,
integrator_state=integrator_state,
)

return data
Expand Down

0 comments on commit a597f85

Please sign in to comment.