Skip to content

Commit

Permalink
Ensure new and wrapper nodes inherit UUID (#6067)
Browse files Browse the repository at this point in the history
Instrumentation of calls involving warning values never really worked because:
1) newly created nodes didn't set the UUID of their children
2) the instrumentable wrappers always had an empty (i.e. null) UUID and
they never referred `get`/`setId` calls to their delegates

On the surface, everything worked fine. Except when one actually relied on the instrumentation of values with warnings for proper setup. Then no instrumentation (replacement of nodes) was performed due to empty UUID (as required by `hasTag` of `FunctionCallInstrumentationNode`).

Closes #6045. Discovered in #5893.
  • Loading branch information
hubertp authored Mar 27, 2023
1 parent b977b5a commit 76409b2
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 14 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@
- [Use SHA-1 for calculating hashes of modules' IR and bindings][5791]
- [Don't install Python component on Windows][5900]
- [Detect potential name conflicts between exported types and FQNs][5966]
- [Ensure calls involving warnings remain instrumented][6067]

[3227]: https://github.com/enso-org/enso/pull/3227
[3248]: https://github.com/enso-org/enso/pull/3248
Expand Down Expand Up @@ -753,6 +754,7 @@
[5791]: https://github.com/enso-org/enso/pull/5791
[5900]: https://github.com/enso-org/enso/pull/5900
[5966]: https://github.com/enso-org/enso/pull/5966
[6067]: https://github.com/enso-org/enso/pull/6067

# Enso 2.0.0-alpha.18 (2021-10-12)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ public void onReturnValue(VirtualFrame frame, Object result) {
Node node = context.getInstrumentedNode();

if (node instanceof FunctionCallInstrumentationNode
&& result instanceof FunctionCallInstrumentationNode.FunctionCall) {
&& result instanceof FunctionCallInstrumentationNode.FunctionCall functionCall) {
UUID nodeId = ((FunctionCallInstrumentationNode) node).getId();
onFunctionReturn(nodeId, result, context);
onFunctionReturn(nodeId, functionCall, context);
} else if (node instanceof ExpressionNode) {
onExpressionReturn(result, node, context);
}
Expand Down Expand Up @@ -307,11 +307,10 @@ private FunctionCallInfo functionCallInfoById(UUID nodeId) {
}

@CompilerDirectives.TruffleBoundary
private void onFunctionReturn(UUID nodeId, Object result, EventContext context) throws ThreadDeath {
private void onFunctionReturn(UUID nodeId, FunctionCallInstrumentationNode.FunctionCall result, EventContext context) throws ThreadDeath {
calls.put(
nodeId, new FunctionCallInfo((FunctionCallInstrumentationNode.FunctionCall) result));
functionCallCallback.accept(
new ExpressionCall(nodeId, (FunctionCallInstrumentationNode.FunctionCall) result));
nodeId, new FunctionCallInfo(result));
functionCallCallback.accept(new ExpressionCall(nodeId, result));
// Return cached value after capturing the enterable function call in `functionCallCallback`
Object cachedResult = cache.get(nodeId);
if (cachedResult != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package org.enso.interpreter.test;

import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.instrumentation.EventContext;
import com.oracle.truffle.api.instrumentation.ExecutionEventNode;
import com.oracle.truffle.api.instrumentation.ExecutionEventNodeFactory;
import com.oracle.truffle.api.instrumentation.SourceSectionFilter;
import com.oracle.truffle.api.instrumentation.TruffleInstrument;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
import com.oracle.truffle.api.source.SourceSection;
import org.enso.interpreter.node.MethodRootNode;
import org.enso.interpreter.node.callable.FunctionCallInstrumentationNode;
import org.enso.pkg.QualifiedName;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Function;
Expand All @@ -22,6 +30,8 @@ public class NodeCountingTestInstrument extends TruffleInstrument {
public static final String INSTRUMENT_ID = "node-count-test";
private Map<Node, Node> all = new ConcurrentHashMap<>();
private Map<Class, List<Node>> counter = new ConcurrentHashMap<>();

private Map<UUID, FunctionCallInfo> calls = new ConcurrentHashMap<>();
private Env env;

@Override
Expand All @@ -33,7 +43,17 @@ protected void onCreate(Env env) {
public void enable() {
this.env
.getInstrumenter()
.attachExecutionEventFactory(SourceSectionFilter.ANY, new CountingFactory());
.attachExecutionEventFactory(SourceSectionFilter.ANY, new CountingAndFunctionCallFactory());
}

public void enable(SourceSectionFilter filter) {
this.env
.getInstrumenter()
.attachExecutionEventFactory(filter, new CountingAndFunctionCallFactory());
}

public Map<UUID, FunctionCallInfo> registeredCalls() {
return calls;
}

public Map<Class, List<Node>> assertNewNodes(String msg, int min, int max) {
Expand Down Expand Up @@ -73,16 +93,100 @@ private void dumpNode(String indent, Node n, StringBuilder sb) {
}
}

private final class CountingFactory implements ExecutionEventNodeFactory {
private final class CountingAndFunctionCallFactory implements ExecutionEventNodeFactory {
@Override
public ExecutionEventNode create(EventContext context) {
final Node node = context.getInstrumentedNode();
if (!"PatchableLiteralNode".equals(node.getClass().getSimpleName())) {
if (all.put(node, node) == null) {
counter.computeIfAbsent(node.getClass(), (__) -> new CopyOnWriteArrayList<>()).add(node);
}
return new NodeWrapper(context, calls);
}
return null;
}
}

private class NodeWrapper extends ExecutionEventNode {

private final EventContext context;

private final Map<UUID, FunctionCallInfo> calls;

public NodeWrapper(EventContext context, Map<UUID, FunctionCallInfo> calls) {
this.context = context;
this.calls = calls;
}

public void onReturnValue(VirtualFrame frame, Object result) {
Node node = context.getInstrumentedNode();
if (node instanceof FunctionCallInstrumentationNode instrumentableNode
&& result instanceof FunctionCallInstrumentationNode.FunctionCall functionCall) {
onFunctionReturn(instrumentableNode, functionCall);
}
}

private void onFunctionReturn(FunctionCallInstrumentationNode node, FunctionCallInstrumentationNode.FunctionCall result) {
if (node.getId() != null) {
calls.put(node.getId(), new FunctionCallInfo(result));
}
}

}

public class FunctionCallInfo {

private final QualifiedName moduleName;
private final QualifiedName typeName;
private final String functionName;

public FunctionCallInfo(FunctionCallInstrumentationNode.FunctionCall call) {
RootNode rootNode = call.getFunction().getCallTarget().getRootNode();
if (rootNode instanceof MethodRootNode methodNode) {
moduleName = methodNode.getModuleScope().getModule().getName();
typeName = methodNode.getType().getQualifiedName();
functionName = methodNode.getMethodName();
} else {
moduleName = null;
typeName = null;
functionName = rootNode.getName();
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionCallInfo that = (FunctionCallInfo) o;
return Objects.equals(moduleName, that.moduleName)
&& Objects.equals(typeName, that.typeName)
&& Objects.equals(functionName, that.functionName);
}

@Override
public int hashCode() {
return Objects.hash(moduleName, typeName, functionName);
}

@Override
public String toString() {
return moduleName + "::" + typeName + "::" + functionName;
}

public QualifiedName getModuleName() {
return moduleName;
}

public QualifiedName getTypeName() {
return typeName;
}

public String getFunctionName() {
return functionName;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ public void sendUpdatesWhenFunctionBodyIsChanged() {
sendUpdatesWhenFunctionBodyIsChangedBySettingValue("4", ConstantsGen.INTEGER, "4", "5", "5", LiteralNode.class);
var m = context.languageContext().findModule(MODULE_NAME).orElse(null);
assertNotNull("Module found", m);
var numbers = m.getIr().preorder().filter((v1) -> {
return v1 instanceof IR$Literal$Number;
});
var numbers = m.getIr().preorder().filter((v1) -> v1 instanceof IR$Literal$Number);
assertEquals("One number found: " + numbers, 1, numbers.size());
if (numbers.head() instanceof IR$Literal$Number n) {
assertEquals("updated to 5", "5", n.value());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package org.enso.interpreter.test.instrument;

import com.oracle.truffle.api.instrumentation.SourceSectionFilter;
import com.oracle.truffle.api.instrumentation.StandardTags;
import org.enso.interpreter.runtime.tag.AvoidIdInstrumentationTag;
import org.enso.interpreter.runtime.tag.IdentifiedTag;
import org.enso.interpreter.test.Metadata;
import org.enso.interpreter.test.NodeCountingTestInstrument;
import org.enso.polyglot.RuntimeOptions;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Language;
import org.graalvm.polyglot.Source;
import static org.junit.Assert.assertEquals;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.io.OutputStream;
import java.nio.file.Paths;
import java.util.Map;

public class WarningInstrumentationTest {

private Context context;
private NodeCountingTestInstrument instrument;

@Before
public void initContext() {
context = Context.newBuilder()
.allowExperimentalOptions(true)
.option(
RuntimeOptions.LANGUAGE_HOME_OVERRIDE,
Paths.get("../../distribution/component").toFile().getAbsolutePath()
)
.logHandler(OutputStream.nullOutputStream())
.allowExperimentalOptions(true)
.allowIO(true)
.allowAllAccess(true)
.build();

var engine = context.getEngine();
Map<String, Language> langs = engine.getLanguages();
Assert.assertNotNull("Enso found: " + langs, langs.get("enso"));

instrument = engine.getInstruments().get(NodeCountingTestInstrument.INSTRUMENT_ID).lookup(NodeCountingTestInstrument.class);
SourceSectionFilter builder = SourceSectionFilter.newBuilder()
.tagIs(StandardTags.ExpressionTag.class, StandardTags.CallTag.class)
.tagIs(IdentifiedTag.class)
.tagIsNot(AvoidIdInstrumentationTag.class)
.build();
instrument.enable(builder);
}

@After
public void disposeContext() {
context.close();
}

@Test
public void instrumentValueWithWarnings() throws Exception {
var metadata = new Metadata();

var idOp1 = metadata.addItem(151, 34, null);
var idOp2 = metadata.addItem(202, 31, null);
var idOp3 = metadata.addItem(250, 13, null);
var rawCode = """
from Standard.Base import all
from Standard.Base.Warning import Warning
from Standard.Table.Data.Table import Table
run column_name =
operator1 = Table.new [[column_name, [1,2,3]]]
operator2 = Warning.attach "Text" operator1
operator3 = operator2.get
operator3
""";
var code = metadata.appendToCode(rawCode);
var src = Source.newBuilder("enso", code, "TestWarning.enso").build();
var module = context.eval(src);
var res = module.invokeMember("eval_expression", "run");
res.execute("A");

var calls = instrument.registeredCalls();

assertEquals(calls.keySet().size(), 3);
assertEquals(calls.get(idOp1).getFunctionName(), "new");
assertEquals(calls.get(idOp2).getFunctionName(), "attach");
assertEquals(calls.get(idOp3).getTypeName().item(), "Table");
assertEquals(calls.get(idOp3).getFunctionName(), "get");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,15 @@ public Object execute(VirtualFrame frame, Function function, Object state, Objec
*/
@Override
public WrapperNode createWrapper(ProbeNode probeNode) {
return new FunctionCallInstrumentationNodeWrapper(this, probeNode);
var wrapper = new FunctionCallInstrumentationNodeWrapper(this, probeNode);
wrapper.setId(this.getId());
return wrapper;
}

/**
* Makrs this node with relevant runtime tags.
* Marks this node with relevant runtime tags.
*
* @param tag the tag to check agains.
* @param tag the tag to check against.
* @return true if the node carries the {@code tag}, false otherwise.
*/
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ public Object invokeWarnings(
invokeFunctionNode.getDefaultsExecutionMode(),
invokeFunctionNode.getArgumentsExecutionMode()));
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
} finally {
Expand Down Expand Up @@ -356,5 +357,8 @@ public void setId(UUID id) {
invokeFunctionNode.setId(id);
invokeMethodNode.setId(id);
invokeConversionNode.setId(id);
if (childDispatch != null) {
childDispatch.setId(id);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ Object doWarning(
invokeFunctionNode.getArgumentsExecutionMode(),
thatArgumentPosition));
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ Object doWarning(
invokeFunctionNode.getArgumentsExecutionMode(),
thisArgumentPosition));
childDispatch.setTailStatus(getTailStatus());
childDispatch.setId(invokeFunctionNode.getId());
notifyInserted(childDispatch);
}
} finally {
Expand Down Expand Up @@ -525,5 +526,8 @@ ThunkExecutorNode[] buildExecutors() {
*/
public void setId(UUID id) {
invokeFunctionNode.setId(id);
if (childDispatch != null) {
childDispatch.setId(id);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,9 @@ public SourceSection getSourceSection() {
public void setId(UUID id) {
functionCallInstrumentationNode.setId(id);
}

/** Returns expression ID of this node. */
public UUID getId() {
return functionCallInstrumentationNode.getId();
}
}

0 comments on commit 76409b2

Please sign in to comment.