Skip to content

Commit

Permalink
Fix sort declaration unparsing (#2767)
Browse files Browse the repository at this point in the history
* pyk/test_pretty_print_kast: add failing test of sort declaration

* pyk/ktool/kprint: fix bug in printing KProduction

* pyk/: rename prettyPrintKast => pretty_print_kast

* pyk/ktool/kprint: rename variables using new convention

* pyk/ktool/kprint: add test of failing sort declaration with an attribute

* pyk/kprint: another test of unparsing productions
  • Loading branch information
ehildenb authored Aug 4, 2022
1 parent 1eff007 commit 84491b6
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 80 deletions.
4 changes: 2 additions & 2 deletions pyk/src/pyk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
propagate_up_constraints,
removeSourceMap,
)
from .ktool import KPrint, KProve, build_symbol_table, prettyPrintKast
from .ktool import KPrint, KProve, build_symbol_table, pretty_print_kast
from .prelude import Sorts, mlAnd, mlOr, mlTop

_LOG_FORMAT: Final = '%(levelname)s %(asctime)s %(name)s - %(message)s'
Expand Down Expand Up @@ -84,7 +84,7 @@ def main():
args['output'].write('\n\n')
args['output'].write('Rule: ' + rid.strip())
args['output'].write('\nUnparsed:\n')
args['output'].write(prettyPrintKast(rule, symbol_table))
args['output'].write(pretty_print_kast(rule, symbol_table))

