Skip to content

Commit

Permalink
Iteration limit in smart sampling to fix behavior for step functions (#…
Browse files Browse the repository at this point in the history
…928)

Closes #923 

The smart sampling algorithm was not tested on step functions and was
broken when applied to such functions. The fix is to stop iterating when
an interval becomes too narrow. In addition, further stopping criteria
were added, based on number of iterations and time spend.
  • Loading branch information
HDembinski authored Aug 15, 2023
1 parent 47bd39c commit f8033ee
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 15 deletions.
29 changes: 22 additions & 7 deletions src/iminuit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,18 +1542,32 @@ def _histogram_segments(mask, xe, masked):
return segments


def _smart_sampling(f, xmin, xmax, start=5, tol=5e-3):
def _smart_sampling(f, xmin, xmax, start=5, tol=5e-3, maxiter=20, maxtime=10):
t0 = monotonic()
x = np.linspace(xmin, xmax, start)
ynew = f(x)
ymin = np.min(ynew)
ymax = np.max(ynew)
y = {xi: yi for (xi, yi) in zip(x, ynew)}
a = x[:-1]
b = x[1:]
niter = 0
while len(a):
if len(y) > 10000:
warnings.warn("Too many points", RuntimeWarning) # pragma: no cover
break # pragma: no cover
niter += 1
if niter > maxiter:
msg = (
f"Iteration limit {maxiter} in smart sampling reached, "
f"produced {len(y)} points"
)
warnings.warn(msg, RuntimeWarning)
break
if monotonic() - t0 > maxtime:
msg = (
f"Time limit {maxtime} in smart sampling reached, "
f"produced {len(y)} points"
)
warnings.warn(msg, RuntimeWarning)
break
xnew = 0.5 * (a + b)
ynew = f(xnew)
ymin = min(ymin, np.min(ynew))
Expand All @@ -1565,10 +1579,11 @@ def _smart_sampling(f, xmin, xmax, start=5, tol=5e-3):
+ np.fromiter((y[bi] for bi in b), float)
)
dy = np.abs(ynew - yint)
dx = np.abs(b - a)

mask = dy > tol * (ymax - ymin)

# intervals which do not pass interpolation test
# in next iteration, handle intervals which do not
# pass interpolation test and are not too narrow
mask = (dy > tol * (ymax - ymin)) & (dx > tol * abs(xmax - xmin))
a = a[mask]
b = b[mask]
xnew = xnew[mask]
Expand Down
58 changes: 52 additions & 6 deletions tests/test_issue.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from iminuit import Minuit
from iminuit.util import IMinuitWarning
import pickle
import pytest
import numpy as np


def test_issue_424():
from iminuit import Minuit

def fcn(x, y, z):
return (x - 1) ** 2 + (y - 4) ** 2 / 2 + (z - 9) ** 2 / 3

Expand All @@ -20,6 +18,10 @@ def fcn(x, y, z):


def test_issue_544():
import pytest
from iminuit import Minuit
from iminuit.util import IMinuitWarning

def fcn(x, y):
return x**2 + y**2

Expand All @@ -30,6 +32,8 @@ def fcn(x, y):


def test_issue_648():
from iminuit import Minuit

class F:
first = True

Expand All @@ -45,6 +49,8 @@ def __call__(self, a, b):


def test_issue_643():
from iminuit import Minuit

def fcn(x, y, z):
return (x - 2) ** 2 + (y - 3) ** 2 + (z - 4) ** 2

Expand All @@ -64,6 +70,8 @@ def fcn(x, y, z):


def test_issue_669():
from iminuit import Minuit

def fcn(x, y):
return x**2 + (y / 2) ** 2

Expand All @@ -84,15 +92,21 @@ def fcn(x, y):
assert match


# cannot define this inside function, pickle will not allow it
def fcn(par):
return np.sum(par**2)


# cannot define this inside function, pickle will not allow it
def grad(par):
return 2 * par


def test_issue_687():
import pickle
import numpy as np
from iminuit import Minuit

start = np.zeros(3)
m = Minuit(fcn, start)

Expand All @@ -107,10 +121,13 @@ def test_issue_687():


def test_issue_694():
stats = pytest.importorskip("scipy.stats")

import pytest
import numpy as np
from iminuit import Minuit
from iminuit.cost import ExtendedUnbinnedNLL

stats = pytest.importorskip("scipy.stats")

xmus = 1.0
xmub = 5.0
xsigma = 1.0
Expand Down Expand Up @@ -142,3 +159,32 @@ def model(x, sig_n, sig_mu, sig_sigma, bkg_n, bkg_tau):
break
else:
assert False


def test_issue_923():
from iminuit import Minuit
from iminuit.cost import LeastSquares
import numpy as np
import pytest

# implicitly needed by visualize
pytest.importorskip("matplotlib")

def model(x, c1):
c2 = 100
res = np.zeros(len(x))
mask = x < 47
res[mask] = c1
res[~mask] = c2
return res

xtest = np.linspace(0, 74)
ytest = xtest * 0 + 1
ytesterr = ytest

least_squares = LeastSquares(xtest, ytest, ytesterr, model)

m = Minuit(least_squares, c1=1)
m.migrad()
# this used to trigger an endless (?) loop
m.visualize()
26 changes: 24 additions & 2 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,8 +711,30 @@ def test_smart_sampling_1(fn_expected):


def test_smart_sampling_2():
with pytest.warns(RuntimeWarning):
util._smart_sampling(np.log, 1e-10, 1, tol=1e-10)
# should not raise a warning
x, y = util._smart_sampling(np.log, 1e-10, 1, tol=1e-5)
assert 0 < len(x) < 1000


def test_smart_sampling_3():
def step(x):
return np.where(x > 0.5, 0, 1)

with pytest.warns(RuntimeWarning, match="Iteration limit"):
x, y = util._smart_sampling(step, 0, 1, tol=0)
assert 0 < len(x) < 80


def test_smart_sampling_4():
from time import sleep

def step(x):
sleep(0.1)
return np.where(x > 0.5, 0, 1)

with pytest.warns(RuntimeWarning, match="Time limit"):
x, y = util._smart_sampling(step, 0, 1, maxtime=0)
assert 0 < len(x) < 10


@pytest.mark.parametrize(
Expand Down

0 comments on commit f8033ee

Please sign in to comment.