Skip to content

Commit

Permalink
Re-Resolve function calls if only java.lang.Object was found.
Browse files Browse the repository at this point in the history
This fixes an issue in test-generator where a call to `toString` is dispatched
to the `toString` method from `java.lang.Object` is used instead of the
overridden method from `ArrayList`.

Cite from the issue about the current implementation:

> It tracks a `visited` set preventing it from listing callees twice when
> multiple inheritance is in play (e.g. Java interfaces). Unfortunately this
> malfunctions when we visit a type twice, once via its interfaces and once via
> its concrete subclass which provides a definition

This extends the original implementation, by resolving dispatch entries where an
initial step resolved to a java.lang.Object function. This case can be an error,
as some classes might be visited multiple times, first for an interface and then
for the concrete class.

The original implementation kept a visited set that recorded visited classes
which resulted in some functions not correctly being resolved. This set is
replaced with a map from class identifiers to dispatch table entries.
  • Loading branch information
Matthias Güdemann committed Feb 5, 2018
1 parent a619e48 commit 09efc90
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 22 deletions.
122 changes: 100 additions & 22 deletions src/goto-programs/remove_virtual_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class remove_virtual_functionst
const symbol_exprt &,
const irep_idt &,
dispatch_table_entriest &,
std::set<irep_idt> &visited,
dispatch_table_entries_mapt &,
const function_call_resolvert &) const;
exprt
get_method(const irep_idt &class_id, const irep_idt &component_name) const;
Expand Down Expand Up @@ -163,11 +163,18 @@ void remove_virtual_functionst::remove_virtual_function(
newinst->source_location=vcall_source_loc;
}

// get initial identifier for grouping
INVARIANT(!functions.empty(), "Function dispatch table cannot be empty.");
auto last_id = functions.back().symbol_expr.get_identifier();
// record class_ids for disjunction
std::set<irep_idt> class_ids;

std::map<irep_idt, goto_programt::targett> calls;
// Note backwards iteration, to get the fallback candidate first.
for(auto it=functions.crbegin(), itend=functions.crend(); it!=itend; ++it)
{
const auto &fun=*it;
class_ids.insert(fun.class_id);
auto insertit=calls.insert(
{fun.symbol_expr.get_identifier(), goto_programt::targett()});

Expand Down Expand Up @@ -209,15 +216,50 @@ void remove_virtual_functionst::remove_virtual_function(
t3->make_goto(t_final, true_exprt());
}

// Emit target if end of dispatch table is reached or if the next element is
// dispatched to another function call. Assumes entries in the functions
// variable to be sorted for the identifier of the function to be called.
auto l_it = std::next(it);
bool next_emit_target =
(l_it == functions.crend()) ||
l_it->symbol_expr.get_identifier() != fun.symbol_expr.get_identifier();

// The root function call is done via fall-through, so nothing to emit
// explicitly for this.
if(next_emit_target && fun.symbol_expr == last_function_symbol)
{
class_ids.clear();
}

// If this calls the fallback function we just fall through.
// Otherwise branch to the right call:
if(fallback_action!=virtual_dispatch_fallback_actiont::CALL_LAST_FUNCTION ||
fun.symbol_expr!=last_function_symbol)
{
exprt c_id1=constant_exprt(fun.class_id, string_typet());
goto_programt::targett t4=new_code_gotos.add_instruction();
t4->source_location=vcall_source_loc;
t4->make_goto(insertit.first->second, equal_exprt(c_id1, c_id2));
// create a disjunction of class_ids to test
if(next_emit_target && fun.symbol_expr != last_function_symbol)
{
exprt::operandst or_ops;
for(const auto &id : class_ids)
{
const constant_exprt c_id1(id, string_typet());
const equal_exprt class_id_test(c_id1, c_id2);
or_ops.push_back(class_id_test);
}

goto_programt::targett t4 = new_code_gotos.add_instruction();
t4->source_location = vcall_source_loc;
t4->make_goto(insertit.first->second, disjunction(or_ops));

last_id = fun.symbol_expr.get_identifier();
class_ids.clear();
}
// record class_id
else if(next_emit_target)
{
last_id = fun.symbol_expr.get_identifier();
class_ids.clear();
}
}
}

