Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sort declaration unparsing #2767

Merged
merged 7 commits into from
Aug 4, 2022
4 changes: 2 additions & 2 deletions pyk/src/pyk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
propagate_up_constraints,
removeSourceMap,
)
from .ktool import KPrint, KProve, build_symbol_table, prettyPrintKast
from .ktool import KPrint, KProve, build_symbol_table, pretty_print_kast
from .prelude import Sorts, mlAnd, mlOr, mlTop

_LOG_FORMAT: Final = '%(levelname)s %(asctime)s %(name)s - %(message)s'
Expand Down Expand Up @@ -84,7 +84,7 @@ def main():
args['output'].write('\n\n')
args['output'].write('Rule: ' + rid.strip())
args['output'].write('\nUnparsed:\n')
args['output'].write(prettyPrintKast(rule, symbol_table))
args['output'].write(pretty_print_kast(rule, symbol_table))

else:
raise ValueError(f'Unknown command: {args["command"]}')
Expand Down
5 changes: 2 additions & 3 deletions pyk/src/pyk/ktool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from .kompile import KompileBackend, kompile
from .kprint import (
KPrint,
appliedLabelStr,
applied_label_str,
build_symbol_table,
indent,
paren,
prettyPrintKast,
prettyPrintKastBool,
pretty_print_kast,
unparser_for_production,
)
from .kprove import KProve, kprove
147 changes: 75 additions & 72 deletions pyk/src/pyk/ktool/kprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def pretty_print(self, kast: KAst, debug=False):
- Input: KAST term in JSON.
- Output: Best-effort pretty-printed representation of the KAST term.
"""
return prettyPrintKast(kast, self.symbol_table, debug=debug)
return pretty_print_kast(kast, self.symbol_table, debug=debug)


def unparser_for_production(prod):
Expand Down Expand Up @@ -97,7 +97,7 @@ def build_symbol_table(definition: KDefinition, opinionated=False) -> SymbolTabl
return symbol_table


def prettyPrintKast(kast: KAst, symbol_table: SymbolTable, debug=False):
def pretty_print_kast(kast: KAst, symbol_table: SymbolTable, debug=False):
"""Print out KAST terms/outer syntax.

