diff --git a/graphein/protein/tensor/angles.py b/graphein/protein/tensor/angles.py index 3d03f753..3fd47282 100644 --- a/graphein/protein/tensor/angles.py +++ b/graphein/protein/tensor/angles.py @@ -70,10 +70,18 @@ def _extract_torsion_coords( res_atoms = [] idxs = [] + # Whether or not the protein contains selenocysteine + selenium = coords.shape[1] == 38 + # Iterate over residues and grab indices of the atoms for each Chi angle for i, res in enumerate(res_types): res_coords = [] - for angle_coord_set in CHI_ANGLES_ATOMS[res]: + + angle_groups = CHI_ANGLES_ATOMS[res] + if not selenium and res == "SEC": + angle_groups = [] + + for angle_coord_set in angle_groups: res_coords.append([ATOM_NUMBERING[i] for i in angle_coord_set]) idxs.append(i) res_atoms.append(torch.tensor(res_coords, device=coords.device)) @@ -115,6 +123,9 @@ def sidechain_torsion( :return: _description_ :rtype: Union[TorsionTensor, Tuple[TorsionTensor, torch.Tensor]] """ + # Whether or not the protein contains selenocysteine + selenium = coords.shape[1] == 38 + idxs, coords = _extract_torsion_coords(coords, res_types) angles = _dihedral_angle( coords[:, 0, :].unsqueeze(1), @@ -139,7 +150,11 @@ def sidechain_torsion( res_types = copy.deepcopy(res_types) res_types.reverse() for res in res_types: - if res in ["ALA", "GLY", "UNK"]: + PAD_RESIDUES = ["ALA", "GLY", "UNK"] + if not selenium: + PAD_RESIDUES.append("SEC") + + if res in PAD_RESIDUES: post_pad_len += 1 else: break