-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
48b46fa
commit deca12a
Showing
12 changed files
with
302 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,38 @@ | ||
# pruning-radix-trie | ||
# Trie Search Library | ||
|
||
A Python library for efficient string search and retrieval with spell correction. | ||
|
||
## Features | ||
|
||
- Efficient string insertion and search using a Trie (Radix Tree) data structure. | ||
- Spell correction using Levenshtein distance. | ||
- Configurable search limit and max edit distance for spell correction. | ||
- Term frequency count for ranking search results. | ||
- Support for additional metadata like entity type, neighbors, and canonical form. | ||
|
||
## Installation | ||
|
||
(Provide installation instructions, e.g., using pip) | ||
|
||
## Usage | ||
|
||
```python | ||
from trie_search import Trie | ||
|
||
# Create a Trie instance | ||
trie = Trie() | ||
|
||
# Insert words into the Trie | ||
trie.insert("apple", 5, "fruit", [], "Apple") | ||
trie.insert("banana", 7, "fruit", [], "Banana") | ||
|
||
# Search for a word | ||
results = trie.search("appl") | ||
print(results) | ||
|
||
# Search with spell correction | ||
results = trie.search_with_correction("banan") | ||
print(results) | ||
|
||
# Update term frequency count | ||
trie.update("apple", 3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from trie import Trie |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Configurable parameters for the PruningRadixTrieLib | ||
|
||
MAX_EDIT_DISTANCE = 2 # Maximum edit distance for spell correction. | ||
SEARCH_LIMIT = 10 # Default limit for search results. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
class Node: | ||
def __init__(self, value=""): | ||
self.value = value | ||
self.children = [] | ||
self.is_end_of_word = False | ||
self.full_text = None | ||
self.termFrequencyCount = 0 | ||
self.entity_type = None | ||
self.nearest_neighbors = [] | ||
self.canonical_form = None | ||
|
||
# Helper function to find a child node with a given prefix | ||
def find_child(self, prefix): | ||
low, high = 0, len(self.children) - 1 | ||
while low <= high: | ||
mid = (low + high) // 2 | ||
if self.children[mid].value.startswith(prefix): | ||
return self.children[mid] | ||
elif self.children[mid].value < prefix: | ||
low = mid + 1 | ||
else: | ||
high = mid - 1 | ||
return None | ||
|
||
def insert_child(self, child): | ||
index = 0 | ||
while index < len(self.children) and self.children[index].value < child.value: | ||
index += 1 | ||
self.children.insert(index, child) | ||
|
||
# Updated method to get neighbors | ||
def get_neighbors(self): | ||
return self.nearest_neighbors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import unittest | ||
from trie import Trie | ||
|
||
class TestTrie(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.trie = Trie() | ||
self.trie.insert("apple", termFrequencyCount=5, entity_type="fruit", neighbors=["apples"], canonical="Apple") | ||
|
||
def test_search(self): | ||
results = self.trie.search_with_correction("aple") | ||
self.assertEqual(results[0]['text'], "apple") | ||
|
||
def test_search_without_correction(self): | ||
results = self.trie.search_with_correction("aple", correct_spelling=False) | ||
self.assertEqual(len(results), 0) | ||
|
||
def test_insert_and_search(self): | ||
trie = Trie() | ||
trie.insert("apple", 5, "fruit", [], "Apple") | ||
trie.insert("appetite", 3, "noun", [], "Appetite") | ||
trie.insert("apex", 2, "noun", [], "Apex") | ||
trie.insert("banana", 7, "fruit", [], "Banana") | ||
|
||
assert trie.search("apple")[0]['text'] == "apple" | ||
assert trie.search("app")[0]['text'] == "apple" | ||
assert trie.search("ban")[0]['text'] == "banana" | ||
assert not trie.search("berry") | ||
|
||
def test_search_with_correction(self): | ||
trie = Trie() | ||
trie.insert("apple", 5, "fruit", [], "Apple") | ||
trie.insert("appetite", 3, "noun", [], "Appetite") | ||
trie.insert("apex", 2, "noun", [], "Apex") | ||
trie.insert("banana", 7, "fruit", [], "Banana") | ||
|
||
results = trie.search_with_correction("appl") | ||
self.assertTrue(results and results[0]['text'] == "apple") | ||
|
||
results = trie.search_with_correction("banan") | ||
self.assertTrue(results and results[0]['text'] == "banana") | ||
|
||
results = trie.search_with_correction("bery") | ||
self.assertFalse(results) # Expecting no results for "bery" | ||
|
||
def test_update(self): | ||
trie = Trie() | ||
trie.insert("apple", 5, "fruit", [], "Apple") | ||
trie.update("apple", 3) | ||
assert trie.search("apple")[0]['termFrequencyCount'] == 8 | ||
|
||
try: | ||
trie.update("berry", 2) | ||
except ValueError as e: | ||
assert str(e) == "Word 'berry' not found in trie." | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from node import Node | ||
from utils import calculate_levenshtein_distance | ||
from config import MAX_EDIT_DISTANCE, SEARCH_LIMIT | ||
import heapq | ||
|
||
class Trie: | ||
def __init__(self, max_edit_distance=None, search_limit=None): | ||
self.root = Node() | ||
self.max_edit_distance = max_edit_distance or MAX_EDIT_DISTANCE | ||
self.search_limit = search_limit or SEARCH_LIMIT | ||
|
||
def insert(self, word, termFrequencyCount, entity_type, neighbors, canonical): | ||
node = self.root | ||
i = 0 | ||
while i < len(word): | ||
prefix = self._longest_prefix(word[i:], node) | ||
if not prefix: | ||
new_child = Node(word[i:]) | ||
node.insert_child(new_child) | ||
node = new_child | ||
i += len(word[i:]) | ||
else: | ||
child = node.find_child(prefix) | ||
if prefix != word[i:]: | ||
# Split the node | ||
split_point = len(word[i:]) | ||
new_node_value = prefix[split_point:] | ||
old_node = child | ||
|
||
new_node = Node(new_node_value) | ||
new_node.children = old_node.children | ||
new_node.is_end_of_word = old_node.is_end_of_word | ||
new_node.full_text = old_node.full_text | ||
new_node.termFrequencyCount = old_node.termFrequencyCount | ||
new_node.entity_type = old_node.entity_type | ||
new_node.nearest_neighbors = old_node.nearest_neighbors | ||
new_node.canonical_form = old_node.canonical_form | ||
|
||
node.children.remove(old_node) | ||
node.insert_child(Node(prefix[:split_point])) | ||
node = node.find_child(prefix[:split_point]) | ||
node.insert_child(new_node) | ||
|
||
i += split_point | ||
else: | ||
node = child | ||
i += len(prefix) | ||
|
||
node.is_end_of_word = True | ||
node.full_text = word | ||
node.termFrequencyCount = termFrequencyCount | ||
node.entity_type = entity_type | ||
node.nearest_neighbors = [self.search(neighbor, limit=1)[0] for neighbor in neighbors if self.search(neighbor, limit=1)] | ||
node.canonical_form = canonical | ||
|
||
|
||
def print_trie(self, node=None, indent=""): | ||
if node is None: | ||
node = self.root | ||
print(indent + node.value + ("*" if node.is_end_of_word else "")) | ||
for child in node.children: | ||
self.print_trie(child, indent + " ") | ||
|
||
|
||
def update(self, word, termFrequencyCount_increment): | ||
node = self._find_word_node(word) | ||
if node: | ||
node.termFrequencyCount += termFrequencyCount_increment | ||
else: | ||
raise ValueError(f"Word '{word}' not found in trie.") | ||
|
||
def _find_word_node(self, word): | ||
node = self.root | ||
i = 0 | ||
while i < len(word): | ||
prefix = self._longest_prefix(word[i:], node) | ||
if not prefix: | ||
return None | ||
node = node.find_child(prefix) | ||
i += len(prefix) | ||
return node if node.is_end_of_word else None | ||
|
||
|
||
def search(self, query, limit=SEARCH_LIMIT): | ||
results = [] | ||
node = self.root | ||
i = 0 | ||
while i < len(query): | ||
prefix = self._longest_prefix(query[i:], node) | ||
if not prefix: | ||
break | ||
node = node.find_child(prefix) | ||
i += len(prefix) | ||
self._dfs(node, query, results, limit) | ||
results.sort(key=lambda x: x['termFrequencyCount'], reverse=True) | ||
return results[:limit] | ||
|
||
def _dfs(self, node, prefix, results, limit): | ||
if len(results) >= limit: | ||
return | ||
if node.is_end_of_word and node.full_text.startswith(prefix): | ||
results.append({ | ||
"text": node.full_text, | ||
"termFrequencyCount": node.termFrequencyCount, | ||
"type": node.entity_type, | ||
"neighbors": [neighbor.full_text for neighbor in node.get_neighbors()], | ||
"canonical": node.canonical_form | ||
}) | ||
for child in node.children: | ||
self._dfs(child, prefix, results, limit) | ||
|
||
def search_with_correction(self, query, correct_spelling=True): | ||
results = self.search(query, self.search_limit) | ||
if not results and correct_spelling: | ||
closest_word = self.find_closest_word(query) | ||
if closest_word: | ||
results = self.search(closest_word, self.search_limit) | ||
return results | ||
|
||
|
||
def find_closest_word(self, query): | ||
min_distance = float('inf') | ||
closest_word = None | ||
for word in self.all_words: | ||
distance = calculate_levenshtein_distance(query, word) | ||
if distance <= self.max_edit_distance and distance < min_distance: | ||
min_distance = distance | ||
closest_word = word | ||
return closest_word | ||
|
||
# New method to get all words in the Trie for spell correction | ||
@property | ||
def all_words(self): | ||
words = [] | ||
|
||
def _collect_words(node, current_word): | ||
if node.is_end_of_word: | ||
words.append(current_word + node.value) | ||
for child in node.children: | ||
_collect_words(child, current_word + node.value) | ||
|
||
_collect_words(self.root, "") | ||
return words | ||
|
||
|
||
def _longest_prefix(self, word, node): | ||
child = node.find_child(word[0]) | ||
if child and word.startswith(child.value): | ||
return child.value | ||
return "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
|
||
def calculate_levenshtein_distance(s1, s2): | ||
if len(s1) > len(s2): | ||
s1, s2 = s2, s1 | ||
|
||
distances = range(len(s1) + 1) | ||
for index2, char2 in enumerate(s2): | ||
new_distances = [index2 + 1] | ||
for index1, char1 in enumerate(s1): | ||
if char1 == char2: | ||
new_distances.append(distances[index1]) | ||
else: | ||
new_distances.append(1 + min((distances[index1], distances[index1 + 1], new_distances[-1]))) | ||
distances = new_distances | ||
|
||
return distances[-1] | ||
|
||
# TODO: Implement SymSpell or other efficient spell correction algorithms for optimization. |