Skip to content

Commit

Permalink
Fixed an issue where '==' was not properly parsed; raise a KeyError i…
Browse files Browse the repository at this point in the history
…f brand new atoms are encountered
  • Loading branch information
Bas van Beek committed Jan 22, 2020
1 parent a5ab4bf commit c5b38c6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
30 changes: 20 additions & 10 deletions FOX/functions/charge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,21 @@ def invert_partial_ufunc(ufunc: functools.partial) -> Callable:

def assign_constraints(constraints: Union[str, Iterable[str]],
param: pd.DataFrame, idx_key: str) -> None:
operator_set = {'>', '<', '>=', '<=', '*', '=='}
operator_set = {'>', '<', '*', '=='}

# Parse integers and floats
if isinstance(constraints, str):
constraints = [constraints]

constrain_list = []
for item in constraints:
intersect = operator_set.intersection(item) # Identify all operators
if not intersect:
continue
for i in operator_set: # Sanitize all operators; ensure they are surrounded by spaces
item = item.replace(i, f'~{i}~')

for i in intersect: # Sanitize all operators; ensure they are surrounded by spaces
item = item.replace(i, f' {i} ')
item_list = [i.strip().rstrip() for i in item.split('~')]
if len(item_list) == 1:
continue

item_list = item.split()
for i, j in enumerate(item_list): # Convert strings to floats where possible
try:
float_j = float(j)
Expand Down Expand Up @@ -225,22 +224,29 @@ def _gt_lt_constraints(constrain: list, param: pd.DataFrame, idx_key: str) -> No
if isinstance(at, float):
at, value = value, at
operator = _INVERT[operator]
if (idx_key, at) not in param.index:
raise KeyError(f"Assigning invalid constraint '({' '.join(str(i) for i in constrain)})'"
f"; no parameter available of type ({repr(idx_key)}, {repr(at)})")
param.at[(idx_key, at), operator] = value


def _find_float(iterable: Tuple[str, str]) -> Tuple[str, float]:
"""Take an iterable of 2 strings and identify which element can be converted into a float."""
i, j = iterable
try:
i, j = iterable
except ValueError:
return iterable[0], 1.0

try:
return j, float(i)
except ValueError:
return i, float(j)


def _eq_constraints(constrain: list, param: pd.DataFrame, idx_key: str) -> None:
def _eq_constraints(constrain_: list, param: pd.DataFrame, idx_key: str) -> None:
"""Parse :math:`a = i * b`-type constraints."""
constrain_dict: Dict[str, functools.partial] = {}
constrain = ''.join(str(i) for i in constrain).split('==')
constrain = ''.join(str(i) for i in constrain_).split('==')
iterator = iter(constrain)

# Set the first item; remove any prefactor and compensate al other items if required
Expand All @@ -262,5 +268,9 @@ def _eq_constraints(constrain: list, param: pd.DataFrame, idx_key: str) -> None:

# Update the dataframe
param['constraints'] = None
for k in constrain_dict:
if (idx_key, k) not in param.index:
raise KeyError(f"Assigning invalid constraint '({' '.join(str(i) for i in constrain_)})"
f"'; no parameter available of type ({repr(idx_key)}, {repr(at)})")
for at, _ in param.loc[idx_key].iterrows():
param.at[(idx_key, at), 'constraints'] = constrain_dict
12 changes: 8 additions & 4 deletions tests/test_charge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def test_assign_constraints() -> None:
"""Test :func:`assign_constraints`."""
df = pd.DataFrame(index=pd.MultiIndex.from_product(
[['key'], ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I']]
[['key'], ['H', 'C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'H H']]
))
df['constraints'] = None
df['min'] = -np.inf
Expand All @@ -28,14 +28,15 @@ def test_assign_constraints() -> None:
'4<O',
'1 < F< 2.0',
'2 > P >1.0',
'S == 2 * Cl == 0.5*Br == 1* I'
'S == 2 * Cl == 0.5*Br == 1* I',
'1 < H H < 2'
]

assign_constraints(constraints, df, 'key')

inf = np.inf
min_ar = np.array([-inf, 2.0, -inf, 4.0, 1.0, 1.0, -inf, -inf, -inf, -inf])
max_ar = np.array([1.0, inf, 3.0, inf, 2.0, 2.0, inf, inf, inf, inf])
min_ar = np.array([-inf, 2.0, -inf, 4.0, 1.0, 1.0, -inf, -inf, -inf, -inf, 1.0])
max_ar = np.array([1.0, inf, 3.0, inf, 2.0, 2.0, inf, inf, inf, inf, 2.0])
np.testing.assert_allclose(df['min'], min_ar)
np.testing.assert_allclose(df['max'], max_ar)

Expand All @@ -50,3 +51,6 @@ def test_assign_constraints() -> None:
assertion.isinstance(v1, functools.partial)
assertion.is_(v1.func, v2.func)
assertion.eq(v1.args, v2.args)

assertion.assert_(assign_constraints, ['bob == bub'], df, 'key', exception=KeyError)
assertion.assert_(assign_constraints, ['bob > 1.0'], df, 'key', exception=KeyError)

0 comments on commit c5b38c6

Please sign in to comment.