-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest_api_model.py
405 lines (328 loc) · 13.4 KB
/
test_api_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import pathlib
import jax
import jax.numpy as jnp
import numpy as np
import pytest
import rod
import jaxsim.api as js
from jaxsim import VelRepr
from . import utils_idyntree
def test_model_creation_and_reduction(
jaxsim_model_ergocub: js.model.JaxSimModel,
prng_key: jax.Array,
):
model_full = jaxsim_model_ergocub
key, subkey = jax.random.split(prng_key, num=2)
data_full = js.data.random_model_data(
model=model_full,
key=subkey,
velocity_representation=VelRepr.Inertial,
base_pos_bounds=((0, 0, 0.8), (0, 0, 0.8)),
)
# =====
# Tests
# =====
# Check that the data of the full model is valid.
assert data_full.valid(model=model_full)
# Build the ROD model from the original description.
assert isinstance(model_full.built_from, (str, pathlib.Path))
rod_sdf = rod.Sdf.load(sdf=model_full.built_from)
assert len(rod_sdf.models()) == 1
# Get all non-fixed joint names from the description.
joint_names_in_description = [
j.name for j in rod_sdf.models()[0].joints() if j.type != "fixed"
]
# Check that all non-fixed joints are in the full model.
assert set(joint_names_in_description) == set(model_full.joint_names())
# ================
# Reduce the model
# ================
# Get the names of the joints to keep in the reduced model.
reduced_joints = tuple(
j
for j in model_full.joint_names()
if "camera" not in j
and "neck" not in j
and "wrist" not in j
and "thumb" not in j
and "index" not in j
and "middle" not in j
and "ring" not in j
and "pinkie" not in j
#
and "elbow" not in j
and "shoulder" not in j
and "torso" not in j
and "r_knee" not in j
)
# Reduce the model.
# Note: here we also specify a non-zero position of the removed joints.
# The process should take into account the corresponding joint transforms
# when the link-joint-link chains are lumped together.
model_reduced = js.model.reduce(
model=model_full,
considered_joints=reduced_joints,
locked_joint_positions={
name: pos
for name, pos in zip(
model_full.joint_names(),
data_full.joint_positions(
model=model_full, joint_names=model_full.joint_names()
).tolist(),
)
},
)
# Check DoFs.
assert model_full.dofs() != model_reduced.dofs()
# Check that all non-fixed joints are in the reduced model.
assert set(reduced_joints) == set(model_reduced.joint_names())
# Build the data of the reduced model.
data_reduced = js.data.JaxSimModelData.build(
model=model_reduced,
base_position=data_full.base_position(),
base_quaternion=data_full.base_orientation(dcm=False),
joint_positions=data_full.joint_positions(
model=model_full, joint_names=model_reduced.joint_names()
),
base_linear_velocity=data_full.base_velocity()[0:3],
base_angular_velocity=data_full.base_velocity()[3:6],
joint_velocities=data_full.joint_velocities(
model=model_full, joint_names=model_reduced.joint_names()
),
velocity_representation=data_full.velocity_representation,
)
# =====================
# Test against iDynTree
# =====================
kin_dyn_full = utils_idyntree.build_kindyncomputations_from_jaxsim_model(
model=model_full, data=data_full
)
kin_dyn_reduced = utils_idyntree.build_kindyncomputations_from_jaxsim_model(
model=model_reduced, data=data_reduced
)
# Check that the total mass is preserved.
assert kin_dyn_full.total_mass() == pytest.approx(kin_dyn_reduced.total_mass())
# Check that the CoM position match.
assert kin_dyn_full.com_position() == pytest.approx(kin_dyn_reduced.com_position())
assert kin_dyn_full.com_position() == pytest.approx(
js.com.com_position(model=model_reduced, data=data_reduced)
)
# Check that link transforms match.
for link_name, link_idx in zip(
model_reduced.link_names(),
js.link.names_to_idxs(
model=model_reduced, link_names=model_reduced.link_names()
),
):
assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx(
kin_dyn_full.frame_transform(frame_name=link_name)
)
assert kin_dyn_reduced.frame_transform(frame_name=link_name) == pytest.approx(
js.link.transform(
model=model_reduced, data=data_reduced, link_index=link_idx
)
)
def test_model_properties(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model(
model=model, data=data
)
# =====
# Tests
# =====
m_idt = kin_dyn.total_mass()
m_js = js.model.total_mass(model=model)
assert pytest.approx(m_idt) == m_js
J_Bh_idt = kin_dyn.total_momentum_jacobian()
J_Bh_js = js.model.total_momentum_jacobian(model=model, data=data)
assert pytest.approx(J_Bh_idt) == J_Bh_js
h_tot_idt = kin_dyn.total_momentum()
h_tot_js = js.model.total_momentum(model=model, data=data)
assert pytest.approx(h_tot_idt) == h_tot_js
M_locked_idt = kin_dyn.locked_spatial_inertia()
M_locked_js = js.model.locked_spatial_inertia(model=model, data=data)
assert pytest.approx(M_locked_idt) == M_locked_js
J_avg_idt = kin_dyn.average_velocity_jacobian()
J_avg_js = js.model.average_velocity_jacobian(model=model, data=data)
assert pytest.approx(J_avg_idt) == J_avg_js
v_avg_idt = kin_dyn.average_velocity()
v_avg_js = js.model.average_velocity(model=model, data=data)
assert pytest.approx(v_avg_idt) == v_avg_js
def test_model_rbda(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
velocity_representation: VelRepr,
):
model = jaxsim_models_types
key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model(
model=model, data=data
)
# =====
# Tests
# =====
# Support both fixed-base and floating-base models by slicing the first six rows
sl = np.s_[0:] if model.floating_base() else np.s_[6:]
# Mass matrix
M_idt = kin_dyn.mass_matrix()
M_js = js.model.free_floating_mass_matrix(model=model, data=data)
assert pytest.approx(M_idt[sl, sl]) == M_js[sl, sl]
# Gravity forces
g_idt = kin_dyn.gravity_forces()
g_js = js.model.free_floating_gravity_forces(model=model, data=data)
assert pytest.approx(g_idt[sl]) == g_js[sl]
# Bias forces
h_idt = kin_dyn.bias_forces()
h_js = js.model.free_floating_bias_forces(model=model, data=data)
assert pytest.approx(h_idt[sl]) == h_js[sl]
# Forward kinematics
HH_js = js.model.forward_kinematics(model=model, data=data)
HH_idt = jnp.stack(
[kin_dyn.frame_transform(frame_name=name) for name in model.link_names()]
)
assert pytest.approx(HH_idt) == HH_js
# Bias accelerations
Jν_js = js.model.link_bias_accelerations(model=model, data=data)
Jν_idt = jnp.stack(
[kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()]
)
assert pytest.approx(Jν_idt) == Jν_js
def test_model_jacobian(
jaxsim_models_types: js.model.JaxSimModel,
prng_key: jax.Array,
):
model = jaxsim_models_types
key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=VelRepr.Inertial
)
# =====
# Tests
# =====
# Create random references (joint torques and link forces)
key, subkey1, subkey2 = jax.random.split(key, num=3)
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),
data=data,
velocity_representation=data.velocity_representation,
)
# Remove the force applied to the base link if the model is fixed-base
if not model.floating_base():
references = references.apply_link_forces(
forces=jnp.atleast_2d(jnp.zeros(6)),
model=model,
data=data,
link_names=(model.base_link(),),
additive=False,
)
# Get the J.T @ f product in inertial-fixed input/output representation.
# We use doubly right-trivialized jacobian with inertial-fixed 6D forces.
with references.switch_velocity_representation(VelRepr.Inertial):
with data.switch_velocity_representation(VelRepr.Inertial):
f = references.link_forces(model=model, data=data)
assert f == pytest.approx(references.input.physics_model.f_ext)
J = js.model.generalized_free_floating_jacobian(model=model, data=data)
JTf_inertial = jnp.einsum("l6g,l6->g", J, f)
for vel_repr in [VelRepr.Body, VelRepr.Mixed]:
with references.switch_velocity_representation(vel_repr):
# Get the jacobian having an inertial-fixed input representation (so that
# it computes the same quantity computed above) and an output representation
# compatible with the frame in which the external forces are expressed.
with data.switch_velocity_representation(VelRepr.Inertial):
J = js.model.generalized_free_floating_jacobian(
model=model, data=data, output_vel_repr=vel_repr
)
# Get the forces in the tested representation and compute the product
# O_J_WL_W.T @ O_f, producing a generalized acceleration in W.
# The resulting acceleration can be tested again the one computed before.
with data.switch_velocity_representation(vel_repr):
f = references.link_forces(model=model, data=data)
JTf_other = jnp.einsum("l6g,l6->g", J, f)
assert pytest.approx(JTf_inertial) == JTf_other, vel_repr.name
def test_model_fd_id_consistency(
jaxsim_models_types: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jax.Array,
):
model = jaxsim_models_types
key, subkey = jax.random.split(prng_key, num=2)
data = js.data.random_model_data(
model=model, key=subkey, velocity_representation=velocity_representation
)
# =====
# Tests
# =====
# Create random references (joint torques and link forces)
key, subkey1, subkey2 = jax.random.split(key, num=3)
references = js.references.JaxSimModelReferences.build(
model=model,
joint_force_references=10 * jax.random.uniform(subkey1, shape=(model.dofs(),)),
link_forces=jax.random.uniform(subkey2, shape=(model.number_of_links(), 6)),
data=data,
velocity_representation=data.velocity_representation,
)
# Remove the force applied to the base link if the model is fixed-base
if not model.floating_base():
references = references.apply_link_forces(
forces=jnp.atleast_2d(jnp.zeros(6)),
model=model,
data=data,
link_names=(model.base_link(),),
additive=False,
)
# Compute forward dynamics with ABA
v̇_WB_aba, s̈_aba = js.model.forward_dynamics_aba(
model=model,
data=data,
joint_forces=references.joint_force_references(),
link_forces=references.link_forces(model=model, data=data),
)
# Compute forward dynamics with CRB
v̇_WB_crb, s̈_crb = js.model.forward_dynamics_crb(
model=model,
data=data,
joint_forces=references.joint_force_references(),
link_forces=references.link_forces(model=model, data=data),
)
assert pytest.approx(s̈_aba) == s̈_crb
assert pytest.approx(v̇_WB_aba) == v̇_WB_crb
# Compute inverse dynamics with the quantities computed by forward dynamics
fB_id, τ_id = js.model.inverse_dynamics(
model=model,
data=data,
joint_accelerations=s̈_aba,
base_acceleration=v̇_WB_aba,
link_forces=references.link_forces(model=model, data=data),
)
# Check consistency between FD and ID
assert pytest.approx(τ_id) == references.joint_force_references(model=model)
assert pytest.approx(fB_id, abs=1e-9) == jnp.zeros(6)
if model.floating_base():
# If we remove the base 6D force from the inputs, we should find it as output.
fB_id, τ_id = js.model.inverse_dynamics(
model=model,
data=data,
joint_accelerations=s̈_aba,
base_acceleration=v̇_WB_aba,
link_forces=references.link_forces(model=model, data=data)
.at[0]
.set(jnp.zeros(6)),
)
assert pytest.approx(τ_id) == references.joint_force_references(model=model)
assert (
pytest.approx(fB_id, abs=1e-9)
== references.link_forces(model=model, data=data)[0]
)