else:
raise ValueError(f'Unknown command: {args["command"]}')
Expand Down
5 changes: 2 additions & 3 deletions pyk/src/pyk/ktool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from .kompile import KompileBackend, kompile
from .kprint import (
KPrint,
appliedLabelStr,
applied_label_str,
build_symbol_table,
indent,
paren,
prettyPrintKast,
prettyPrintKastBool,
pretty_print_kast,
unparser_for_production,
)
from .kprove import KProve, kprove
147 changes: 75 additions & 72 deletions pyk/src/pyk/ktool/kprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def pretty_print(self, kast: KAst, debug=False):
- Input: KAST term in JSON.
- Output: Best-effort pretty-printed representation of the KAST term.
"""
return prettyPrintKast(kast, self.symbol_table, debug=debug)
return pretty_print_kast(kast, self.symbol_table, debug=debug)


def unparser_for_production(prod):
Expand Down Expand Up @@ -97,7 +97,7 @@ def build_symbol_table(definition: KDefinition, opinionated=False) -> SymbolTabl
return symbol_table


def prettyPrintKast(kast: KAst, symbol_table: SymbolTable, debug=False):
def pretty_print_kast(kast: KAst, symbol_table: SymbolTable, debug=False):
"""Print out KAST terms/outer syntax.
- Input: KAST term.
Expand All @@ -116,120 +116,123 @@ def prettyPrintKast(kast: KAst, symbol_table: SymbolTable, debug=False):
if type(kast) is KApply:
label = kast.label.name
args = kast.args
unparsedArgs = [prettyPrintKast(arg, symbol_table, debug=debug) for arg in args]
unparsed_args = [pretty_print_kast(arg, symbol_table, debug=debug) for arg in args]
if kast.is_cell:
cellContents = '\n'.join(unparsedArgs).rstrip()
cellStr = label + '\n' + indent(cellContents) + '\n</' + label[1:]
return cellStr.rstrip()
unparser = appliedLabelStr(label) if label not in symbol_table else symbol_table[label]
return unparser(*unparsedArgs)
cell_contents = '\n'.join(unparsed_args).rstrip()
cell_str = label + '\n' + indent(cell_contents) + '\n</' + label[1:]
return cell_str.rstrip()
unparser = applied_label_str(label) if label not in symbol_table else symbol_table[label]
return unparser(*unparsed_args)
if type(kast) is KAs:
patternStr = prettyPrintKast(kast.pattern, symbol_table, debug=debug)
aliasStr = prettyPrintKast(kast.alias, symbol_table, debug=debug)
return patternStr + ' #as ' + aliasStr
pattern_str = pretty_print_kast(kast.pattern, symbol_table, debug=debug)
alias_str = pretty_print_kast(kast.alias, symbol_table, debug=debug)
return pattern_str + ' #as ' + alias_str
if type(kast) is KRewrite:
lhsStr = prettyPrintKast(kast.lhs, symbol_table, debug=debug)
rhsStr = prettyPrintKast(kast.rhs, symbol_table, debug=debug)
return '( ' + lhsStr + ' => ' + rhsStr + ' )'
lhs_str = pretty_print_kast(kast.lhs, symbol_table, debug=debug)
rhs_str = pretty_print_kast(kast.rhs, symbol_table, debug=debug)
return '( ' + lhs_str + ' => ' + rhs_str + ' )'
if type(kast) is KSequence:
if kast.arity == 0:
return prettyPrintKast(KApply(Labels.EMPTY_K), symbol_table, debug=debug)
return pretty_print_kast(KApply(Labels.EMPTY_K), symbol_table, debug=debug)
if kast.arity == 1:
return prettyPrintKast(kast.items[0], symbol_table, debug=debug)
unparsedKSequence = '\n~> '.join([prettyPrintKast(item, symbol_table, debug=debug) for item in kast.items[0:-1]])
return pretty_print_kast(kast.items[0], symbol_table, debug=debug)
unparsed_k_seq = '\n~> '.join([pretty_print_kast(item, symbol_table, debug=debug) for item in kast.items[0:-1]])
if kast.items[-1] == ktokenDots:
unparsedKSequence = unparsedKSequence + '\n' + prettyPrintKast(ktokenDots, symbol_table, debug=debug)
unparsed_k_seq = unparsed_k_seq + '\n' + pretty_print_kast(ktokenDots, symbol_table, debug=debug)
else:
unparsedKSequence = unparsedKSequence + '\n~> ' + prettyPrintKast(kast.items[-1], symbol_table, debug=debug)
return unparsedKSequence
unparsed_k_seq = unparsed_k_seq + '\n~> ' + pretty_print_kast(kast.items[-1], symbol_table, debug=debug)
return unparsed_k_seq
if type(kast) is KTerminal:
return '"' + kast.value + '"'
if type(kast) is KRegexTerminal:
return 'r"' + kast.regex + '"'
if type(kast) is KNonTerminal:
return prettyPrintKast(kast.sort, symbol_table, debug=debug)
return pretty_print_kast(kast.sort, symbol_table, debug=debug)
if type(kast) is KProduction:
if 'klabel' not in kast.att and kast.klabel:
kast = kast.update_atts({'klabel': kast.klabel.name})
sortStr = prettyPrintKast(kast.sort, symbol_table, debug=debug)
productionStr = ' '.join([prettyPrintKast(pi, symbol_table, debug=debug) for pi in kast.items])
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax ' + sortStr + ' ::= ' + productionStr + ' ' + attStr
syntax_str = 'syntax ' + pretty_print_kast(kast.sort, symbol_table, debug=debug)
if kast.items:
syntax_str += ' ::= ' + ' '.join([pretty_print_kast(pi, symbol_table, debug=debug) for pi in kast.items])
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
if att_str:
syntax_str += ' ' + att_str
return syntax_str
if type(kast) is KSyntaxSort:
sortStr = prettyPrintKast(kast.sort, symbol_table, debug=debug)
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax ' + sortStr + ' ' + attStr
sort_str = pretty_print_kast(kast.sort, symbol_table, debug=debug)
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return 'syntax ' + sort_str + ' ' + att_str
if type(kast) is KSortSynonym:
newSortStr = prettyPrintKast(kast.new_sort, symbol_table, debug=debug)
oldSortStr = prettyPrintKast(kast.old_sort, symbol_table, debug=debug)
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax ' + newSortStr + ' = ' + oldSortStr + ' ' + attStr
new_sort_str = pretty_print_kast(kast.new_sort, symbol_table, debug=debug)
old_sort_str = pretty_print_kast(kast.old_sort, symbol_table, debug=debug)
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return 'syntax ' + new_sort_str + ' = ' + old_sort_str + ' ' + att_str
if type(kast) is KSyntaxLexical:
nameStr = kast.name
regexStr = kast.regex
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
name_str = kast.name
regex_str = kast.regex
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
# todo: proper escaping
return 'syntax lexical ' + nameStr + ' = r"' + regexStr + '" ' + attStr
return 'syntax lexical ' + name_str + ' = r"' + regex_str + '" ' + att_str
if type(kast) is KSyntaxAssociativity:
assocStr = kast.assoc.value
tagsStr = ' '.join(kast.tags)
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax associativity ' + assocStr + ' ' + tagsStr + ' ' + attStr
assoc_str = kast.assoc.value
tags_str = ' '.join(kast.tags)
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return 'syntax associativity ' + assoc_str + ' ' + tags_str + ' ' + att_str
if type(kast) is KSyntaxPriority:
prioritiesStr = ' > '.join([' '.join(group) for group in kast.priorities])
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return 'syntax priority ' + prioritiesStr + ' ' + attStr
priorities_str = ' > '.join([' '.join(group) for group in kast.priorities])
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return 'syntax priority ' + priorities_str + ' ' + att_str
if type(kast) is KBubble:
body = '// KBubble(' + kast.sentence_type + ', ' + kast.content + ')'
attStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
return body + ' ' + attStr
att_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
return body + ' ' + att_str
if type(kast) is KRule or type(kast) is KClaim:
body = '\n '.join(prettyPrintKast(kast.body, symbol_table, debug=debug).split('\n'))
ruleStr = 'rule ' if type(kast) is KRule else 'claim '
body = '\n '.join(pretty_print_kast(kast.body, symbol_table, debug=debug).split('\n'))
rule_str = 'rule ' if type(kast) is KRule else 'claim '
if 'label' in kast.att:
ruleStr = ruleStr + '[' + kast.att['label'] + ']:'
ruleStr = ruleStr + ' ' + body
attsStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
rule_str = rule_str + '[' + kast.att['label'] + ']:'
rule_str = rule_str + ' ' + body
atts_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
if kast.requires != Bool.true:
requiresStr = 'requires ' + '\n '.join(prettyPrintKastBool(kast.requires, symbol_table, debug=debug).split('\n'))
ruleStr = ruleStr + '\n ' + requiresStr
requires_str = 'requires ' + '\n '.join(pretty_print_kast_bool(kast.requires, symbol_table, debug=debug).split('\n'))
rule_str = rule_str + '\n ' + requires_str
if kast.ensures != Bool.true:
ensuresStr = 'ensures ' + '\n '.join(prettyPrintKastBool(kast.ensures, symbol_table, debug=debug).split('\n'))
ruleStr = ruleStr + '\n ' + ensuresStr
return ruleStr + '\n ' + attsStr
ensures_str = 'ensures ' + '\n '.join(pretty_print_kast_bool(kast.ensures, symbol_table, debug=debug).split('\n'))
rule_str = rule_str + '\n ' + ensures_str
return rule_str + '\n ' + atts_str
if type(kast) is KContext:
body = indent(prettyPrintKast(kast.body, symbol_table, debug=debug))
contextStr = 'context alias ' + body
requiresStr = ''
attsStr = prettyPrintKast(kast.att, symbol_table, debug=debug)
body = indent(pretty_print_kast(kast.body, symbol_table, debug=debug))
context_str = 'context alias ' + body
requires_str = ''
atts_str = pretty_print_kast(kast.att, symbol_table, debug=debug)
if kast.requires != Bool.true:
requiresStr = prettyPrintKast(kast.requires, symbol_table, debug=debug)
requiresStr = 'requires ' + indent(requiresStr)
return contextStr + '\n ' + requiresStr + '\n ' + attsStr
requires_str = pretty_print_kast(kast.requires, symbol_table, debug=debug)
requires_str = 'requires ' + indent(requires_str)
return context_str + '\n ' + requires_str + '\n ' + atts_str
if type(kast) is KAtt:
if not kast.atts:
return ''
attStrs = [k + '(' + v + ')' for k, v in kast.atts.items()]
return '[' + ', '.join(attStrs) + ']'
att_strs = [k + '(' + v + ')' for k, v in kast.atts.items()]
return '[' + ', '.join(att_strs) + ']'
if type(kast) is KImport:
return ' '.join(['imports', ('public' if kast.public else 'private'), kast.name])
if type(kast) is KFlatModule:
name = kast.name
imports = '\n'.join([prettyPrintKast(kimport, symbol_table, debug=debug) for kimport in kast.imports])
sentences = '\n\n'.join([prettyPrintKast(sentence, symbol_table, debug=debug) for sentence in kast.sentences])
imports = '\n'.join([pretty_print_kast(kimport, symbol_table, debug=debug) for kimport in kast.imports])
sentences = '\n\n'.join([pretty_print_kast(sentence, symbol_table, debug=debug) for sentence in kast.sentences])
contents = imports + '\n\n' + sentences
return 'module ' + name + '\n ' + '\n '.join(contents.split('\n')) + '\n\nendmodule'
if type(kast) is KRequire:
return 'requires "' + kast.require + '"'
if type(kast) is KDefinition:
requires = '\n'.join([prettyPrintKast(require, symbol_table, debug=debug) for require in kast.requires])
modules = '\n\n'.join([prettyPrintKast(module, symbol_table, debug=debug) for module in kast.modules])
requires = '\n'.join([pretty_print_kast(require, symbol_table, debug=debug) for require in kast.requires])
modules = '\n\n'.join([pretty_print_kast(module, symbol_table, debug=debug) for module in kast.modules])
return requires + '\n\n' + modules

