diff --git a/examples/manopth_demo.py b/examples/manopth_demo.py index f5849cb..58f6a5d 100644 --- a/examples/manopth_demo.py +++ b/examples/manopth_demo.py @@ -18,7 +18,8 @@ action='store_true', help="Disable display output of ManoLayer given random inputs") parser.add_argument('--side', default='left', choices=['left', 'right']) - parser.add_argument('--random_shape', action='store_true') + parser.add_argument('--random_shape', action='store_true', help="Random hand shape") + parser.add_argument('--rand_mag', type=float, default=1, help="Controls pose variability") parser.add_argument( '--flat_hand_mean', action='store_true', @@ -31,6 +32,7 @@ "Use for quick profiling of forward and backward pass accross ManoLayer" ) parser.add_argument('--mano_root', default='mano/models') + parser.add_argument('--root_rot_mode', default='axisang', choices=['rot6d', 'axisang']) parser.add_argument( '--mano_ncomps', default=6, type=int, help="Number of PCA components") args = parser.parse_args() @@ -41,12 +43,16 @@ flat_hand_mean=args.flat_hand_mean, side=args.side, mano_root=args.mano_root, - ncomps=args.mano_ncomps) + ncomps=args.mano_ncomps, + root_rot_mode=args.root_rot_mode) n_components = 6 - rot = 3 + if args.root_rot_mode == 'axisang': + rot = 3 + else: + rot = 6 # Generate random pose coefficients - pose_params = torch.rand(args.batch_size, n_components + rot) + pose_params = args.rand_mag * torch.rand(args.batch_size, n_components + rot) pose_params.requires_grad = True if args.random_shape: shape = torch.rand(args.batch_size, 10) diff --git a/manopth/manolayer.py b/manopth/manolayer.py index 6c9a9ac..3cb231a 100644 --- a/manopth/manolayer.py +++ b/manopth/manolayer.py @@ -5,7 +5,7 @@ from torch.nn import Module from mano.webuser.smpl_handpca_wrapper_HAND_only import ready_arguments -from manopth import rodrigues_layer, rotproj +from manopth import rodrigues_layer, rotproj, rot6d from manopth.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, subtract_flat_id, make_list) @@ -22,7 +22,8 @@ def __init__(self, ncomps=6, side='right', mano_root='mano/models', - use_pca=True): + use_pca=True, + root_rot_mode='axisang'): """ Args: center_idx: index of center joint in our computations, @@ -38,10 +39,14 @@ def __init__(self, super().__init__() self.center_idx = center_idx - self.rot = 3 + if root_rot_mode == 'axisang': + self.rot = 3 + else: + self.rot = 6 self.flat_hand_mean = flat_hand_mean self.side = side self.use_pca = use_pca + self.root_rot_mode = root_rot_mode if use_pca: self.ncomps = ncomps else: @@ -114,17 +119,28 @@ def forward(self, batch_size = th_pose_coeffs.shape[0] # Get axis angle from PCA components and coefficients if self.use_pca: + # Remove global rot coeffs th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot + self.ncomps] + # PCA components --> axis angles th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps) + + # Concatenate back global rot th_full_pose = torch.cat([ th_pose_coeffs[:, :self.rot], self.th_hands_mean + th_full_hand_pose ], 1) - th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose) - th_full_pose = th_full_pose.view(batch_size, -1, 3) - root_rot = rodrigues_layer.batch_rodrigues( - th_full_pose[:, 0]).view(batch_size, 3, 3) + if self.root_rot_mode == 'axisang': + # compute rotation matrixes from axis-angle while skipping global rotation + th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose) + th_full_pose = th_full_pose.view(batch_size, -1, 3) + root_rot = rodrigues_layer.batch_rodrigues( + th_full_pose[:, 0]).view(batch_size, 3, 3) + else: + # th_posemap offsets by 3, so add offset or 3 to get to self.rot=6 + th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose[:, 3:]) + root_rot = rot6d.compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6]) + th_full_pose = th_full_pose.view(batch_size, -1, 3) else: assert th_pose_coeffs.dim() == 4, ( 'When not self.use_pca, ' diff --git a/manopth/rot6d.py b/manopth/rot6d.py new file mode 100644 index 0000000..12685c7 --- /dev/null +++ b/manopth/rot6d.py @@ -0,0 +1,44 @@ +import torch + + +def compute_rotation_matrix_from_ortho6d(poses): + """ + Code from + https://github.com/papagina/RotationContinuity + On the Continuity of Rotation Representations in Neural Networks + Zhou et al. CVPR19 + https://zhouyisjtu.github.io/project_rotation/rotation.html + """ + x_raw = poses[:, 0:3] # batch*3 + y_raw = poses[:, 3:6] # batch*3 + + x = normalize_vector(x_raw) # batch*3 + z = cross_product(x, y_raw) # batch*3 + z = normalize_vector(z) # batch*3 + y = cross_product(z, x) # batch*3 + + x = x.view(-1, 3, 1) + y = y.view(-1, 3, 1) + z = z.view(-1, 3, 1) + matrix = torch.cat((x, y, z), 2) # batch*3*3 + return matrix + + +def normalize_vector(v): + batch = v.shape[0] + v_mag = torch.sqrt(v.pow(2).sum(1)) # batch + v_mag = torch.max(v_mag, v.new([1e-8])) + v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1]) + v = v/v_mag + return v + + +def cross_product(u, v): + batch = u.shape[0] + i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] + j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] + k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] + + out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) + + return out