diff --git a/pyk/src/pyk/__main__.py b/pyk/src/pyk/__main__.py index 5c1fa8863cd..828df3326fa 100644 --- a/pyk/src/pyk/__main__.py +++ b/pyk/src/pyk/__main__.py @@ -15,7 +15,7 @@ propagate_up_constraints, removeSourceMap, ) -from .ktool import KPrint, KProve, build_symbol_table, prettyPrintKast +from .ktool import KPrint, KProve, build_symbol_table, pretty_print_kast from .prelude import Sorts, mlAnd, mlOr, mlTop _LOG_FORMAT: Final = '%(levelname)s %(asctime)s %(name)s - %(message)s' @@ -84,7 +84,7 @@ def main(): args['output'].write('\n\n') args['output'].write('Rule: ' + rid.strip()) args['output'].write('\nUnparsed:\n') - args['output'].write(prettyPrintKast(rule, symbol_table)) + args['output'].write(pretty_print_kast(rule, symbol_table)) else: raise ValueError(f'Unknown command: {args["command"]}') diff --git a/pyk/src/pyk/ktool/__init__.py b/pyk/src/pyk/ktool/__init__.py index b02872349a1..1c1b15f8692 100644 --- a/pyk/src/pyk/ktool/__init__.py +++ b/pyk/src/pyk/ktool/__init__.py @@ -1,12 +1,11 @@ from .kompile import KompileBackend, kompile from .kprint import ( KPrint, - appliedLabelStr, + applied_label_str, build_symbol_table, indent, paren, - prettyPrintKast, - prettyPrintKastBool, + pretty_print_kast, unparser_for_production, ) from .kprove import KProve, kprove diff --git a/pyk/src/pyk/ktool/kprint.py b/pyk/src/pyk/ktool/kprint.py index 19dedd94f17..f8c5b609e7d 100644 --- a/pyk/src/pyk/ktool/kprint.py +++ b/pyk/src/pyk/ktool/kprint.py @@ -55,7 +55,7 @@ def pretty_print(self, kast: KAst, debug=False): - Input: KAST term in JSON. - Output: Best-effort pretty-printed representation of the KAST term. """ - return prettyPrintKast(kast, self.symbol_table, debug=debug) + return pretty_print_kast(kast, self.symbol_table, debug=debug) def unparser_for_production(prod): @@ -97,7 +97,7 @@ def build_symbol_table(definition: KDefinition, opinionated=False) -> SymbolTabl return symbol_table -def prettyPrintKast(kast: KAst, symbol_table: SymbolTable, debug=False): +def pretty_print_kast(kast: KAst, symbol_table: SymbolTable, debug=False): """Print out KAST terms/outer syntax. - Input: KAST term. @@ -116,120 +116,123 @@ def prettyPrintKast(kast: KAst, symbol_table: SymbolTable, debug=False): if type(kast) is KApply: label = kast.label.name args = kast.args - unparsedArgs = [prettyPrintKast(arg, symbol_table, debug=debug) for arg in args] + unparsed_args = [pretty_print_kast(arg, symbol_table, debug=debug) for arg in args] if kast.is_cell: - cellContents = '\n'.join(unparsedArgs).rstrip() - cellStr = label + '\n' + indent(cellContents) + '\n ' + rhsStr + ' )' + lhs_str = pretty_print_kast(kast.lhs, symbol_table, debug=debug) + rhs_str = pretty_print_kast(kast.rhs, symbol_table, debug=debug) + return '( ' + lhs_str + ' => ' + rhs_str + ' )' if type(kast) is KSequence: if kast.arity == 0: - return prettyPrintKast(KApply(Labels.EMPTY_K), symbol_table, debug=debug) + return pretty_print_kast(KApply(Labels.EMPTY_K), symbol_table, debug=debug) if kast.arity == 1: - return prettyPrintKast(kast.items[0], symbol_table, debug=debug) - unparsedKSequence = '\n~> '.join([prettyPrintKast(item, symbol_table, debug=debug) for item in kast.items[0:-1]]) + return pretty_print_kast(kast.items[0], symbol_table, debug=debug) + unparsed_k_seq = '\n~> '.join([pretty_print_kast(item, symbol_table, debug=debug) for item in kast.items[0:-1]]) if kast.items[-1] == ktokenDots: - unparsedKSequence = unparsedKSequence + '\n' + prettyPrintKast(ktokenDots, symbol_table, debug=debug) + unparsed_k_seq = unparsed_k_seq + '\n' + pretty_print_kast(ktokenDots, symbol_table, debug=debug) else: - unparsedKSequence = unparsedKSequence + '\n~> ' + prettyPrintKast(kast.items[-1], symbol_table, debug=debug) - return unparsedKSequence + unparsed_k_seq = unparsed_k_seq + '\n~> ' + pretty_print_kast(kast.items[-1], symbol_table, debug=debug) + return unparsed_k_seq if type(kast) is KTerminal: return '"' + kast.value + '"' if type(kast) is KRegexTerminal: return 'r"' + kast.regex + '"' if type(kast) is KNonTerminal: - return prettyPrintKast(kast.sort, symbol_table, debug=debug) + return pretty_print_kast(kast.sort, symbol_table, debug=debug) if type(kast) is KProduction: if 'klabel' not in kast.att and kast.klabel: kast = kast.update_atts({'klabel': kast.klabel.name}) - sortStr = prettyPrintKast(kast.sort, symbol_table, debug=debug) - productionStr = ' '.join([prettyPrintKast(pi, symbol_table, debug=debug) for pi in kast.items]) - attStr = prettyPrintKast(kast.att, symbol_table, debug=debug) - return 'syntax ' + sortStr + ' ::= ' + productionStr + ' ' + attStr + syntax_str = 'syntax ' + pretty_print_kast(kast.sort, symbol_table, debug=debug) + if kast.items: + syntax_str += ' ::= ' + ' '.join([pretty_print_kast(pi, symbol_table, debug=debug) for pi in kast.items]) + att_str = pretty_print_kast(kast.att, symbol_table, debug=debug) + if att_str: + syntax_str += ' ' + att_str + return syntax_str if type(kast) is KSyntaxSort: - sortStr = prettyPrintKast(kast.sort, symbol_table, debug=debug) - attStr = prettyPrintKast(kast.att, symbol_table, debug=debug) - return 'syntax ' + sortStr + ' ' + attStr + sort_str = pretty_print_kast(kast.sort, symbol_table, debug=debug) + att_str = pretty_print_kast(kast.att, symbol_table, debug=debug) + return 'syntax ' + sort_str + ' ' + att_str if type(kast) is KSortSynonym: - newSortStr = prettyPrintKast(kast.new_sort, symbol_table, debug=debug) - oldSortStr = prettyPrintKast(kast.old_sort, symbol_table, debug=debug) - attStr = prettyPrintKast(kast.att, symbol_table, debug=debug) - return 'syntax ' + newSortStr + ' = ' + oldSortStr + ' ' + attStr + new_sort_str = pretty_print_kast(kast.new_sort, symbol_table, debug=debug) + old_sort_str = pretty_print_kast(kast.old_sort, symbol_table, debug=debug) + att_str = pretty_print_kast(kast.att, symbol_table, debug=debug) + return 'syntax ' + new_sort_str + ' = ' + old_sort_str + ' ' + att_str if type(kast) is KSyntaxLexical: - nameStr = kast.name - regexStr = kast.regex - attStr = prettyPrintKast(kast.att, symbol_table, debug=debug) + name_str = kast.name + regex_str = kast.regex + att_str = pretty_print_kast(kast.att, symbol_table, debug=debug) # todo: proper escaping - return 'syntax lexical ' + nameStr + ' = r"' + regexStr + '" ' + attStr + return 'syntax lexical ' + name_str + ' = r"' + regex_str + '" ' + att_str if type(kast) is KSyntaxAssociativity: - assocStr = kast.assoc.value - tagsStr = ' '.join(kast.tags) - attStr = prettyPrintKast(kast.att, symbol_table, debug=debug) - return 'syntax associativity ' + assocStr + ' ' + tagsStr + ' ' + attStr + assoc_str = kast.assoc.value + tags_str = ' '.join(kast.tags) + att_str = pretty_print_kast(kast.att, symbol_table, debug=debug) + return 'syntax associativity ' + assoc_str + ' ' + tags_str + ' ' + att_str if type(kast) is KSyntaxPriority: - prioritiesStr = ' > '.join([' '.join(group) for group in kast.priorities]) - attStr = prettyPrintKast(kast.att, symbol_table, debug=debug) - return 'syntax priority ' + prioritiesStr + ' ' + attStr + priorities_str = ' > '.join([' '.join(group) for group in kast.priorities]) + att_str = pretty_print_kast(kast.att, symbol_table, debug=debug) + return 'syntax priority ' + priorities_str + ' ' + att_str if type(kast) is KBubble: body = '// KBubble(' + kast.sentence_type + ', ' + kast.content + ')' - attStr = prettyPrintKast(kast.att, symbol_table, debug=debug) - return body + ' ' + attStr + att_str = pretty_print_kast(kast.att, symbol_table, debug=debug) + return body + ' ' + att_str if type(kast) is KRule or type(kast) is KClaim: - body = '\n '.join(prettyPrintKast(kast.body, symbol_table, debug=debug).split('\n')) - ruleStr = 'rule ' if type(kast) is KRule else 'claim ' + body = '\n '.join(pretty_print_kast(kast.body, symbol_table, debug=debug).split('\n')) + rule_str = 'rule ' if type(kast) is KRule else 'claim ' if 'label' in kast.att: - ruleStr = ruleStr + '[' + kast.att['label'] + ']:' - ruleStr = ruleStr + ' ' + body - attsStr = prettyPrintKast(kast.att, symbol_table, debug=debug) + rule_str = rule_str + '[' + kast.att['label'] + ']:' + rule_str = rule_str + ' ' + body + atts_str = pretty_print_kast(kast.att, symbol_table, debug=debug) if kast.requires != Bool.true: - requiresStr = 'requires ' + '\n '.join(prettyPrintKastBool(kast.requires, symbol_table, debug=debug).split('\n')) - ruleStr = ruleStr + '\n ' + requiresStr + requires_str = 'requires ' + '\n '.join(pretty_print_kast_bool(kast.requires, symbol_table, debug=debug).split('\n')) + rule_str = rule_str + '\n ' + requires_str if kast.ensures != Bool.true: - ensuresStr = 'ensures ' + '\n '.join(prettyPrintKastBool(kast.ensures, symbol_table, debug=debug).split('\n')) - ruleStr = ruleStr + '\n ' + ensuresStr - return ruleStr + '\n ' + attsStr + ensures_str = 'ensures ' + '\n '.join(pretty_print_kast_bool(kast.ensures, symbol_table, debug=debug).split('\n')) + rule_str = rule_str + '\n ' + ensures_str + return rule_str + '\n ' + atts_str if type(kast) is KContext: - body = indent(prettyPrintKast(kast.body, symbol_table, debug=debug)) - contextStr = 'context alias ' + body - requiresStr = '' - attsStr = prettyPrintKast(kast.att, symbol_table, debug=debug) + body = indent(pretty_print_kast(kast.body, symbol_table, debug=debug)) + context_str = 'context alias ' + body + requires_str = '' + atts_str = pretty_print_kast(kast.att, symbol_table, debug=debug) if kast.requires != Bool.true: - requiresStr = prettyPrintKast(kast.requires, symbol_table, debug=debug) - requiresStr = 'requires ' + indent(requiresStr) - return contextStr + '\n ' + requiresStr + '\n ' + attsStr + requires_str = pretty_print_kast(kast.requires, symbol_table, debug=debug) + requires_str = 'requires ' + indent(requires_str) + return context_str + '\n ' + requires_str + '\n ' + atts_str if type(kast) is KAtt: if not kast.atts: return '' - attStrs = [k + '(' + v + ')' for k, v in kast.atts.items()] - return '[' + ', '.join(attStrs) + ']' + att_strs = [k + '(' + v + ')' for k, v in kast.atts.items()] + return '[' + ', '.join(att_strs) + ']' if type(kast) is KImport: return ' '.join(['imports', ('public' if kast.public else 'private'), kast.name]) if type(kast) is KFlatModule: name = kast.name - imports = '\n'.join([prettyPrintKast(kimport, symbol_table, debug=debug) for kimport in kast.imports]) - sentences = '\n\n'.join([prettyPrintKast(sentence, symbol_table, debug=debug) for sentence in kast.sentences]) + imports = '\n'.join([pretty_print_kast(kimport, symbol_table, debug=debug) for kimport in kast.imports]) + sentences = '\n\n'.join([pretty_print_kast(sentence, symbol_table, debug=debug) for sentence in kast.sentences]) contents = imports + '\n\n' + sentences return 'module ' + name + '\n ' + '\n '.join(contents.split('\n')) + '\n\nendmodule' if type(kast) is KRequire: return 'requires "' + kast.require + '"' if type(kast) is KDefinition: - requires = '\n'.join([prettyPrintKast(require, symbol_table, debug=debug) for require in kast.requires]) - modules = '\n\n'.join([prettyPrintKast(module, symbol_table, debug=debug) for module in kast.modules]) + requires = '\n'.join([pretty_print_kast(require, symbol_table, debug=debug) for require in kast.requires]) + modules = '\n\n'.join([pretty_print_kast(module, symbol_table, debug=debug) for module in kast.modules]) return requires + '\n\n' + modules raise ValueError(f'Error unparsing: {kast}') -def prettyPrintKastBool(kast, symbol_table, debug=False): +def pretty_print_kast_bool(kast, symbol_table, debug=False): """Print out KAST requires/ensures clause. - Input: KAST Bool for requires/ensures clause. @@ -240,7 +243,7 @@ def prettyPrintKastBool(kast, symbol_table, debug=False): sys.stderr.write('\n') sys.stderr.flush() if type(kast) is KApply and kast.label.name in ['_andBool_', '_orBool_']: - clauses = [prettyPrintKastBool(c, symbol_table, debug=debug) for c in flatten_label(kast.label.name, kast)] + clauses = [pretty_print_kast_bool(c, symbol_table, debug=debug) for c in flatten_label(kast.label.name, kast)] head = kast.label.name.replace('_', ' ') if head == ' orBool ': head = ' orBool ' @@ -253,14 +256,14 @@ def joinSep(s): clauses = ['( ' + joinSep(clauses[0])] + [head + '( ' + joinSep(c) for c in clauses[1:]] + [spacer + (')' * len(clauses))] return '\n'.join(clauses) else: - return prettyPrintKast(kast, symbol_table, debug=debug) + return pretty_print_kast(kast, symbol_table, debug=debug) def paren(printer): return (lambda *args: '( ' + printer(*args) + ' )') -def appliedLabelStr(symbol): +def applied_label_str(symbol): return (lambda *args: symbol + ' ( ' + ' , '.join(args) + ' )') diff --git a/pyk/src/pyk/tests/test_pretty_print_kast.py b/pyk/src/pyk/tests/test_pretty_print_kast.py index 0da09d79890..13cc7ce1447 100644 --- a/pyk/src/pyk/tests/test_pretty_print_kast.py +++ b/pyk/src/pyk/tests/test_pretty_print_kast.py @@ -1,10 +1,20 @@ from typing import Final, Tuple from unittest import TestCase -from pyk.kast import KApply, KAst, KLabel, KProduction, KRule, KSort, KTerminal +from pyk.kast import ( + KApply, + KAst, + KAtt, + KLabel, + KNonTerminal, + KProduction, + KRule, + KSort, + KTerminal, +) from pyk.ktool.kprint import ( SymbolTable, - prettyPrintKast, + pretty_print_kast, unparser_for_production, ) from pyk.prelude import Bool @@ -17,6 +27,9 @@ class PrettyPrintKastTest(TestCase): (KRule(Bool.true), 'rule true\n '), (KRule(Bool.true, ensures=Bool.true), 'rule true\n '), (KRule(Bool.true, ensures=KApply('_andBool_', [Bool.true, Bool.true])), 'rule true\n ensures ( true\n andBool ( true\n ))\n '), + (KProduction(KSort('Test')), 'syntax Test'), + (KProduction(KSort('Test'), att=KAtt({'token': ''})), 'syntax Test [token()]'), + (KProduction(KSort('Test'), [KTerminal('foo'), KNonTerminal(KSort('Int'))], att=KAtt({'function': ''})), 'syntax Test ::= "foo" Int [function()]'), ) SYMBOL_TABLE: Final[SymbolTable] = {} @@ -24,7 +37,7 @@ class PrettyPrintKastTest(TestCase): def test_pretty_print(self): for i, (kast, expected) in enumerate(self.TEST_DATA): with self.subTest(i=i): - actual = prettyPrintKast(kast, self.SYMBOL_TABLE) + actual = pretty_print_kast(kast, self.SYMBOL_TABLE) actual_tokens = actual.split('\n') expected_tokens = expected.split('\n') self.assertListEqual(actual_tokens, expected_tokens)