diff --git a/src/ruptures/utils/bnode.py b/src/ruptures/utils/bnode.py index dae8e19f..36c21d96 100644 --- a/src/ruptures/utils/bnode.py +++ b/src/ruptures/utils/bnode.py @@ -1,5 +1,6 @@ """Binary node.""" import functools +import numpy as np @functools.total_ordering @@ -25,6 +26,8 @@ def gain(self): return 0 elif abs(self.val) < 1e-8: return 0 + elif np.isinf(self.val) and self.val < 0: + return 0 return self.val - (self.left.val + self.right.val) def __lt__(self, other):