Skip to content

Commit

Permalink
Improves search to handle smaller search terms. (#4735)
Browse files Browse the repository at this point in the history
Co-authored-by: Agriya Khetarpal <[email protected]>
Co-authored-by: Saransh Chopra <[email protected]>
  • Loading branch information
3 people authored Feb 14, 2025
1 parent b0b0fb6 commit e39d72e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 14 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

- Fixed interpolation bug in `pybamm.QuickPlot` with spatial variables. ([#4841](https://github.com/pybamm-team/PyBaMM/pull/4841))

## Optimizations

- Improved search to handle cases with shorter input strings and provide more relevant results. ([#4735](https://github.com/pybamm-team/PyBaMM/pull/4735))

# [v25.1.1](https://github.com/pybamm-team/PyBaMM/tree/v25.1.1) - 2025-01-20

## Features
Expand Down
66 changes: 52 additions & 14 deletions src/pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def __getitem__(self, key):
f"'{key}' not found. Best matches are {best_matches}"
) from error

def _find_matches(self, search_key: str, known_keys: list[str]):
def _find_matches(
self, search_key: str, known_keys: list[str], min_similarity: float = 0.4
):
"""
Helper method to find exact and partial matches for a given search key.
Expand All @@ -101,13 +103,37 @@ def _find_matches(self, search_key: str, known_keys: list[str]):
The term to search for in the keys.
known_keys : list of str
The list of known dictionary keys to search within.
min_similarity : float, optional
The minimum similarity threshold for a match.
Default is 0.4
"""
exact = [key for key in known_keys if search_key in key.lower()]
partial = difflib.get_close_matches(search_key, known_keys, n=5, cutoff=0.5)
return exact, partial
search_key = search_key.lower()
exact_matches = []
partial_matches = []

for key in known_keys:
key_lower = key.lower()
if search_key in key_lower:
key_words = key_lower.split()

for word in key_words:
similarity = difflib.SequenceMatcher(None, search_key, word).ratio()

def search(self, keys: str | list[str], print_values: bool = False):
if similarity >= min_similarity:
exact_matches.append(key)

else:
partial_matches = difflib.get_close_matches(
search_key, known_keys, n=5, cutoff=0.5
)
return exact_matches, partial_matches

def search(
self,
keys: str | list[str],
print_values: bool = False,
min_similarity: float = 0.4,
):
"""
Search dictionary for keys containing all terms in 'keys'.
If print_values is True, both the keys and values will be printed.
Expand All @@ -121,6 +147,9 @@ def search(self, keys: str | list[str], print_values: bool = False):
print_values : bool, optional
If True, print both keys and values. Otherwise, print only keys.
Default is False.
min_similarity : float, optional
The minimum similarity threshold for a match.
Default is 0.4
"""

if not isinstance(keys, (str, list)) or not all(
Expand All @@ -145,14 +174,23 @@ def search(self, keys: str | list[str], print_values: bool = False):
search_keys = [k.strip().lower() for k in keys if k.strip()]

known_keys = list(self.keys())
known_keys.sort()

# Check for exact matches where all search keys appear together in a key
exact_matches = [
key
for key in known_keys
if all(term in key.lower() for term in search_keys)
]
exact_matches = []
for key in known_keys:
key_lower = key.lower()
if all(term in key_lower for term in search_keys):
key_words = key_lower.split()

# Ensure all search terms match at least one word in the key
if all(
any(
difflib.SequenceMatcher(None, term, word).ratio()
>= min_similarity
for word in key_words
)
for term in search_keys
):
exact_matches.append(key)

if exact_matches:
print(
Expand All @@ -166,7 +204,7 @@ def search(self, keys: str | list[str], print_values: bool = False):
# If no exact matches, iterate over search keys individually
for original_key, search_key in zip(original_keys, search_keys):
exact_key_matches, partial_matches = self._find_matches(
search_key, known_keys
search_key, known_keys, min_similarity
)

if exact_key_matches:
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,22 @@ def test_url_gets_to_stdout(self, mocker):
match="'keys' must be a string or a list of strings, got <class 'int'>",
):
model.variables.search(123)

# Test smaller strings
with mocker.patch("sys.stdout", new=StringIO()) as fake_out:
model.variables.search(["El", "co"], print_values=True)
out = "No matches found for 'El'\nNo matches found for 'co'\n"
assert fake_out.getvalue() == out

# Case where min_similarity is high (0.9)
with mocker.patch("sys.stdout", new=StringIO()) as fake_out:
model.variables.search("electro", min_similarity=0.9)
assert fake_out.getvalue() == "No matches found for 'electro'\n"

# Case where min_similarity is low (0.3)
with mocker.patch("sys.stdout", new=StringIO()) as fake_out:
model.variables.search("electro", min_similarity=0.3)
assert (
fake_out.getvalue()
== "Results for 'electro': ['Electrolyte concentration', 'Electrode potential']\n"
)

0 comments on commit e39d72e

Please sign in to comment.