Skip to content

Commit

Permalink
Merge pull request #144 from bluescarni/pr/cfunc_perf_fix
Browse files Browse the repository at this point in the history
Fix several performance issues when creating large cfuncs
  • Loading branch information
bluescarni authored Nov 9, 2023
2 parents 2d20e0b + 38806af commit caa0314
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
2 changes: 2 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ Changes
Fix
~~~

- Fix slow performance when creating very large compiled functions
(`#144 <https://github.com/bluescarni/heyoka.py/pull/144>`__).
- Fix building against Python 3.12
(`#139 <https://github.com/bluescarni/heyoka.py/pull/139>`__).

Expand Down
49 changes: 47 additions & 2 deletions heyoka/_test_cfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class cfunc_test_case(_ut.TestCase):
def test_basic(self):
from . import make_cfunc, make_vars, cfunc_dbl, core
from . import make_cfunc, make_vars, cfunc_dbl, core, par, time
import pickle
from copy import copy, deepcopy

Expand All @@ -34,7 +34,7 @@ def test_basic(self):
"Cannot invoke a default-constructed compiled function" in str(cm.exception)
)

x, y, z = make_vars("x", "y", "z")
x, y, z, s = make_vars("x", "y", "z", "s")
cf = make_cfunc([y * (x + z)])

self.assertFalse(cf.llvm_state_scalar.force_avx512)
Expand Down Expand Up @@ -84,6 +84,51 @@ def test_basic(self):
self.assertTrue(cf.llvm_state_batch.force_avx512)
self.assertTrue(cf.llvm_state_batch.slp_vectorize)

# Tests for correct detection of number of params, time dependency
# and list of variables.
cf = make_cfunc(
[y * (x + z), x], vars=[y, z, x]
)
self.assertEqual(cf.param_size, 0)
cf = make_cfunc(
[y * (x + z), par[0]], vars=[y, z, x]
)
self.assertEqual(cf.param_size, 1)
cf = make_cfunc(
[y * (x + z) - par[89], par[0]], vars=[y, z, x]
)
self.assertEqual(cf.param_size, 90)

cf = make_cfunc(
[y * (x + z), x], vars=[y, z, x]
)
self.assertFalse(cf.is_time_dependent)
cf = make_cfunc(
[y * (x + z) + time, x], vars=[y, z, x]
)
self.assertTrue(cf.is_time_dependent)
cf = make_cfunc(
[y * (x + z), x + time], vars=[y, z, x]
)
self.assertTrue(cf.is_time_dependent)

cf = make_cfunc(
[y * (x + z), x + time]
)
self.assertEqual(cf.list_var, [x, y, z])
cf = make_cfunc(
[y * (x + z), x + time], vars=[y, z, x]
)
self.assertEqual(cf.list_var, [y, z, x])
cf = make_cfunc(
[y * (x + z), x + time], vars=[y, z, x, s]
)
self.assertEqual(cf.list_var, [y, z, x, s])
cf = make_cfunc(
[y * (x + z), x + time], vars=[s, y, z, x]
)
self.assertEqual(cf.list_var, [s, y, z, x])

# NOTE: test for a bug in the multiprecision
# implementation where the precision is not
# correctly copied.
Expand Down
36 changes: 13 additions & 23 deletions heyoka/cfunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
#include <cstdint>
#include <initializer_list>
#include <iterator>
#include <limits>
#include <optional>
#include <ostream>
#include <set>
#include <sstream>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -685,6 +683,9 @@ void expose_add_cfunc_impl(py::module &m, const char *suffix)
cfunc_inst.def(py::init<>());
cfunc_inst.def("__call__", &cfunc<T>::operator(), "inputs"_a, "outputs"_a = py::none{}, "pars"_a = py::none{},
"time"_a = py::none{});
cfunc_inst.def_readonly("param_size", &cfunc<T>::nparams);
cfunc_inst.def_readonly("is_time_dependent", &cfunc<T>::is_time_dependent);
// NOTE: these can be probably simplified into def_readonly().
cfunc_inst.def_property_readonly("llvm_state_scalar",
[](const cfunc<T> &cf) -> const hey::llvm_state & { return cf.s_scal; });
cfunc_inst.def_property_readonly("llvm_state_batch",
Expand Down Expand Up @@ -780,12 +781,8 @@ void expose_add_cfunc_impl(py::module &m, const char *suffix)
}

// Let's figure out if fn contains params and if it is time-dependent.
std::uint32_t nparams = 0;
bool is_time_dependent = false;
for (const auto &ex : fn) {
nparams = std::max<std::uint32_t>(nparams, hey::get_param_size(ex));
is_time_dependent = is_time_dependent || hey::is_time_dependent(ex);
}
const auto nparams = hey::get_param_size(fn);
const auto is_time_dependent = hey::is_time_dependent(fn);

// Cache the number of variables and outputs.
// NOTE: static casts are fine, because add_cfunc()
Expand All @@ -803,21 +800,14 @@ void expose_add_cfunc_impl(py::module &m, const char *suffix)

list_var = std::move(*vars);
} else {
// NOTE: this is a bit of repetition from add_cfunc().
// If this becomes an issue, we can consider in the
// future changing add_cfunc() to return also the number
// of detected variables.
std::set<std::string> dvars;
for (const auto &ex : fn) {
for (const auto &var : hey::get_variables(ex)) {
dvars.emplace(var);
}
}

nvars = static_cast<std::uint32_t>(dvars.size());

std::transform(dvars.begin(), dvars.end(), std::back_inserter(list_var),
[](const auto &str) { return hey::expression{str}; });
// NOTE: get_variables() returns an ordered list of strings,
// we need to convert it into a list of expressions.
const auto var_slist = hey::get_variables(fn);
list_var.reserve(var_slist.size());
std::transform(var_slist.begin(), var_slist.end(), std::back_inserter(list_var),
[](const auto &name) { return hey::expression{name}; });

nvars = static_cast<std::uint32_t>(list_var.size());
}

// Prepare local buffers to store inputs, outputs, pars and time
Expand Down

0 comments on commit caa0314

Please sign in to comment.