Skip to content

Commit

Permalink
Lifting select and substitution code in class hierarchy to avoid code…
Browse files Browse the repository at this point in the history
… duplication
  • Loading branch information
ckirsch committed Dec 4, 2024
1 parent a4b8d7c commit 0e958de
Showing 1 changed file with 84 additions and 136 deletions.
220 changes: 84 additions & 136 deletions tools/bitme.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,86 @@ def get(nid):
assert Line.is_defined(nid), f"undefined nid {self.nid} @ {self.line_no}"
return Line.lines[nid]

def get_z3_lambda(self, line):
if self.z3_lambda_line is None:
if line.domain:
self.z3_lambda_line = z3.Lambda([state.get_z3() for state in line.domain], line.get_z3())
else:
self.z3_lambda_line = line.get_z3()
return self.z3_lambda_line

def get_z3_select(self, domain, step):
if step not in self.cache_z3_instance:
if domain:
self.cache_z3_instance[step] = z3.Select(self.get_z3_lambda(), *[state.get_z3_step(step) for state in domain])
else:
self.cache_z3_instance[step] = self.get_z3_lambda()
return self.cache_z3_instance[step]

def get_z3_substitute(self, domain, step):
assert step >= 0
if step not in self.cache_z3_instance:
self.z3 = self.get_z3()
if domain:
if step == 0:
current_states = [state.get_z3() for state in domain]
else:
# assuming that self.z3 is a term over states of step - 1
current_states = [state.get_z3_step(step - 1) for state in domain]
next_states = [state.get_z3_step(step) for state in domain]
renaming = list(zip(current_states, next_states))

self.z3 = z3.substitute(self.z3, renaming)
self.cache_z3_instance[step] = self.z3
return self.cache_z3_instance[step]

def get_z3_step(self, domain, step):
if Line.LAMBDAS:
return self.get_z3_select(domain, step)
else:
return self.get_z3_substitute(domain, step)

def get_bitwuzla_lambda(self, line, tm):
if self.bitwuzla_lambda_line is None:
if line.domain:
self.bitwuzla_lambda_line = tm.mk_term(bitwuzla.Kind.LAMBDA,
[*[state.get_bitwuzla(tm) for state in line.domain], line.get_bitwuzla(tm)])
else:
self.bitwuzla_lambda_line = line.get_bitwuzla(tm)
return self.bitwuzla_lambda_line

def get_bitwuzla_select(self, domain, step, tm):
if step not in self.cache_bitwuzla_instance:
if domain:
self.cache_bitwuzla_instance[step] = tm.mk_term(bitwuzla.Kind.APPLY,
[self.get_bitwuzla_lambda(tm), *[state.get_bitwuzla_step(step, tm) for state in domain]])
else:
self.cache_bitwuzla_instance[step] = self.get_bitwuzla_lambda(tm)
return self.cache_bitwuzla_instance[step]

def get_bitwuzla_substitute(self, domain, step, tm):
assert step >= 0
if step not in self.cache_bitwuzla_instance:
self.bitwuzla = self.get_bitwuzla(tm)
if domain:
if step == 0:
current_states = [state.get_bitwuzla(tm) for state in domain]
else:
# assuming that self.bitwuzla is a term over states of step - 1
current_states = [state.get_bitwuzla_step(step - 1, tm) for state in domain]
next_states = [state.get_bitwuzla_step(step, tm) for state in domain]
renaming = dict(zip(current_states, next_states))

self.bitwuzla = tm.substitute_term(self.bitwuzla, renaming)
self.cache_bitwuzla_instance[step] = self.bitwuzla
return self.cache_bitwuzla_instance[step]

def get_bitwuzla_step(self, domain, step, tm):
if Line.LAMBDAS:
return self.get_bitwuzla_select(domain, step, tm)
else:
return self.get_bitwuzla_substitute(domain, step, tm)

class Sort(Line):
keyword = OP_SORT

Expand Down Expand Up @@ -573,33 +653,6 @@ def get_z3_step(self, step):
self.cache_z3[step] = z3.Const(self.get_step_name(step), self.sid_line.get_z3())
return self.cache_z3[step]

def get_z3_lambda(line):
if line.domain:
return z3.Lambda([state.get_z3() for state in line.domain], line.get_z3())
else:
return line.get_z3()

def get_z3_select(line, domain, step):
if domain:
return z3.Select(line.get_z3_lambda(), *[state.get_z3_step(step) for state in domain])
else:
return line.get_z3_lambda()

def get_z3_substitute(line, domain, step):
assert step >= 0
line.z3 = line.get_z3()
if domain:
if step == 0:
current_states = [state.get_z3() for state in domain]
else:
# assuming that line.z3 is a term over states of step - 1
current_states = [state.get_z3_step(step - 1) for state in domain]
next_states = [state.get_z3_step(step) for state in domain]
renaming = list(zip(current_states, next_states))

line.z3 = z3.substitute(line.z3, renaming)
return line.z3

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
self.bitwuzla = tm.mk_var(self.sid_line.get_bitwuzla(tm), self.name)
Expand All @@ -611,35 +664,6 @@ def get_bitwuzla_step(self, step, tm):
self.get_step_name(step))
return self.cache_bitwuzla[step]

def get_bitwuzla_lambda(line, tm):
if line.domain:
return tm.mk_term(bitwuzla.Kind.LAMBDA,
[*[state.get_bitwuzla(tm) for state in line.domain], line.get_bitwuzla(tm)])
else:
return line.get_bitwuzla(tm)

