Skip to content
This repository has been archived by the owner on Dec 29, 2022. It is now read-only.

Commit

Permalink
add mutable wrapper for treelist
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Grosse committed Mar 22, 2016
1 parent 5649f6e commit 6ee0b06
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 23 deletions.
34 changes: 27 additions & 7 deletions enjarify/jvm/treelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,27 @@
# of sharing memory with previous versions of the list when only a few elements
# are changed. See http://en.wikipedia.org/wiki/Persistent_data_structure#Trees
# Also, default values are not stored, so this is good for sparse arrays
class ImmutableTreeList:
class TreeList:
def __init__(self, default, func, data=None):
self.default = default
self.func = func
self.data = data or _TreeListSub(default)

def __getitem__(self, i):
return self.data[i]

def __setitem__(self, i, val):
self.data = self.data.set(i, val)

def copy(self):
return TreeList(self.default, self.func, self.data)

def merge(self, other):
assert self.func is other.func
self.data = _TreeListSub.merge(self.data, other.data, self.func)


class _TreeListSub:
def __init__(self, default, direct=None, children=None):
self.default = default
if direct is None:
Expand Down Expand Up @@ -52,7 +72,7 @@ def set(self, i, val):

temp = self.direct[:]
temp[i] = val
return ImmutableTreeList(self.default, temp, self.children)
return _TreeListSub(self.default, temp, self.children)

i -= SIZE
i, ci = divmod(i, SPLIT)
Expand All @@ -61,15 +81,15 @@ def set(self, i, val):
if child is None:
if val == self.default:
return self
child = ImmutableTreeList(self.default).set(i, val)
child = _TreeListSub(self.default).set(i, val)
else:
if val == child[i]:
return self
child = child.set(i, val)

temp = self.children[:]
temp[ci] = child
return ImmutableTreeList(self.default, self.direct, temp)
return _TreeListSub(self.default, self.direct, temp)

@staticmethod
def merge(left, right, func):
Expand All @@ -82,18 +102,18 @@ def merge(left, right, func):
left, right = right, left

default = left.default
merge = ImmutableTreeList.merge
merge = _TreeListSub.merge
if right is None:
direct = [func(x, default) for x in left.direct]
children = [merge(child, None, func) for child in left.children]
if direct == left.direct and children == left.children:
return left
return ImmutableTreeList(default, direct, children)
return _TreeListSub(default, direct, children)

direct = [func(x, y) for x, y in zip(left.direct, right.direct)]
children = [merge(c1, c2, func) for c1, c2 in zip(left.children, right.children)]
if direct == left.direct and children == left.children:
return left
if direct == right.direct and children == right.children:
return right
return ImmutableTreeList(default, direct, children)
return _TreeListSub(default, direct, children)
36 changes: 20 additions & 16 deletions enjarify/jvm/typeinference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import arraytypes as arrays
from . import scalartypes as scalars
from . import mathops, jvmops
from .treelist import ImmutableTreeList
from .treelist import TreeList
from .. import flags, dalvik


Expand Down Expand Up @@ -46,13 +46,13 @@ def __init__(self, prims, arrs, tainted):
self.arrs = arrs
self.tainted = tainted

def _copy(self): return TypeInfo(self.prims, self.arrs, self.tainted)
def _copy(self): return TypeInfo(self.prims.copy(), self.arrs.copy(), self.tainted.copy())
def _get(self, reg): return self.prims[reg], self.arrs[reg], self.tainted[reg]

def _set(self, reg, st, at, taint=False):
self.prims = self.prims.set(reg, st)
self.arrs = self.arrs.set(reg, at)
self.tainted = self.tainted.set(reg, taint)
self.prims[reg] = st
self.arrs[reg] = at
self.tainted[reg] = taint
return self

def move(self, src, dest, wide):
Expand All @@ -78,27 +78,31 @@ def assignFromDesc(self, reg, desc):
else:
return self.assign(reg, st, at)

def isSame(self, other):
return (self.prims.data is other.prims.data and
self.arrs.data is other.arrs.data and
self.tainted.data is other.tainted.data)

def merge(old, new):
prims = ImmutableTreeList.merge(old.prims, new.prims, operator.__and__)
arrs = ImmutableTreeList.merge(old.arrs, new.arrs, arrays.merge)
tainted = ImmutableTreeList.merge(old.tainted, new.tainted, operator.__or__)
if prims is old.prims and arrs is old.arrs and tainted is old.tainted:
return old
return TypeInfo(prims, arrs, tainted)
temp = old._copy()
temp.prims.merge(new.prims)
temp.arrs.merge(new.arrs)
temp.tainted.merge(new.tainted)
return old if old.isSame(temp) else temp

def fromParams(method, num_regs):
isstatic = method.access & flags.ACC_STATIC
full_ptypes = method.id.getSpacedParamTypes(isstatic)
offset = num_regs - len(full_ptypes)

prims = ImmutableTreeList(scalars.INVALID)
arrs = ImmutableTreeList(arrays.INVALID)
tainted = ImmutableTreeList(False)
prims = TreeList(scalars.INVALID, operator.__and__)
arrs = TreeList(arrays.INVALID, arrays.merge)
tainted = TreeList(False, operator.__or__)

for i, desc in enumerate(full_ptypes):
if desc is not None:
prims = prims.set(offset + i, scalars.fromDesc(desc))
arrs = arrs.set(offset + i, arrays.fromDesc(desc))
prims[offset + i] = scalars.fromDesc(desc)
arrs[offset + i] = arrays.fromDesc(desc)
return TypeInfo(prims, arrs, tainted)

_MATH_THROW_OPS = [jvmops.IDIV, jvmops.IREM, jvmops.LDIV, jvmops.LREM]
Expand Down

0 comments on commit 6ee0b06

Please sign in to comment.