Skip to content

Commit

Permalink
v0 of edit distance repair
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Dec 5, 2024
1 parent 4be4067 commit bcb61ee
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 1 deletion.
141 changes: 140 additions & 1 deletion src/ast/sls/sls_seq_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ namespace sls {
return ev.lhs;
}

ptr_vector<expr> const& seq_plugin::concats(expr* x) {
auto& ev = get_eval(x);
if (ev.lhs.empty())
seq.str.get_concat(x, ev.lhs);
return ev.lhs;
}

ptr_vector<expr> const& seq_plugin::rhs(expr* eq) {
lhs(eq);
auto& e = get_eval(eq);
Expand Down Expand Up @@ -593,7 +600,8 @@ namespace sls {
VERIFY(m.is_eq(e, x, y));
IF_VERBOSE(3, verbose_stream() << is_true << ": " << mk_bounded_pp(e, m, 3) << "\n");
if (ctx.is_true(e)) {
if (ctx.rand(10) != 0)
//return repair_down_str_eq_edit_distance(e);
if (ctx.rand(2) != 0)
return repair_down_str_eq_unify(e);
if (!is_value(x))
m_str_updates.push_back({ x, strval1(y), 1 });
Expand All @@ -619,6 +627,100 @@ namespace sls {
return apply_update();
}

/**
* \brief compute the edit distance between two strings.
*/
unsigned seq_plugin::edit_distance(zstring const& a, zstring const& b) {
unsigned n = a.length();
unsigned m = b.length();
vector<unsigned_vector> d(n + 1);
for (unsigned i = 0; i <= n; ++i)
d[i].resize(m + 1, 0);
for (unsigned i = 0; i <= n; ++i)
d[i][0] = i;
for (unsigned j = 0; j <= m; ++j)
d[0][j] = j;
for (unsigned j = 1; j <= m; ++j) {
for (unsigned i = 1; i <= n; ++i) {
if (a[i - 1] == b[j - 1])
d[i][j] = d[i - 1][j - 1];
else
d[i][j] = std::min(std::min(d[i - 1][j] + 1, d[i][j - 1] + 1), d[i - 1][j - 1] + 1);
}
}
return d[n][m];
}

void seq_plugin::add_edit_updates(ptr_vector<expr> const& w, uint_set const& chars) {
for (auto x : w) {
if (is_value(x))
continue;
zstring const & a = strval0(x);
for (auto ch : chars)
m_str_updates.push_back({ x, a + zstring(ch), 1 });
for (auto ch : chars)
m_str_updates.push_back({ x, zstring(ch) + a, 1 });
if (a.length() > 0) {
zstring b = a.extract(0, a.length() - 1);
m_str_updates.push_back({ x, b, 1 }); // truncate a
for (auto ch : chars)
m_str_updates.push_back({ x, b + zstring(ch), 1 }); // replace last character in a by ch
b = a.extract(1, a.length() - 1);
m_str_updates.push_back({ x, b, 1 }); // truncate a
for (auto ch : chars)
m_str_updates.push_back({ x, zstring(ch) + b, 1 }); // replace first character in a by ch
}
}
}

bool seq_plugin::repair_down_str_eq_edit_distance(app* eq) {
auto const& L = lhs(eq);
auto const& R = rhs(eq);
zstring a, b;
uint_set a_chars, b_chars;

for (auto x : L) {
for (auto ch : strval0(x))
a_chars.insert(ch);
a += strval0(x);
}
for (auto y : R) {
for (auto ch : strval0(y))
b_chars.insert(ch);
b += strval0(y);
}
if (a == b)
return update(eq->get_arg(0), a) && update(eq->get_arg(1), b);

unsigned diff = a.length() + b.length() + L.size() + R.size();

add_edit_updates(L, b_chars);
add_edit_updates(R, a_chars);

for (auto& [x, s, score] : m_str_updates) {
a.reset();
b.reset();
for (auto z : L) {
if (z == x)
a += s;
else
a += strval0(z);
}
for (auto z : R) {
if (z == x)
b += s;
else
b += strval0(z);
}
unsigned local_diff = edit_distance(a, b);
if (local_diff >= diff)
score = 0.1;
else
score = (diff - local_diff) * (diff - local_diff);
}
return apply_update();
}

bool seq_plugin::repair_down_str_eq_unify(app* eq) {
auto const& L = lhs(eq);
auto const& R = rhs(eq);
Expand Down Expand Up @@ -1081,6 +1183,42 @@ namespace sls {
return apply_update();
}

#if 1
bool seq_plugin::repair_down_str_concat(app* e) {
auto const& es = concats(e);
zstring value;
zstring value0 = strval0(e);
for (auto const& e : es)
value += strval0(e);
if (value == value0)
return true;
uint_set chars;

for (auto ch : value0)
chars.insert(ch);

add_edit_updates(es, chars);

unsigned diff = edit_distance(value, value0);
for (auto& [x, s, score] : m_str_updates) {
value.reset();
for (auto z : es) {
if (z == x)
value += s;
else
value += strval0(z);
}
unsigned local_diff = edit_distance(value, value0);
if (local_diff >= diff)
score = 0.1;
else
score = (diff - local_diff) * (diff - local_diff);
}
return apply_update();

}
#else
bool seq_plugin::repair_down_str_concat(app* e) {
zstring val_e = strval0(e);
unsigned len_e = val_e.length();
Expand Down Expand Up @@ -1125,6 +1263,7 @@ namespace sls {
}
return true;
}
#endif



Expand Down
5 changes: 5 additions & 0 deletions src/ast/sls/sls_seq_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ namespace sls {
bool repair_down_seq(app* e);
bool repair_down_eq(app* e);
bool repair_down_str_eq_unify(app* e);
bool repair_down_str_eq_edit_distance(app* e);
bool repair_down_str_eq(app* e);
bool repair_down_str_extract(app* e);
bool repair_down_str_contains(expr* e);
Expand All @@ -90,6 +91,9 @@ namespace sls {
void repair_up_str_itos(app* e);
void repair_up_str_stoi(app* e);

unsigned edit_distance(zstring const& a, zstring const& b);
void add_edit_updates(ptr_vector<expr> const& w, uint_set const& chars);

// regex functionality

// enumerate set of strings that can match a prefix of regex r.
Expand All @@ -111,6 +115,7 @@ namespace sls {
eval* get_eval(expr* e) const;
ptr_vector<expr> const& lhs(expr* eq);
ptr_vector<expr> const& rhs(expr* eq);
ptr_vector<expr> const& concats(expr* eq);

bool is_value(expr* e);
public:
Expand Down
5 changes: 5 additions & 0 deletions src/util/zstring.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class zstring {
bool operator!=(const zstring& other) const;
unsigned hash() const;

void reset() { m_buffer.reset(); }
zstring& operator+=(zstring const& other) { m_buffer.append(other.m_buffer); return *this; }
uint32_t const* begin() const { return m_buffer.begin(); }
uint32_t const* end() const { return m_buffer.end(); }

friend std::ostream& operator<<(std::ostream &os, const zstring &str);
friend bool operator<(const zstring& lhs, const zstring& rhs);
};

0 comments on commit bcb61ee

Please sign in to comment.