- Input: KAST term.
Expand All @@ -116,120 +116,123 @@ def prettyPrintKast(kast: KAst, symbol_table: SymbolTable, debug=False):
if type(kast) is KApply:
label = kast.label.name
args = kast.args
unparsedArgs = [prettyPrintKast(arg, symbol_table, debug=debug) for arg in args]
unparsed_args = [pretty_print_kast(arg, symbol_table, debug=debug) for arg in args]
if kast.is_cell:
cellContents = '\n'.join(unparsedArgs).rstrip()
cellStr = label + '\n' + indent(cellContents) + '\n</' + label[1:]
return cellStr.rstrip()
unparser = appliedLabelStr(label) if label not in symbol_table else symbol_table[label]
return unparser(*unparsedArgs)
cell_contents = '\n'.join(unparsed_args).rstrip()
cell_str = label + '\n' + indent(cell_contents) + '\n</' + label[1:]
return cell_str.rstrip()
unparser = applied_label_str(label) if label not in symbol_table else symbol_table[label]
return unparser(*unparsed_args)
if type(kast) is KAs:
patternStr = prettyPrintKast(kast.pattern, symbol_table, debug=debug)
aliasStr = prettyPrintKast(kast.alias, symbol_table, debug=debug)
return patternStr + ' #as ' + aliasStr
pattern_str = pretty_print_kast(kast.pattern, symbol_table, debug=debug)
alias_str = pretty_print_kast(kast.alias, symbol_table, debug=debug)
return pattern_str + ' #as ' + alias_str
if type(kast) is KRewrite:
lhsStr = prettyPrintKast(kast.lhs, symbol_table, debug=debug)
rhsStr = prettyPrintKast(kast.rhs, symbol_table, debug=debug)
return '( ' + lhsStr + ' => ' + rhsStr + ' )'
lhs_str = pretty_print_kast(kast.lhs, symbol_table, debug=debug)
rhs_str = pretty_print_kast(kast.rhs, symbol_table, debug=debug)
return '( ' + lhs_str + ' => ' + rhs_str + ' )'
if type(kast) is KSequence:
if kast.arity == 0:
return prettyPrintKast(KApply(Labels.EMPTY_K), symbol_table, debug=debug)
return pretty_print_kast(KApply(Labels.EMPTY_K), symbol_table, debug=debug)
if kast.arity == 1:
return prettyPrintKast(kast.items[0], symbol_table, debug=debug)
unparsedKSequence = '\n~> '.join([prettyPrintKast(item, symbol_table, debug=debug) for item in kast.items[0:-1]])
return pretty_print_kast(kast.items[0], symbol_table, debug=debug)
unparsed_k_seq = '\n~> '.join([pretty_print_kast(item, symbol_table, debug=debug) for item in kast.items[0:-1]])
if kast.items[-1] == ktokenDots:
unparsedKSequence = unparsedKSequence + '\n' + prettyPrintKast(ktokenDots, symbol_table, debug=debug)
unparsed_k_seq = unparsed_k_seq + '\n' + pretty_print_kast(ktokenDots, symbol_table, debug=debug)
else:
unparsedKSequence = unparsedKSequence + '\n~> ' + prettyPrintKast(kast.items[-1], symbol_table, debug=debug)
return unparsedKSequence
unparsed_k_seq = unparsed_k_seq + '\n~> ' + pretty_print_kast(kast.items[-1], symbol_table, debug=debug)
return unparsed_k_seq
if type(kast) is KTerminal:
return '"' + kast.value + '"'
if type(kast) is KRegexTerminal:
return 'r"' + kast.regex + '"'
if type(kast) is KNonTerminal:
return prettyPrintKast(kast.sort, symbol_table, debug=debug)
return pretty_print_kast(kast.sort, symbol_table, debug=debug)
if type(kast) is KProduction:
if 'klabel' not in kast.att and kast.klabel:
kast = kast.update_atts({'klabel': kast.klabel.name})
sortStr = prettyPrintKast(kast.sort, symbol_table, debug=debug)
productionStr = ' '.join([prettyPrintKast(pi, symbol_table, debug=debug) for pi in kast.items])
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax ' + sortStr + ' ::= ' + productionStr + ' ' + attStr
syntax_str = 'syntax ' + pretty_print_kast(kast.sort, symbol_table, debug=debug)
if kast.items:
syntax_str += ' ::= ' + ' '.join([pretty_print_kast(pi, symbol_table, debug=debug) for pi in kast.items])
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
if att_str:
syntax_str += ' ' + att_str
return syntax_str
if type(kast) is KSyntaxSort:
sortStr = prettyPrintKast(kast.sort, symbol_table, debug=debug)
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax ' + sortStr + ' ' + attStr
sort_str = pretty_print_kast(kast.sort, symbol_table, debug=debug)
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return 'syntax ' + sort_str + ' ' + att_str
if type(kast) is KSortSynonym:
newSortStr = prettyPrintKast(kast.new_sort, symbol_table, debug=debug)
oldSortStr = prettyPrintKast(kast.old_sort, symbol_table, debug=debug)
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax ' + newSortStr + ' = ' + oldSortStr + ' ' + attStr
new_sort_str = pretty_print_kast(kast.new_sort, symbol_table, debug=debug)
old_sort_str = pretty_print_kast(kast.old_sort, symbol_table, debug=debug)
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return 'syntax ' + new_sort_str + ' = ' + old_sort_str + ' ' + att_str
if type(kast) is KSyntaxLexical:
nameStr = kast.name
regexStr = kast.regex
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
name_str = kast.name
regex_str = kast.regex
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
# todo: proper escaping
return 'syntax lexical ' + nameStr + ' = r"' + regexStr + '" ' + attStr
return 'syntax lexical ' + name_str + ' = r"' + regex_str + '" ' + att_str
if type(kast) is KSyntaxAssociativity:
assocStr = kast.assoc.value
tagsStr = ' '.join(kast.tags)
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax associativity ' + assocStr + ' ' + tagsStr + ' ' + attStr
assoc_str = kast.assoc.value
tags_str = ' '.join(kast.tags)
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return 'syntax associativity ' + assoc_str + ' ' + tags_str + ' ' + att_str
if type(kast) is KSyntaxPriority:
prioritiesStr = ' > '.join([' '.join(group) for group in kast.priorities])
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax priority ' + prioritiesStr + ' ' + attStr
priorities_str = ' > '.join([' '.join(group) for group in kast.priorities])
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return 'syntax priority ' + priorities_str + ' ' + att_str
if type(kast) is KBubble:
body = '// KBubble(' + kast.sentence_type + ', ' + kast.content + ')'
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return body + ' ' + attStr
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return body + ' ' + att_str
if type(kast) is KRule or type(kast) is KClaim:
body = '\n '.join(prettyPrintKast(kast.body, symbol_table, debug=debug).split('\n'))
ruleStr = 'rule ' if type(kast) is KRule else 'claim '
body = '\n '.join(pretty_print_kast(kast.body, symbol_table, debug=debug).split('\n'))
rule_str = 'rule ' if type(kast) is KRule else 'claim '
if 'label' in kast.att:
ruleStr = ruleStr + '[' + kast.att['label'] + ']:'
ruleStr = ruleStr + ' ' + body
attsStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
rule_str = rule_str + '[' + kast.att['label'] + ']:'
rule_str = rule_str + ' ' + body
atts_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
if kast.requires != Bool.true:
requiresStr = 'requires ' + '\n '.join(prettyPrintKastBool(kast.requires, symbol_table, debug=debug).split('\n'))
ruleStr = ruleStr + '\n ' + requiresStr
requires_str = 'requires ' + '\n '.join(pretty_print_kast_bool(kast.requires, symbol_table, debug=debug).split('\n'))
rule_str = rule_str + '\n ' + requires_str
if kast.ensures != Bool.true:
ensuresStr = 'ensures ' + '\n '.join(prettyPrintKastBool(kast.ensures, symbol_table, debug=debug).split('\n'))
ruleStr = ruleStr + '\n ' + ensuresStr
return ruleStr + '\n ' + attsStr
ensures_str = 'ensures ' + '\n '.join(pretty_print_kast_bool(kast.ensures, symbol_table, debug=debug).split('\n'))
rule_str = rule_str + '\n ' + ensures_str
return rule_str + '\n ' + atts_str
if type(kast) is KContext:
body = indent(prettyPrintKast(kast.body, symbol_table, debug=debug))
contextStr = 'context alias ' + body
requiresStr = ''
attsStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
body = indent(pretty_print_kast(kast.body, symbol_table, debug=debug))
context_str = 'context alias ' + body
requires_str = ''
atts_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
if kast.requires != Bool.true:
requiresStr = prettyPrintKast(kast.requires, symbol_table, debug=debug)
requiresStr = 'requires ' + indent(requiresStr)
return contextStr + '\n ' + requiresStr + '\n ' + attsStr
requires_str = pretty_print_kast(kast.requires, symbol_table, debug=debug)
requires_str = 'requires ' + indent(requires_str)
return context_str + '\n ' + requires_str + '\n ' + atts_str
if type(kast) is KAtt:
if not kast.atts:
return ''
attStrs = [k + '(' + v + ')' for k, v in kast.atts.items()]
return '[' + ', '.join(attStrs) + ']'
att_strs = [k + '(' + v + ')' for k, v in kast.atts.items()]
return '[' + ', '.join(att_strs) + ']'
if type(kast) is KImport:
return ' '.join(['imports', ('public' if kast.public else 'private'), kast.name])
if type(kast) is KFlatModule:
name = kast.name
imports = '\n'.join([prettyPrintKast(kimport, symbol_table, debug=debug) for kimport in kast.imports])
sentences = '\n\n'.join([prettyPrintKast(sentence, symbol_table, debug=debug) for sentence in kast.sentences])
imports = '\n'.join([pretty_print_kast(kimport, symbol_table, debug=debug) for kimport in kast.imports])
sentences = '\n\n'.join([pretty_print_kast(sentence, symbol_table, debug=debug) for sentence in kast.sentences])
contents = imports + '\n\n' + sentences
return 'module ' + name + '\n ' + '\n '.join(contents.split('\n')) + '\n\nendmodule'
if type(kast) is KRequire:
return 'requires "' + kast.require + '"'
if type(kast) is KDefinition:
requires = '\n'.join([prettyPrintKast(require, symbol_table, debug=debug) for require in kast.requires])
modules = '\n\n'.join([prettyPrintKast(module, symbol_table, debug=debug) for module in kast.modules])
requires = '\n'.join([pretty_print_kast(require, symbol_table, debug=debug) for require in kast.requires])
modules = '\n\n'.join([pretty_print_kast(module, symbol_table, debug=debug) for module in kast.modules])
return requires + '\n\n' + modules

