-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rewrite the editdist function (Levenshtein) in C
- Loading branch information
Showing
2 changed files
with
86 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,106 @@ | ||
#!/usr/bin/env cython | ||
# cython: language_level=3 | ||
# cython: boundscheck=False | ||
# cython: wraparound=False | ||
# cython: cdivision=True | ||
# cython: embedsignature=True | ||
# coding: utf-8 | ||
# | ||
# Copyright (C) 2021 Radim Rehurek <[email protected]> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
# Code adapted from TinyFastSS (public domain), https://github.com/fujimotos/TinyFastSS | ||
|
||
"""Create and query FastSS index for fast approximate string similarity search.""" | ||
"""Fast approximate string similarity search using the FastSS algorithm.""" | ||
|
||
import itertools | ||
|
||
from cpython.ref cimport PyObject | ||
|
||
DEF MAX_WORD_LENGTH = 254 # a trade-off between speed (fast stack allocations) and versatility (long strings) | ||
|
||
DEF MAX_WORD_LENGTH = 10000 # Maximum allowed word length, in characters. Must fit in the C `int` range. | ||
|
||
def editdist(s1: unicode, s2: unicode, max_dist=None): | ||
|
||
cdef extern from *: | ||
""" | ||
#define WIDTH int | ||
#define MAX_WORD_LENGTH 10000 | ||
int ceditdist(PyObject * s1, PyObject * s2, WIDTH maximum) { | ||
WIDTH row1[MAX_WORD_LENGTH + 1]; | ||
WIDTH row2[MAX_WORD_LENGTH + 1]; | ||
WIDTH * CYTHON_RESTRICT pos_new; | ||
WIDTH * CYTHON_RESTRICT pos_old; | ||
int row_flip = 1; /* Does pos_new represent row1 or row2? */ | ||
int kind = PyUnicode_KIND(s1); /* How many bytes per unicode codepoint? */ | ||
if (kind != PyUnicode_KIND(s2)) return -1; | ||
WIDTH len_s1 = (WIDTH)PyUnicode_GET_LENGTH(s1); | ||
WIDTH len_s2 = (WIDTH)PyUnicode_GET_LENGTH(s2); | ||
if (len_s1 > len_s2) { | ||
PyObject * tmp = s1; s1 = s2; s2 = tmp; | ||
const WIDTH tmpi = len_s1; len_s1 = len_s2; len_s2 = tmpi; | ||
} | ||
if (len_s2 - len_s1 > maximum) return maximum + 1; | ||
if (len_s2 > MAX_WORD_LENGTH) return -2; | ||
void * s1_data = PyUnicode_DATA(s1); | ||
void * s2_data = PyUnicode_DATA(s2); | ||
for (WIDTH tmpi = 0; tmpi <= len_s1; tmpi++) row2[tmpi] = tmpi; | ||
for (WIDTH i2 = 0; i2 < len_s2; i2++) { | ||
int all_bad = i2 >= maximum; | ||
const Py_UCS4 ch = PyUnicode_READ(kind, s2_data, i2); | ||
row_flip = 1 - row_flip; | ||
if (row_flip) { | ||
pos_new = row2; pos_old = row1; | ||
} else { | ||
pos_new = row1; pos_old = row2; | ||
} | ||
*pos_new = i2 + 1; | ||
for (WIDTH i1 = 0; i1 < len_s1; i1++) { | ||
WIDTH val = *(pos_old++); | ||
if (ch != PyUnicode_READ(kind, s1_data, i1)) { | ||
const WIDTH _val1 = *pos_old; | ||
const WIDTH _val2 = *pos_new; | ||
if (_val1 < val) val = _val1; | ||
if (_val2 < val) val = _val2; | ||
val += 1; | ||
} | ||
*(++pos_new) = val; | ||
if (all_bad && val <= maximum) all_bad = 0; | ||
} | ||
if (all_bad) return maximum + 1; | ||
} | ||
return row_flip ? row2[len_s1] : row1[len_s1]; | ||
} | ||
""" | ||
If the Levenshtein distance between two strings is <= max_dist, return that distance. | ||
Otherwise return max_dist+1. | ||
int ceditdist(PyObject *s1, PyObject *s2, int maximum) | ||
|
||
|
||
def editdist(s1: str, s2: str, max_dist=None): | ||
""" | ||
Return the Levenshtein distance between two strings. | ||
Use `max_dist` to control the maximum distance you care about. If the actual distance is larger | ||
than `max_dist`, editdist will return early, with the value `max_dist+1`. | ||
This is a performance optimization – for example if anything above distance 2 is uninteresting | ||
to your application, call editdist with `max_dist=2` and ignore any return value greater than 2. | ||
Leave `max_dist=None` (default) to always return the full Levenshtein distance (slower). | ||
""" | ||
if s1 == s2: | ||
return 0 | ||
|
||
if len(s1) > len(s2): | ||
s1, s2 = s2, s1 | ||
|
||
if len(s2) > MAX_WORD_LENGTH: | ||
result = ceditdist(<PyObject *>s1, <PyObject *>s2, MAX_WORD_LENGTH if max_dist is None else int(max_dist)) | ||
if result >= 0: | ||
return result | ||
elif result == -1: | ||
raise ValueError("incompatible types of unicode strings") | ||
elif result == -2: | ||
raise ValueError(f"editdist doesn't support strings longer than {MAX_WORD_LENGTH} characters") | ||
|
||
cdef unsigned char len_s1 = len(s1) | ||
cdef unsigned char len_s2 = len(s2) | ||
cdef unsigned char maximum = min(len_s2, max_dist or MAX_WORD_LENGTH) | ||
|
||
if len_s2 - len_s1 > maximum: | ||
return maximum + 1 | ||
|
||
cdef unsigned char all_bad, i1, i2, val | ||
cdef unsigned char[MAX_WORD_LENGTH + 1] row1, row2 | ||
cdef unsigned char * row_new = &row1[0] | ||
cdef unsigned char * row_old = &row2[0] | ||
for i1 in range(len_s1 + 1): | ||
row_old[i1] = i1 | ||
|
||
for i2 in range(len_s2): | ||
row_new[0] = i2 + 1 | ||
all_bad = i2 >= maximum | ||
for i1 in range(len_s1): | ||
if s1[i1] == s2[i2]: | ||
val = row_old[i1] | ||
else: | ||
val = 1 + min((row_old[i1], row_old[i1 + 1], row_new[i1])) | ||
row_new[i1 + 1] = val | ||
if all_bad and val <= maximum: | ||
all_bad = 0 | ||
if all_bad: | ||
return maximum + 1 | ||
row_new, row_old = row_old, row_new | ||
|
||
return row_old[len_s1] | ||
else: | ||
raise ValueError(f"editdist returned an error: {result}") | ||
|
||
|
||
def indexkeys(word, max_dist): | ||
|
@@ -75,9 +114,7 @@ def indexkeys(word, max_dist): | |
limit = min(max_dist, wordlen) + 1 | ||
|
||
for dist in range(limit): | ||
variants = itertools.combinations(word, wordlen - dist) | ||
|
||
for variant in variants: | ||
for variant in itertools.combinations(word, wordlen - dist): | ||
res.add(''.join(variant)) | ||
|
||
return res | ||
|
@@ -98,10 +135,7 @@ def bytes2set(b): | |
>>> bytes2set(b'a\x00b\x00c') | ||
{u'a', u'b', u'c'} | ||
""" | ||
if not b: | ||
return set() | ||
|
||
return set(b.decode('utf8').split('\x00')) | ||
return set(b.decode('utf8').split('\x00')) if b else set() | ||
|
||
|
||
class FastSS: | ||
|
@@ -147,8 +181,7 @@ class FastSS: | |
max_dist = self.max_dist | ||
if max_dist > self.max_dist: | ||
raise ValueError( | ||
f"query max_dist={max_dist} cannot be greater than " | ||
f"max_dist={self.max_dist} specified in the constructor" | ||
f"query max_dist={max_dist} cannot be greater than max_dist={self.max_dist} from the constructor" | ||
) | ||
|
||
res = {d: [] for d in range(max_dist + 1)} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters