Skip to content

Commit

Permalink
disabled start on implict captures to explicit capture lists as an as…
Browse files Browse the repository at this point in the history
…t->ast transformation prior to codegen (needed for by const ref capture for non-escaping/immediately invoked lambdas rather than simple [&] mutable ref capture). needs fixes.
  • Loading branch information
ehren committed Aug 25, 2024
1 parent eaab6b8 commit a20925c
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 8 deletions.
9 changes: 6 additions & 3 deletions ceto/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,10 @@ def is_capture(n):
if isinstance(node.parent, Call) and node is node.parent.func:
# immediately invoked (TODO: nonescaping)
# capture by const ref: https://stackoverflow.com/questions/3772867/lambda-capture-as-const-reference/32440415#32440415
capture_list = ["&" + i + " = " + "std::as_const(" + i + ")" for i in possible_captures]
# TODO we still have some work with lowering implicit captures to explicit capture lists as an ast->ast pass in semanticanalysis
# before this can work (scope handling and _decltype_string with explicit capture lists fixes)
#capture_list = ["&" + i + " = " + "std::as_const(" + i + ")" for i in possible_captures]
capture_list = ["&"] # just allow mutable ref capture until we resolve the above
else:
# capture only a few things by const value (shared/weak instances, arithithmetic_v, enums):
capture_list = [i + " = " + "ceto::default_capture(" + i + ")" for i in possible_captures]
Expand Down Expand Up @@ -1711,7 +1714,7 @@ def codegen_type(expr_node, type_node, cx):
pass
elif not isinstance(expr_node, (ListLiteral, TupleLiteral, Call, Identifier, TypeOp, AttributeAccess)):
raise CodeGenError("unexpected typed expression", expr_node)
if isinstance(expr_node, Call) and expr_node.func.name not in ["lambda", "def"]:
if isinstance(expr_node, Call) and not is_call_lambda(expr_node) and expr_node.func.name != "def":
raise CodeGenError("unexpected typed call", expr_node)

types = type_node_to_list_of_types(type_node)
Expand Down Expand Up @@ -2572,7 +2575,7 @@ def codegen_node(node: Node, cx: Scope):
raise CodeGenError("unexpected context for typed construct", node)

return codegen_type(node, node, cx) # this is a type inside a more complicated expression e.g. std.is_same_v<Foo, int:ptr>
elif isinstance(node, Call) and node.func.name not in ["lambda", "def"] and node.declared_type.name not in ["const", "mut"]:
elif isinstance(node, Call) and node.func.name != "def" and not is_call_lambda(node) and node.declared_type.name not in ["const", "mut"]:
raise CodeGenError("Unexpected typed call", node)

if isinstance(node, Call):
Expand Down
5 changes: 4 additions & 1 deletion ceto/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,21 @@ def report_error(e):
print(" " * (col) + "^", file=sys.stderr)
return
elif isinstance(e, (SemanticAnalysisError, CodeGenError)):
print(e.args)
try:
msg, node = e.args
except ValueError:
pass
else:
if isinstance(node, Node):
loc = node.source[1]
loc = node.source.loc
#source = node.source.source
# # lineindex = source.count("\n", 0, loc)
# beg = source.rfind("\n", loc)
# end = source.find("\n", loc)
# print(source[beg:end], file=sys.stderr)
# print(" " * (beg) + "^", file=sys.stderr)
print(e.__class__.__name__)
print(source[loc:loc+10], file=sys.stderr)
# print(" " * (beg) + "^", file=sys.stderr)
print(msg, file=sys.stderr)
Expand Down
7 changes: 6 additions & 1 deletion ceto/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ class ParameterDefinition(VariableDefinition):


def creates_new_variable_scope(e: Node) -> bool:
return isinstance(e, Call) and e.func.name in ["def", "lambda", "class", "struct"]
if isinstance(e, Call):
if e.func.name in ["def", "lambda", "class", "struct"]:
return True
elif isinstance(e.func, ArrayAccess) and e.func.func.name == "lambda":
return True
return False


def _node_depth(node):
Expand Down
91 changes: 91 additions & 0 deletions ceto/semanticanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,96 @@ def visit_Module(self, module):
self._module_scope = module.scope


class ImplicitLambdaCaptureVisitor:

def visit_Call(self, call):
# TODO needs fixes + fixes for explicit capture lists in _decltype_str
return call