raise ValueError(f'Error unparsing: {kast}')


def prettyPrintKastBool(kast, symbol_table, debug=False):
def pretty_print_kast_bool(kast, symbol_table, debug=False):
"""Print out KAST requires/ensures clause.

- Input: KAST Bool for requires/ensures clause.
Expand All @@ -240,7 +243,7 @@ def prettyPrintKastBool(kast, symbol_table, debug=False):
sys.stderr.write('\n')
sys.stderr.flush()
if type(kast) is KApply and kast.label.name in ['_andBool_', '_orBool_']:
clauses = [prettyPrintKastBool(c, symbol_table, debug=debug) for c in flatten_label(kast.label.name, kast)]
clauses = [pretty_print_kast_bool(c, symbol_table, debug=debug) for c in flatten_label(kast.label.name, kast)]
head = kast.label.name.replace('_', ' ')
if head == ' orBool ':
head = ' orBool '
Expand All @@ -253,14 +256,14 @@ def joinSep(s):
clauses = ['( ' + joinSep(clauses[0])] + [head + '( ' + joinSep(c) for c in clauses[1:]] + [spacer + (')' * len(clauses))]
return '\n'.join(clauses)
else:
return prettyPrintKast(kast, symbol_table, debug=debug)
return pretty_print_kast(kast, symbol_table, debug=debug)


def paren(printer):
return (lambda *args: '( ' + printer(*args) + ' )')


def appliedLabelStr(symbol):
def applied_label_str(symbol):
return (lambda *args: symbol + ' ( ' + ' , '.join(args) + ' )')


Expand Down
19 changes: 16 additions & 3 deletions pyk/src/pyk/tests/test_pretty_print_kast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from typing import Final, Tuple
from unittest import TestCase

from pyk.kast import KApply, KAst, KLabel, KProduction, KRule, KSort, KTerminal
from pyk.kast import (
KApply,
KAst,
KAtt,
KLabel,
KNonTerminal,
KProduction,
KRule,
KSort,
KTerminal,
)
from pyk.ktool.kprint import (
SymbolTable,
prettyPrintKast,
pretty_print_kast,
unparser_for_production,
)
from pyk.prelude import Bool
Expand All @@ -17,14 +27,17 @@ class PrettyPrintKastTest(TestCase):
(KRule(Bool.true), 'rule true\n '),
(KRule(Bool.true, ensures=Bool.true), 'rule true\n '),
(KRule(Bool.true, ensures=KApply('_andBool_', [Bool.true, Bool.true])), 'rule true\n ensures ( true\n andBool ( true\n ))\n '),
(KProduction(KSort('Test')), 'syntax Test'),
(KProduction(KSort('Test'), att=KAtt({'token': ''})), 'syntax Test [token()]'),
ehildenb marked this conversation as resolved.
Show resolved Hide resolved
(KProduction(KSort('Test'), [KTerminal('foo'), KNonTerminal(KSort('Int'))], att=KAtt({'function': ''})), 'syntax Test ::= "foo" Int [function()]'),
)

SYMBOL_TABLE: Final[SymbolTable] = {}

def test_pretty_print(self):
for i, (kast, expected) in enumerate(self.TEST_DATA):
with self.subTest(i=i):
actual = prettyPrintKast(kast, self.SYMBOL_TABLE)
actual = pretty_print_kast(kast, self.SYMBOL_TABLE)
actual_tokens = actual.split('\n')
expected_tokens = expected.split('\n')
self.assertListEqual(actual_tokens, expected_tokens)
Expand Down