Expand Down Expand Up @@ -252,11 +294,12 @@ void remove_virtual_functionst::remove_virtual_function(

/// Used by get_functions to track the most-derived parent that provides an
/// override of a given function.
/// \par parameters: `this_id`: class name
/// `last_method_defn`: the most-derived parent of `this_id` to define the
/// requested function
/// `component_name`: name of the function searched for
/// `resolve_function_call`: function to resolve abstract method call
/// \param parameters: `this_id`: class name
/// \param `last_method_defn`: the most-derived parent of `this_id` to define
/// the requested function
/// \param `component_name`: name of the function searched for
/// \param `entry_map`: map of class identifiers to dispatch table entries
/// \param `resolve_function_call`: function to resolve abstract method call
/// \return `functions` is assigned a list of {class name, function symbol}
/// pairs indicating that if `this` is of the given class, then the call will
/// target the given function. Thus if A <: B <: C and A and C provide
Expand All @@ -267,7 +310,7 @@ void remove_virtual_functionst::get_child_functions_rec(
const symbol_exprt &last_method_defn,
const irep_idt &component_name,
dispatch_table_entriest &functions,
std::set<irep_idt> &visited,
dispatch_table_entries_mapt &entry_map,
const function_call_resolvert &resolve_function_call) const
{
auto findit=class_hierarchy.class_map.find(this_id);
Expand All @@ -276,9 +319,18 @@ void remove_virtual_functionst::get_child_functions_rec(

for(const auto &child : findit->second.children)
{
if(!visited.insert(child).second)
// Skip if we have already visited this and we found a function call that
// did not resolve to non java.lang.Object.
auto it = entry_map.find(child);
if(
it != entry_map.end() &&
!has_prefix(
id2string(it->second.symbol_expr.get_identifier()),
"java::java.lang.Object"))
{
continue;
exprt method=get_method(child, component_name);
}
exprt method = get_method(child, component_name);
dispatch_table_entryt function(child);
if(method.is_not_nil())
{
Expand All @@ -305,37 +357,43 @@ void remove_virtual_functionst::get_child_functions_rec(
}
}
functions.push_back(function);
entry_map.insert({child, function});

get_child_functions_rec(
child,
function.symbol_expr,
component_name,
functions,
visited,
entry_map,
resolve_function_call);
}
}

/// Used to get dispatch entries to call for the given function
/// \param function: function that should be called
/// \param[out] functions: is assigned a list of dispatch entries, i.e., pairs
/// of class names and function symbol to call when encountering the class.
void remove_virtual_functionst::get_functions(
const exprt &function,
dispatch_table_entriest &functions)
{
// class part of function to call
const irep_idt class_id=function.get(ID_C_class);
const std::string class_id_string(id2string(class_id));
const irep_idt component_name=function.get(ID_component_name);
const std::string component_name_string(id2string(component_name));
const irep_idt function_name = function.get(ID_component_name);
const std::string function_name_string(id2string(function_name));
INVARIANT(!class_id.empty(), "All virtual functions must have a class");

resolve_concrete_function_callt get_virtual_call_target(
symbol_table, class_hierarchy);
const function_call_resolvert resolve_function_call =
[&get_virtual_call_target](
const irep_idt &class_id, const irep_idt &component_name) {
return get_virtual_call_target(class_id, component_name);
const irep_idt &class_id, const irep_idt &function_name) {
return get_virtual_call_target(class_id, function_name);
};

const resolve_concrete_function_callt::concrete_function_callt
&resolved_call = get_virtual_call_target(class_id, component_name);
&resolved_call = get_virtual_call_target(class_id, function_name);

dispatch_table_entryt root_function;

Expand All @@ -357,17 +415,37 @@ void remove_virtual_functionst::get_functions(
}

// iterate over all children, transitively
std::set<irep_idt> visited;
dispatch_table_entries_mapt entry_map;
get_child_functions_rec(
class_id,
root_function.symbol_expr,
component_name,
function_name,
functions,
visited,
entry_map,
resolve_function_call);

if(root_function.symbol_expr!=symbol_exprt())
functions.push_back(root_function);

// Sort for the identifier of the function call symbol expression, grouping
// together calls to the same function. Keep java.lang.Object entries at the
// end for fall through. The reasoning is that this is the case with most
// entries in realistic cases.
std::sort(
functions.begin(),
functions.end(),
[&root_function](const dispatch_table_entryt &a, dispatch_table_entryt &b) {
if(
has_prefix(
id2string(a.symbol_expr.get_identifier()), "java::java.lang.Object"))
return false;
else if(
has_prefix(
id2string(b.symbol_expr.get_identifier()), "java::java.lang.Object"))
return true;
else
return a.symbol_expr.get_identifier() < b.symbol_expr.get_identifier();
});
}

exprt remove_virtual_functionst::get_method(
Expand Down
1 change: 1 addition & 0 deletions src/goto-programs/remove_virtual_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class dispatch_table_entryt
};

typedef std::vector<dispatch_table_entryt> dispatch_table_entriest;
typedef std::map<irep_idt, dispatch_table_entryt> dispatch_table_entries_mapt;

void remove_virtual_function(
goto_modelt &goto_model,
Expand Down

0 comments on commit 09efc90

Please sign in to comment.