def get_bitwuzla_select(line, domain, step, tm):
if domain:
return tm.mk_term(bitwuzla.Kind.APPLY,
[line.get_bitwuzla_lambda(tm), *[state.get_bitwuzla_step(step, tm) for state in domain]])
else:
return line.get_bitwuzla_lambda(tm)

def get_bitwuzla_substitute(line, domain, step, tm):
assert step >= 0
line.bitwuzla = line.get_bitwuzla(tm)
if domain:
if step == 0:
current_states = [state.get_bitwuzla(tm) for state in domain]
else:
# assuming that line.bitwuzla is a term over states of step - 1
current_states = [state.get_bitwuzla_step(step - 1, tm) for state in domain]
next_states = [state.get_bitwuzla_step(step, tm) for state in domain]
renaming = dict(zip(current_states, next_states))

line.bitwuzla = tm.substitute_term(line.bitwuzla, renaming)
return line.bitwuzla

class Indexed(Expression):
def __init__(self, nid, sid_line, arg1_line, comment, line_no):
super().__init__(nid, sid_line, arg1_line.domain, comment, line_no)
Expand Down Expand Up @@ -1212,28 +1236,11 @@ def get_z3(self):

def get_z3_lambda(self):
# only needed for branching
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(self)
return self.z3_lambda_line

def get_z3_select(self, domain, step):
# only needed for branching
if step not in self.cache_z3_instance:
self.cache_z3_instance[step] = State.get_z3_select(self, domain, step)
return self.cache_z3_instance[step]

def get_z3_substitute(self, domain, step):
# only needed for branching
if step not in self.cache_z3_instance:
self.cache_z3_instance[step] = State.get_z3_substitute(self, domain, step)
return self.cache_z3_instance[step]
return super().get_z3_lambda(self)

def get_z3_step(self, step):
# only needed for branching
if Line.LAMBDAS:
return self.get_z3_select(self.domain, step)
else:
return self.get_z3_substitute(self.domain, step)
return super().get_z3_step(self.domain, step)

def get_bitwuzla(self, tm):
if self.bitwuzla is None:
Expand All @@ -1245,28 +1252,11 @@ def get_bitwuzla(self, tm):

def get_bitwuzla_lambda(self, tm):
# only needed for branching
if self.bitwuzla_lambda_line is None:
self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(self, tm)
return self.bitwuzla_lambda_line

def get_bitwuzla_select(self, domain, step, tm):
# only needed for branching
if step not in self.cache_bitwuzla_instance:
self.cache_bitwuzla_instance[step] = State.get_bitwuzla_select(self, domain, step, tm)
return self.cache_bitwuzla_instance[step]

def get_bitwuzla_substitute(self, domain, step, tm):
# only needed for branching
if step not in self.cache_bitwuzla_instance:
self.cache_bitwuzla_instance[step] = State.get_bitwuzla_substitute(self, domain, step, tm)
return self.cache_bitwuzla_instance[step]
return super().get_bitwuzla_lambda(self, tm)

def get_bitwuzla_step(self, step, tm):
# only needed for branching
if Line.LAMBDAS:
return self.get_bitwuzla_select(self.domain, step, tm)
else:
return self.get_bitwuzla_substitute(self.domain, step, tm)
return super().get_bitwuzla_step(self.domain, step, tm)

class Write(Ternary):
keyword = OP_WRITE
Expand Down Expand Up @@ -1347,53 +1337,11 @@ def get_z3(self, line):
self.z3 = line.get_z3()
return self.z3

def get_z3_lambda(self, line):
if self.z3_lambda_line is None:
self.z3_lambda_line = State.get_z3_lambda(line)
return self.z3_lambda_line

def get_z3_select(self, domain, step):
if step not in self.cache_z3_instance:
self.cache_z3_instance[step] = State.get_z3_select(self, domain, step)
return self.cache_z3_instance[step]

def get_z3_substitute(self, domain, step):
if step not in self.cache_z3_instance:
self.cache_z3_instance[step] = State.get_z3_substitute(self, domain, step)
return self.cache_z3_instance[step]

def get_z3_step(self, domain, step):
if Line.LAMBDAS:
return self.get_z3_select(domain, step)
else:
return self.get_z3_substitute(domain, step)

def get_bitwuzla(self, line, tm):
if self.bitwuzla is None:
self.bitwuzla = line.get_bitwuzla(tm)
return self.bitwuzla

def get_bitwuzla_lambda(self, line, tm):
if self.bitwuzla_lambda_line is None:
self.bitwuzla_lambda_line = State.get_bitwuzla_lambda(line, tm)
return self.bitwuzla_lambda_line

def get_bitwuzla_select(self, domain, step, tm):
if step not in self.cache_bitwuzla_instance:
self.cache_bitwuzla_instance[step] = State.get_bitwuzla_select(self, domain, step, tm)
return self.cache_bitwuzla_instance[step]

def get_bitwuzla_substitute(self, domain, step, tm):
if step not in self.cache_bitwuzla_instance:
self.cache_bitwuzla_instance[step] = State.get_bitwuzla_substitute(self, domain, step, tm)
return self.cache_bitwuzla_instance[step]

def get_bitwuzla_step(self, domain, step, tm):
if Line.LAMBDAS:
return self.get_bitwuzla_select(domain, step, tm)
else:
return self.get_bitwuzla_substitute(domain, step, tm)

class Transitional(Sequential):
def __init__(self, nid, sid_line, state_line, exp_line, comment, line_no, array_line, index):
super().__init__(nid, comment, line_no)
Expand Down

0 comments on commit 0e958de

Please sign in to comment.