Skip to content

Commit

Permalink
Take into account precedence when parsing arithmetic subexpressions
Browse files Browse the repository at this point in the history
  • Loading branch information
Dashadower committed Oct 31, 2022
1 parent 93b5132 commit e621504
Showing 1 changed file with 78 additions and 6 deletions.
84 changes: 78 additions & 6 deletions stanify/builders/ast_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,19 @@ def walk(self, ast_node, node_name: str) -> None:
class BlockCodegenWalker(BaseNodeWaler):
lookup_function_names: Dict[str, str]
datastructure_function_names: Set[str]

def walk(self, ast_node) -> str:

operator_precedence = { # lower comes first
"+": 3,
"-": 3,
"*": 2,
"/": 2,
"^": 1
}

def walk(self, ast_node, current_precedence: int = 100) -> str:
"""
current_precedence is only relevant for ArithmeticStructure. It represents the precedence level of the
current(parent) expression being parsed.
"""
if isinstance(ast_node, int):
return f"{ast_node}"
elif isinstance(ast_node, float):
Expand All @@ -208,14 +218,53 @@ def walk(self, ast_node) -> str:
elif isinstance(ast_node, ArithmeticStructure):
# ArithmeticStructure consists of chained arithmetic expressions.
# We parse them one by one into a single expression
# Assume ArithmeticStructure.operators are in order of precedence(lower comes first)
"""
Suppose we have:
A = 1 + 1; B = A * 2
Naively substituting in A will result in:
1 + 1 * 2
which is not correct, and instead should be:
(1 + 1) * 2
Assume we have the following AST:
* # precedence level 2
/\
/ 2
+ # precedence level 3
/\
1 1
We check if the operators of the subtrees of '*' have a precedence level higher than its parent, and
enclose them in parentheses if so.
"""

output_string = ""
last_argument_index = len(ast_node.arguments) - 1

# Find the maximum precedence value of the operators
max_precedence = max([self.operator_precedence[op] for op in ast_node.operators])

if max_precedence > current_precedence:
output_string += "("

for index, argument in enumerate(ast_node.arguments):
output_string += self.walk(argument)
# Find the operators which set the precedence level of its children
if index == 0:
operators_to_check = [ast_node.operators[0]]
elif index == last_argument_index:
operators_to_check = [ast_node.operators[-1]]
else:
operators_to_check = [ast_node.operators[index - 1], ast_node.operators[index]]

# Pass the minimum precedence level
output_string += self.walk(argument, min([self.operator_precedence[op] for op in operators_to_check]))

if index < last_argument_index:
output_string += " "
output_string += ast_node.operators[index]
output_string += " "

if max_precedence > current_precedence:
output_string += ")"
return output_string

elif isinstance(ast_node, ReferenceStructure):
Expand Down Expand Up @@ -291,7 +340,7 @@ def walk(self, ast_node) -> str:
class InitialValueCodegenWalker(BlockCodegenWalker):
variable_ast_dict: Dict[str, AbstractSyntax]

def walk(self, ast_node):
def walk(self, ast_node, current_precedence: int = 100):

if isinstance(ast_node, IntegStructure):
return self.walk(ast_node.initial)
Expand All @@ -312,15 +361,38 @@ def walk(self, ast_node):
elif isinstance(ast_node, ArithmeticStructure):
# ArithmeticStructure consists of chained arithmetic expressions.
# We parse them one by one into a single expression
# Assume ArithmeticStructure.operators are in order of precedence(lower comes first)
# TODO: This is duplicated code from BlockCodegenWalker.
output_string = ""
last_argument_index = len(ast_node.arguments) - 1

# Find the maximum precedence value of the operators
max_precedence = max([self.operator_precedence[op] for op in ast_node.operators])

if max_precedence > current_precedence:
output_string += "("

for index, argument in enumerate(ast_node.arguments):
output_string += self.walk(argument)
# Find the operators which set the precedence level of its children
if index == 0:
operators_to_check = [ast_node.operators[0]]
elif index == last_argument_index:
operators_to_check = [ast_node.operators[-1]]
else:
operators_to_check = [ast_node.operators[index - 1], ast_node.operators[index]]

# Pass the minimum precedence level
output_string += self.walk(argument, min([self.operator_precedence[op] for op in operators_to_check]))

if index < last_argument_index:
output_string += " "
output_string += ast_node.operators[index]
output_string += " "

if max_precedence > current_precedence:
output_string += ")"
return output_string

else:
return super().walk(ast_node)

Expand Down

0 comments on commit e621504

Please sign in to comment.