def replace_lambda(old_call, new_capture_list):
# this is why scope/parent/declared_type are unavailable to the macro system!
# TODO allow scope lookup in the macro system and selfhost passes without this brittleness
lmb = Identifier("lambda")
lmb.scope = old_call.func.scope
new_call = Call(func=ArrayAccess(func=lmb, args=[]), args=old_call.args)
new_call.parent = old_call.parent
new_call.scope = old_call.scope
if old_call.declared_type:
new_call.declared_type = old_call.declared_type
print("new_call.declared_type", new_call.declared_type)
new_call.declared_type.parent = new_call
if old_call.declared_type.scope:
new_call.declared_type.scope = old_call.declared_type.scope
new_call.func.parent = new_call
new_call.func.scope = old_call.func.scope
block = new_call.args[-1]
assert isinstance(block, Block)
last_statement = block.args[-1]
if hasattr(last_statement, "synthetic_lambda_return_lambda") and last_statement.synthetic_lambda_return_lambda:
last_statement.synthetic_lambda_return_lambda = new_call
for a in new_call.args:
a.parent = new_call
return new_call

from .parser import parse

if not call.func.name == "lambda":
# we don't want is_call_lambda here (lambdas with explicit capture lists handled in codegen)
return call

if not call.parent.scope.in_function_body:
return replace_lambda(call, [])

def is_capture(n):
if not isinstance(n, Identifier):
return False
elif isinstance(n.parent, (Call, ArrayAccess, BracedCall, Template)) and n is n.parent.func:
return False
elif isinstance(n.parent, AttributeAccess) and n is n.parent.rhs:
return False
return True

# find all identifiers but not call funcs etc or anything in a nested class
idents = find_all(call, test=is_capture, stop=lambda c: isinstance(c.func, Identifier) and c.func.name in ["class", "struct"])

idents = {i.name: i for i in idents}.values() # remove duplicates

possible_captures = []
for i in idents:
if i.name == "self":
possible_captures.append(i.name)
elif isinstance(i.parent, Call) and i.parent.func.name in ["def", "lambda"]:
pass # don't capture a lambda parameter
elif (d := i.scope.find_def(i)) and isinstance(d, (LocalVariableDefinition, ParameterDefinition)):
defnode = d.defined_node
is_capture = True
while defnode is not None:
if defnode is call:
# defined in lambda or by lambda params (not a capture)
is_capture = False
break
defnode = defnode.parent
if is_capture:
possible_captures.append(i.name)

if isinstance(call.parent, Call) and call is call.parent.func:
# immediately invoked (TODO: nonescaping)
# capture by const ref: https://stackoverflow.com/questions/3772867/lambda-capture-as-const-reference/32440415#32440415
capture_list = ["&" + i + " = " + "std::as_const(" + i + ")" for i in possible_captures]
else:
# capture only a few things by const value (shared/weak instances, arithithmetic_v, enums):
capture_list = [i + " = " + "ceto::default_capture(" + i + ")" for i in possible_captures]

# this is lazy but it's fine
capture_list_ast = [parse(s).args[0] for s in capture_list]

new_lambda = replace_lambda(call, capture_list_ast)
#for capture_list_arg in new_lambda.func.args:
# TODO need to add_variable_definition for newly added capture list
# capture_list_arg.scope
return new_lambda


def apply_replacers(module: Module, visitors):

def replace(node):
Expand Down Expand Up @@ -912,6 +1002,7 @@ def semantic_analysis(expr: Module):
expr = build_types(expr)
expr = build_parents(expr)
expr = apply_replacers(expr, [ScopeVisitor()])
expr = apply_replacers(expr, [ImplicitLambdaCaptureVisitor()])

def defs(node):
if not isinstance(node, Node):
Expand Down
8 changes: 6 additions & 2 deletions selfhost/scope.cth
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@ class (ParameterDefinition(VariableDefinition):
def (creates_new_variable_scope, e: Node:
if (isinstance(e, Call):
name = e.func.name()
return name and contains(["def"s, "lambda"s, "class"s, "struct"s], name.value())
if (name:
return contains(["def"s, "lambda"s, "class"s, "struct"s], name.value())
elif isinstance(e.func, ArrayAccess) and e.func.func.name() == "lambda":
return True
)
)
return false
return False
)


Expand Down
6 changes: 5 additions & 1 deletion selfhost/scope.donotedit.autogenerated.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ using VariableDefinition::VariableDefinition;
inline auto creates_new_variable_scope(const std::shared_ptr<const Node>& e) -> auto {
if ((std::dynamic_pointer_cast<const Call>(e) != nullptr)) {
const auto name = (*ceto::mad((*ceto::mad(e)).func)).name();
return (name && contains(std::vector {{std::string {"def"}, std::string {"lambda"}, std::string {"class"}, std::string {"struct"}}}, (*ceto::mad_smartptr(name)).value()));
if (name) {
return contains(std::vector {{std::string {"def"}, std::string {"lambda"}, std::string {"class"}, std::string {"struct"}}}, (*ceto::mad_smartptr(name)).value());
} else if (((std::dynamic_pointer_cast<const ArrayAccess>((*ceto::mad(e)).func) != nullptr) && ((*ceto::mad((*ceto::mad((*ceto::mad(e)).func)).func)).name() == "lambda"))) {
return true;
}
}
return false;
}
Expand Down

0 comments on commit a20925c

Please sign in to comment.