Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Dharin-shah committed Oct 12, 2023
1 parent 48b46fa commit deca12a
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 1 deletion.
39 changes: 38 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,38 @@
# pruning-radix-trie
# Trie Search Library

A Python library for efficient string search and retrieval with spell correction.

## Features

- Efficient string insertion and search using a Trie (Radix Tree) data structure.
- Spell correction using Levenshtein distance.
- Configurable search limit and max edit distance for spell correction.
- Term frequency count for ranking search results.
- Support for additional metadata like entity type, neighbors, and canonical form.

## Installation

(Provide installation instructions, e.g., using pip)

## Usage

```python
from trie_search import Trie

# Create a Trie instance
trie = Trie()

# Insert words into the Trie
trie.insert("apple", 5, "fruit", [], "Apple")
trie.insert("banana", 7, "fruit", [], "Banana")

# Search for a word
results = trie.search("appl")
print(results)

# Search with spell correction
results = trie.search_with_correction("banan")
print(results)

# Update term frequency count
trie.update("apple", 3)
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from trie import Trie
Binary file added __pycache__/config.cpython-311.pyc
Binary file not shown.
Binary file added __pycache__/node.cpython-311.pyc
Binary file not shown.
Binary file added __pycache__/test.cpython-311.pyc
Binary file not shown.
Binary file added __pycache__/trie.cpython-311.pyc
Binary file not shown.
Binary file added __pycache__/utils.cpython-311.pyc
Binary file not shown.
4 changes: 4 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Configurable parameters for the PruningRadixTrieLib

MAX_EDIT_DISTANCE = 2 # Maximum edit distance for spell correction.
SEARCH_LIMIT = 10 # Default limit for search results.
33 changes: 33 additions & 0 deletions node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
class Node:
def __init__(self, value=""):
self.value = value
self.children = []
self.is_end_of_word = False
self.full_text = None
self.termFrequencyCount = 0
self.entity_type = None
self.nearest_neighbors = []
self.canonical_form = None

# Helper function to find a child node with a given prefix
def find_child(self, prefix):
low, high = 0, len(self.children) - 1
while low <= high:
mid = (low + high) // 2
if self.children[mid].value.startswith(prefix):
return self.children[mid]
elif self.children[mid].value < prefix:
low = mid + 1
else:
high = mid - 1
return None

def insert_child(self, child):
index = 0
while index < len(self.children) and self.children[index].value < child.value:
index += 1
self.children.insert(index, child)

# Updated method to get neighbors
def get_neighbors(self):
return self.nearest_neighbors
58 changes: 58 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest
from trie import Trie

class TestTrie(unittest.TestCase):

def setUp(self):
self.trie = Trie()
self.trie.insert("apple", termFrequencyCount=5, entity_type="fruit", neighbors=["apples"], canonical="Apple")

def test_search(self):
results = self.trie.search_with_correction("aple")
self.assertEqual(results[0]['text'], "apple")

def test_search_without_correction(self):
results = self.trie.search_with_correction("aple", correct_spelling=False)
self.assertEqual(len(results), 0)

def test_insert_and_search(self):
trie = Trie()
trie.insert("apple", 5, "fruit", [], "Apple")
trie.insert("appetite", 3, "noun", [], "Appetite")
trie.insert("apex", 2, "noun", [], "Apex")
trie.insert("banana", 7, "fruit", [], "Banana")

assert trie.search("apple")[0]['text'] == "apple"
assert trie.search("app")[0]['text'] == "apple"
assert trie.search("ban")[0]['text'] == "banana"
assert not trie.search("berry")

def test_search_with_correction(self):
trie = Trie()
trie.insert("apple", 5, "fruit", [], "Apple")
trie.insert("appetite", 3, "noun", [], "Appetite")
trie.insert("apex", 2, "noun", [], "Apex")
trie.insert("banana", 7, "fruit", [], "Banana")

results = trie.search_with_correction("appl")
self.assertTrue(results and results[0]['text'] == "apple")

results = trie.search_with_correction("banan")
self.assertTrue(results and results[0]['text'] == "banana")

results = trie.search_with_correction("bery")
self.assertFalse(results) # Expecting no results for "bery"

def test_update(self):
trie = Trie()
trie.insert("apple", 5, "fruit", [], "Apple")
trie.update("apple", 3)
assert trie.search("apple")[0]['termFrequencyCount'] == 8

try:
trie.update("berry", 2)
except ValueError as e:
assert str(e) == "Word 'berry' not found in trie."

if __name__ == '__main__':
unittest.main()
150 changes: 150 additions & 0 deletions trie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from node import Node
from utils import calculate_levenshtein_distance
from config import MAX_EDIT_DISTANCE, SEARCH_LIMIT
import heapq

class Trie:
def __init__(self, max_edit_distance=None, search_limit=None):
self.root = Node()
self.max_edit_distance = max_edit_distance or MAX_EDIT_DISTANCE
self.search_limit = search_limit or SEARCH_LIMIT

def insert(self, word, termFrequencyCount, entity_type, neighbors, canonical):
node = self.root
i = 0
while i < len(word):
prefix = self._longest_prefix(word[i:], node)
if not prefix:
new_child = Node(word[i:])
node.insert_child(new_child)
node = new_child
i += len(word[i:])
else:
child = node.find_child(prefix)
if prefix != word[i:]:
# Split the node
split_point = len(word[i:])
new_node_value = prefix[split_point:]
old_node = child

new_node = Node(new_node_value)
new_node.children = old_node.children
new_node.is_end_of_word = old_node.is_end_of_word
new_node.full_text = old_node.full_text
new_node.termFrequencyCount = old_node.termFrequencyCount
new_node.entity_type = old_node.entity_type
new_node.nearest_neighbors = old_node.nearest_neighbors
new_node.canonical_form = old_node.canonical_form

node.children.remove(old_node)
node.insert_child(Node(prefix[:split_point]))
node = node.find_child(prefix[:split_point])
node.insert_child(new_node)

i += split_point
else:
node = child
i += len(prefix)

node.is_end_of_word = True
node.full_text = word
node.termFrequencyCount = termFrequencyCount
node.entity_type = entity_type
node.nearest_neighbors = [self.search(neighbor, limit=1)[0] for neighbor in neighbors if self.search(neighbor, limit=1)]
node.canonical_form = canonical


def print_trie(self, node=None, indent=""):
if node is None:
node = self.root
print(indent + node.value + ("*" if node.is_end_of_word else ""))
for child in node.children:
self.print_trie(child, indent + " ")


def update(self, word, termFrequencyCount_increment):
node = self._find_word_node(word)
if node:
node.termFrequencyCount += termFrequencyCount_increment
else:
raise ValueError(f"Word '{word}' not found in trie.")

def _find_word_node(self, word):
node = self.root
i = 0
while i < len(word):
prefix = self._longest_prefix(word[i:], node)
if not prefix:
return None
node = node.find_child(prefix)
i += len(prefix)
return node if node.is_end_of_word else None


def search(self, query, limit=SEARCH_LIMIT):
results = []
node = self.root
i = 0
while i < len(query):
prefix = self._longest_prefix(query[i:], node)
if not prefix:
break
node = node.find_child(prefix)
i += len(prefix)
self._dfs(node, query, results, limit)
results.sort(key=lambda x: x['termFrequencyCount'], reverse=True)
return results[:limit]

def _dfs(self, node, prefix, results, limit):
if len(results) >= limit:
return
if node.is_end_of_word and node.full_text.startswith(prefix):
results.append({
"text": node.full_text,
"termFrequencyCount": node.termFrequencyCount,
"type": node.entity_type,
"neighbors": [neighbor.full_text for neighbor in node.get_neighbors()],
"canonical": node.canonical_form
})
for child in node.children:
self._dfs(child, prefix, results, limit)

def search_with_correction(self, query, correct_spelling=True):
results = self.search(query, self.search_limit)
if not results and correct_spelling:
closest_word = self.find_closest_word(query)
if closest_word:
results = self.search(closest_word, self.search_limit)
return results


def find_closest_word(self, query):
min_distance = float('inf')
closest_word = None
for word in self.all_words:
distance = calculate_levenshtein_distance(query, word)
if distance <= self.max_edit_distance and distance < min_distance:
min_distance = distance
closest_word = word
return closest_word

# New method to get all words in the Trie for spell correction
@property
def all_words(self):
words = []

def _collect_words(node, current_word):
if node.is_end_of_word:
words.append(current_word + node.value)
for child in node.children:
_collect_words(child, current_word + node.value)

_collect_words(self.root, "")
return words


def _longest_prefix(self, word, node):
child = node.find_child(word[0])
if child and word.startswith(child.value):
return child.value
return ""
18 changes: 18 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

def calculate_levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1

distances = range(len(s1) + 1)
for index2, char2 in enumerate(s2):
new_distances = [index2 + 1]
for index1, char1 in enumerate(s1):
if char1 == char2:
new_distances.append(distances[index1])
else:
new_distances.append(1 + min((distances[index1], distances[index1 + 1], new_distances[-1])))
distances = new_distances

return distances[-1]

# TODO: Implement SymSpell or other efficient spell correction algorithms for optimization.

0 comments on commit deca12a

Please sign in to comment.