From 79372d7523e723a7cddd6cd28beee147f7375796 Mon Sep 17 00:00:00 2001 From: Jonathan Dorn Date: Thu, 18 Jul 2024 17:39:15 -0400 Subject: [PATCH] Only update IntervalTrees when we need them --- python/gtirb/byteinterval.py | 39 ++++---- python/gtirb/lazyintervaltree.py | 118 +++++++++++++++++++++++++ python/gtirb/section.py | 36 ++++---- python/stubs/intervaltree/interval.pyi | 4 +- python/tests/test_blocks_at_offset.py | 12 ++- 5 files changed, 168 insertions(+), 41 deletions(-) create mode 100644 python/gtirb/lazyintervaltree.py diff --git a/python/gtirb/byteinterval.py b/python/gtirb/byteinterval.py index 78071c818..f48c85021 100644 --- a/python/gtirb/byteinterval.py +++ b/python/gtirb/byteinterval.py @@ -1,11 +1,10 @@ -import itertools import typing from uuid import UUID -from intervaltree import IntervalTree from sortedcontainers import SortedDict from .block import ByteBlock, CodeBlock, DataBlock +from .lazyintervaltree import LazyIntervalTree from .node import Node, _NodeMessage from .proto import ByteInterval_pb2, SymbolicExpression_pb2 from .symbolicexpression import SymAddrAddr, SymAddrConst, SymbolicExpression @@ -162,15 +161,19 @@ def __init__( raise ValueError("initialized_size must be <= size!") super().__init__(uuid=uuid) - self._interval_tree: "IntervalTree[int, ByteBlock]" = IntervalTree() self._section: typing.Optional["Section"] = None self.address = address self.size = size self.contents = bytearray(contents) self.initialized_size = initialized_size - self.blocks: SetWrapper[ByteBlock] = ByteInterval._BlockSet( - self, blocks + + # Both blocks and _interval_tree must exist before adding any blocks. + self.blocks: SetWrapper[ByteBlock] = ByteInterval._BlockSet(self) + self._interval_tree = LazyIntervalTree[int, ByteBlock]( + self.blocks, _offset_interval ) + self.blocks.update(blocks) + self._symbolic_expressions = ByteInterval._SymbolicExprDict( self, symbolic_expressions ) @@ -186,20 +189,14 @@ def _index_add_multiple( old_blocks: typing.Collection[ByteBlock], new_blocks: typing.Collection[ByteBlock], ) -> None: - if len(old_blocks) < len(new_blocks): - self._interval_tree = IntervalTree( - _offset_interval(block) - for block in itertools.chain(old_blocks, new_blocks) - ) - else: - for block in new_blocks: - self._index_add(block) + for block in new_blocks: + self._interval_tree.add(block) def _index_add(self, block: ByteBlock) -> None: - self._interval_tree.add(_offset_interval(block)) + self._interval_tree.add(block) def _index_discard(self, block: ByteBlock) -> None: - self._interval_tree.discard(_offset_interval(block)) + self._interval_tree.discard(block) @property def initialized_size(self) -> int: @@ -444,7 +441,7 @@ def byte_blocks_on( return () return _nodes_on_interval_tree( - self._interval_tree, addrs, -self.address + self._interval_tree.get(), addrs, -self.address ) def byte_blocks_at( @@ -460,7 +457,7 @@ def byte_blocks_at( return () return _nodes_at_interval_tree( - self._interval_tree, addrs, -self.address + self._interval_tree.get(), addrs, -self.address ) def code_blocks_on( @@ -524,7 +521,9 @@ def byte_blocks_on_offset( :param offsets: Either a ``range`` object or a single offset. """ - return _nodes_on_interval_tree_offset(self._interval_tree, offsets) + return _nodes_on_interval_tree_offset( + self._interval_tree.get(), offsets + ) def byte_blocks_at_offset( self, offsets: typing.Union[int, range] @@ -535,7 +534,9 @@ def byte_blocks_at_offset( :param offsets: Either a ``range`` object or a single offset. """ - return _nodes_at_interval_tree_offset(self._interval_tree, offsets) + return _nodes_at_interval_tree_offset( + self._interval_tree.get(), offsets + ) def code_blocks_on_offset( self, offsets: typing.Union[int, range] diff --git a/python/gtirb/lazyintervaltree.py b/python/gtirb/lazyintervaltree.py new file mode 100644 index 000000000..4e7642429 --- /dev/null +++ b/python/gtirb/lazyintervaltree.py @@ -0,0 +1,118 @@ +""" +Implements a simple wrapper that lazily initializes and updates an +IntervalTree. + +GTIRB uses IntervalTrees to accelerate certain operations. However, these +operations are not always needed for a given GTIRB object or by a given GTIRB +analysis. To prevent scripts that do not need the IntervalTrees from wasting +time updating the data structures, the LazyIntervalTree in this module delays +instantiating or updating the tree. Instead, it queues the updates so they can +be rapidly applied when the script invokes an operation that requires an +up-to-date tree. +""" + +import enum +from typing import ( + Collection, + Generic, + Iterator, + List, + Optional, + Protocol, + Tuple, + TypeVar, +) + +from intervaltree import Interval, IntervalTree + +_K = TypeVar("_K") +_Kco = TypeVar("_Kco", covariant=True) +_V = TypeVar("_V") + + +class _EventType(enum.Enum): + """Whether an interval is to be added or discarded.""" + + ADDED = enum.auto() + DISCARDED = enum.auto() + + +class IntervalBuilder(Protocol[_Kco, _V]): + """Gets an interval for certain values. + + If no interval is available for a particular value, returns None instead. + """ + + def __call__(self, value: _V, /) -> Optional["Interval[_Kco, _V]"]: + ... + + +class LazyIntervalTree(Generic[_K, _V]): + """Simple wrapper to lazily initialize and update an IntervalTree. + + The underlying IntervalTree can be retrieved by calling get(). This will + ensure that the tree is up-to-date with all intermediate modifications + before returning it. + + In many algorithms, the tree may receive large numbers of modifications, + adding and removing the same intervals several times before querying. In + these cases, it may be faster to rebuild the tree from scratch rather than + perform all of the intermediate modifications. For this reason, get() is + not guaranteed to always return the same tree object. That is, the tree + returned by get() should not be cached; calling get() may return a new tree + rather than updating the tree it returned previously. + """ + + def __init__( + self, + values: Collection[_V], + make_interval: IntervalBuilder[_K, _V], + ): + """Create a new lazy tree. + + :param values: collection of values from which the tree can be rebuilt + :param make_interval: callable to get an interval for a value + """ + self._interval_index: Optional["IntervalTree[_K, _V]"] = None + self._interval_events: List[Tuple[_EventType, "Interval[_K, _V]"]] = [] + self._value_collection = values + self._make_interval = make_interval + + def add(self, value: _V) -> None: + """Add a value to the tree.""" + interval = self._make_interval(value) + if interval is not None: + self._interval_events.append((_EventType.ADDED, interval)) + + def discard(self, value: _V) -> None: + """Remove a value from the tree. + + Does nothing if the interval with that value is not present. + """ + interval = self._make_interval(value) + if interval is not None: + self._interval_events.append((_EventType.DISCARDED, interval)) + + def get(self) -> "IntervalTree[_K, _V]": + """Get the most up-to-date tree reflecting all pending updates.""" + + def intervals() -> Iterator["Interval[_K, _V]"]: + for value in self._value_collection: + interval = self._make_interval(value) + if interval: + yield interval + + if self._interval_index is None: + self._interval_index = IntervalTree(intervals()) + elif len(self._value_collection) <= len(self._interval_events): + # Constructing a new tree involves one update for each value. + self._interval_index = IntervalTree(intervals()) + else: + # There are fewer updates than constructing a new tree would use. + for event, interval in self._interval_events: + if event == _EventType.ADDED: + self._interval_index.add(interval) + else: + self._interval_index.discard(interval) + self._interval_events.clear() + return self._interval_index diff --git a/python/gtirb/section.py b/python/gtirb/section.py index cbe89d174..8584e25a2 100644 --- a/python/gtirb/section.py +++ b/python/gtirb/section.py @@ -3,10 +3,9 @@ from enum import Enum from uuid import UUID -from intervaltree import IntervalTree - from .block import ByteBlock, CodeBlock, DataBlock from .byteinterval import ByteInterval, SymbolicExpressionElement +from .lazyintervaltree import LazyIntervalTree from .node import Node, _NodeMessage from .proto import Section_pb2 from .util import ( @@ -102,24 +101,27 @@ def __init__( """ super().__init__(uuid) - self._interval_index: "IntervalTree[int,ByteInterval]" = IntervalTree() self._module: typing.Optional["Module"] = None self.name = name - self.byte_intervals = Section._ByteIntervalSet(self, byte_intervals) + + # Both byte_intervals and _interval_index must exist before adding any + # intervals. + self.byte_intervals = Section._ByteIntervalSet(self) + self._interval_index = LazyIntervalTree[int, ByteInterval]( + self.byte_intervals, _address_interval + ) + self.byte_intervals.update(byte_intervals) + self.flags = set(flags) # Use the property setter to ensure correct invariants. self.module = module def _index_add(self, byte_interval: ByteInterval) -> None: - address_interval = _address_interval(byte_interval) - if address_interval: - self._interval_index.add(address_interval) + self._interval_index.add(byte_interval) def _index_discard(self, byte_interval: ByteInterval) -> None: - address_interval = _address_interval(byte_interval) - if address_interval: - self._interval_index.discard(address_interval) + self._interval_index.discard(byte_interval) @classmethod def _decode_protobuf( @@ -233,8 +235,9 @@ def address(self) -> typing.Optional[int]: size, so it will be ``None`` in that case. """ - if 0 < len(self._interval_index) == len(self.byte_intervals): - return self._interval_index.begin() + index = self._interval_index.get() + if 0 < len(index) == len(self.byte_intervals): + return index.begin() return None @@ -251,8 +254,9 @@ def size(self) -> typing.Optional[int]: it has no address or size, so it will be ``None`` in that case. """ - if 0 < len(self._interval_index) == len(self.byte_intervals): - return self._interval_index.span() - 1 + index = self._interval_index.get() + if 0 < len(index) == len(self.byte_intervals): + return index.span() - 1 return None @@ -265,7 +269,7 @@ def byte_intervals_on( :param addrs: Either a ``range`` object or a single address. """ - return _nodes_on_interval_tree(self._interval_index, addrs) + return _nodes_on_interval_tree(self._interval_index.get(), addrs) def byte_intervals_at( self, addrs: typing.Union[int, range] @@ -276,7 +280,7 @@ def byte_intervals_at( :param addrs: Either a ``range`` object or a single address. """ - return _nodes_at_interval_tree(self._interval_index, addrs) + return _nodes_at_interval_tree(self._interval_index.get(), addrs) def byte_blocks_on( self, addrs: typing.Union[int, range] diff --git a/python/stubs/intervaltree/interval.pyi b/python/stubs/intervaltree/interval.pyi index 7567ca381..3f8157b91 100644 --- a/python/stubs/intervaltree/interval.pyi +++ b/python/stubs/intervaltree/interval.pyi @@ -1,7 +1,7 @@ from typing import Generic, TypeVar -PointT = TypeVar("PointT") -DataT = TypeVar("DataT") +PointT = TypeVar("PointT", covariant=True) +DataT = TypeVar("DataT", covariant=True) class Interval(Generic[PointT, DataT]): begin: PointT diff --git a/python/tests/test_blocks_at_offset.py b/python/tests/test_blocks_at_offset.py index b4f455fa5..f94b8eda5 100644 --- a/python/tests/test_blocks_at_offset.py +++ b/python/tests/test_blocks_at_offset.py @@ -10,28 +10,32 @@ class BlocksAtOffsetTests(unittest.TestCase): def test_blocks_at_offset_simple(self): ir, m, s, bi = create_interval_etc(address=None, size=4) + # Ensure we always have a couple blocks in the index beyond what we + # are querying so that we don't just rebuild the tree from scratch + # every time. code_block = gtirb.CodeBlock(offset=0, size=1, byte_interval=bi) code_block2 = gtirb.CodeBlock(offset=1, size=1, byte_interval=bi) + code_block3 = gtirb.CodeBlock(offset=2, size=1, byte_interval=bi) found = set(bi.byte_blocks_at_offset(0)) self.assertEqual(found, {code_block}) # Change the offset to verify we update the index - code_block.offset = 2 + code_block.offset = 3 found = set(bi.byte_blocks_at_offset(0)) self.assertEqual(found, set()) - found = set(bi.byte_blocks_at_offset(2)) + found = set(bi.byte_blocks_at_offset(3)) self.assertEqual(found, {code_block}) # Discard the block to verify we update the index bi.blocks.discard(code_block) - found = set(bi.byte_blocks_at_offset(2)) + found = set(bi.byte_blocks_at_offset(3)) self.assertEqual(found, set()) # Now add it back to verify we update the index bi.blocks.add(code_block) - found = set(bi.byte_blocks_at_offset(2)) + found = set(bi.byte_blocks_at_offset(3)) self.assertEqual(found, {code_block}) def test_blocks_at_offset_overlapping(self):