raise ValueError(f'Error unparsing: {kast}')


def prettyPrintKastBool(kast, symbol_table, debug=False):
def pretty_print_kast_bool(kast, symbol_table, debug=False):
"""Print out KAST requires/ensures clause.
- Input: KAST Bool for requires/ensures clause.
Expand All @@ -240,7 +243,7 @@ def prettyPrintKastBool(kast, symbol_table, debug=False):
sys.stderr.write('\n')
sys.stderr.flush()
if type(kast) is KApply and kast.label.name in ['_andBool_', '_orBool_']:
clauses = [prettyPrintKastBool(c, symbol_table, debug=debug) for c in flatten_label(kast.label.name, kast)]
clauses = [pretty_print_kast_bool(c, symbol_table, debug=debug) for c in flatten_label(kast.label.name, kast)]
head = kast.label.name.replace('_', ' ')
if head == ' orBool ':
head = ' orBool '
Expand All @@ -253,14 +256,14 @@ def joinSep(s):
clauses = ['( ' + joinSep(clauses[0])] + [head + '( ' + joinSep(c) for c in clauses[1:]] + [spacer + (')' * len(clauses))]
return '\n'.join(clauses)
else:
return prettyPrintKast(kast, symbol_table, debug=debug)
return pretty_print_kast(kast, symbol_table, debug=debug)


def paren(printer):
return (lambda *args: '( ' + printer(*args) + ' )')


def appliedLabelStr(symbol):
def applied_label_str(symbol):
return (lambda *args: symbol + ' ( ' + ' , '.join(args) + ' )')


Expand Down
19 changes: 16 additions & 3 deletions pyk/src/pyk/tests/test_pretty_print_kast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from typing import Final, Tuple
from unittest import TestCase

from pyk.kast import KApply, KAst, KLabel, KProduction, KRule, KSort, KTerminal
from pyk.kast import (
KApply,
KAst,
KAtt,
KLabel,
KNonTerminal,
KProduction,
KRule,
KSort,
KTerminal,
)
from pyk.ktool.kprint import (
SymbolTable,
prettyPrintKast,
pretty_print_kast,
unparser_for_production,
)
from pyk.prelude import Bool
Expand All @@ -17,14 +27,17 @@ class PrettyPrintKastTest(TestCase):
(KRule(Bool.true), 'rule true\n '),
(KRule(Bool.true, ensures=Bool.true), 'rule true\n '),
(KRule(Bool.true, ensures=KApply('_andBool_', [Bool.true, Bool.true])), 'rule true\n ensures ( true\n andBool ( true\n ))\n '),
(KProduction(KSort('Test')), 'syntax Test'),
(KProduction(KSort('Test'), att=KAtt({'token': ''})), 'syntax Test [token()]'),
(KProduction(KSort('Test'), [KTerminal('foo'), KNonTerminal(KSort('Int'))], att=KAtt({'function': ''})), 'syntax Test ::= "foo" Int [function()]'),
)

SYMBOL_TABLE: Final[SymbolTable] = {}

def test_pretty_print(self):
for i, (kast, expected) in enumerate(self.TEST_DATA):
with self.subTest(i=i):
actual = prettyPrintKast(kast, self.SYMBOL_TABLE)
actual = pretty_print_kast(kast, self.SYMBOL_TABLE)
actual_tokens = actual.split('\n')
expected_tokens = expected.split('\n')
self.assertListEqual(actual_tokens, expected_tokens)
Expand Down

0 comments on commit 84491b6

Please sign in to comment.