Skip to content

Commit

Permalink
Increased the flexibility of MultiMolecule.get_pair_dict()
Browse files Browse the repository at this point in the history
  • Loading branch information
Bas van Beek committed Jan 27, 2020
1 parent c0ec45c commit b583af7
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions FOX/classes/multi_mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from collections import abc
from itertools import chain, combinations_with_replacement, zip_longest, islice, repeat
from typing import (
Sequence, Optional, Union, List, Hashable, Callable, Iterable, Dict, Tuple, Any
Sequence, Optional, Union, List, Hashable, Callable, Iterable, Dict, Tuple, Any, Mapping
)

import numpy as np
Expand Down Expand Up @@ -1355,8 +1355,9 @@ def get_dist_mat(self, mol_subset: MolSubset = None,
ret[k] = cdist(a, b)
return ret

@staticmethod
def get_pair_dict(atom_subset: AtomSubset, r: int = 2) -> Dict[str, str]:
def get_pair_dict(self, atom_subset: Union[Sequence[AtomSubset],
Mapping[Hashable, Sequence[AtomSubset]]],
r: int = 2) -> Dict[str, Tuple[np.ndarray, ...]]:
"""Take a subset of atoms and return a dictionary.
Parameters
Expand All @@ -1370,15 +1371,16 @@ def get_pair_dict(atom_subset: AtomSubset, r: int = 2) -> Dict[str, str]:
The length of the to-be returned subsets.
"""
values = list(combinations_with_replacement(atom_subset, r))

if not isinstance(next(iter(atom_subset)), str):
str_ = 'series' + ''.join(' {:d}' for _ in values[0])
return {str_.format(*[i.index(j) for j in i]): i for i in values}

if isinstance(atom_subset, abc.Mapping):
key_iter = (str(i) for i in atom_subset.keys())
value_iter = (self._get_atom_subset(i) for i in atom_subset.values())
else:
str_ = ''.join(' {}' for _ in values[0])[1:]
return {str_.format(*i): i for i in values}
key_iter = ((j if isinstance(j, abc.Hashable) else i) for i, j in enumerate(atom_subset)) # noqa
value_iter = (self._get_atom_subset(i) for i in atom_subset)

key_gen = combinations_with_replacement(key_iter, r)
value_gen = combinations_with_replacement(value_iter, r)
return {' '.join(str(i) for i in k): v for k, v in zip(key_gen, value_gen)}

"""#################################### Power spectrum ###################################"""

Expand Down Expand Up @@ -1600,12 +1602,10 @@ def init_adf(self, mol_subset: MolSubset = None,
m_subset = self._get_mol_subset(mol_subset)
at_subset = self._get_atom_subset(atom_subset, as_array=True)

# Construct a dictionary unique atom-pair identifiers as keys
atom_pairs = self.get_pair_dict(atom_subset or sorted(self.atoms, key=str), r=3)

# Slice this MultiMolecule instance based on **atom_subset** and **mol_subset**
del_atom = np.arange(0, self.shape[1])[~at_subset]
mol = self.delete_atoms(del_atom)[m_subset]
atom_pairs = mol.get_pair_dict(atom_subset or sorted(mol.atoms, key=str), r=3)
for k, v in atom_pairs.items():
v_new = []
for at in v:
Expand Down

0 comments on commit b583af7

Please sign in to comment.