Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Eventful dict and custom events #278

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions traitlets/eventful.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

# void functions used as a callback placeholders.
def _default_pre_set(self, key, value):
return value

def _default_post_set(self, key, value):
pass

def _default_pre_del(self, key):
pass

def _default_post_del(self, key):
pass


class edict(dict):

pre_set = _default_pre_set
post_set = _default_post_set
pre_del = _default_pre_del
post_del = _default_post_del

def pop(self, key):
self.pre_del(key)
ret = dict.pop(self, key)
self.post_del(key)
return ret

def popitem(self):
key = next(iter(self))
return key, self.pop(key)

def update(self, other_dict):
for (key, value) in other_dict.items():
self[key] = value

def clear(self):
for key in list(self.keys()):
del self[key]

def __setitem__(self, key, value):
value = self.pre_set(key, value)
ret = dict.__setitem__(self, key, value)
self.post_set(key, value)
return ret

def __delitem__(self, key):
value = self.pre_del(key)
ret = dict.__delitem__(self, key)
value = self.post_del(key)
return ret

25 changes: 23 additions & 2 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
TraitError, Union, All, Undefined, Type, This, Instance, TCPAddress,
List, Tuple, ObjectName, DottedObjectName, CRegExp, link, directional_link,
ForwardDeclaredType, ForwardDeclaredInstance, validate, observe, default,
observe_compat, BaseDescriptor, HasDescriptors,
observe_compat, BaseDescriptor, HasDescriptors, rollback_change, EDict
)

import six

def change_dict(*ordered_values):
change_names = ('name', 'old', 'new', 'owner', 'type')
return dict(zip(change_names, ordered_values))
return dict(zip(change_names, ordered_values), rollback=rollback_change)

#-----------------------------------------------------------------------------
# Helper classes for testing
Expand Down Expand Up @@ -1994,6 +1994,27 @@ def assign_rollback():

self.assertRaises(TraitError, assign_rollback)

def test_edict_rollback(self):
class Foo(HasTraits):
bar = EDict(default_value={1: 1, 2: 2})
baz = Int()

foo = Foo()
try:
# This one should roll back due to baz failed validation
with foo.hold_trait_notifications():
foo.bar[1] = 2
del foo.bar[2]
foo.baz = '' #triggers trait error and rollback
except TraitError:
self.assertEqual(foo.bar, {1: 1, 2: 2})
# This one should roll back due to baz failed validation
with foo.hold_trait_notifications():
foo.bar[1] = 2
del foo.bar[2]
self.assertEqual(foo.bar, {1: 2})



class CacheModification(HasTraits):
foo = Int()
Expand Down
137 changes: 131 additions & 6 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

import six

from .eventful import edict
from .utils.getargspec import getargspec
from .utils.importstring import import_item
from .utils.sentinel import Sentinel
Expand Down Expand Up @@ -974,6 +975,11 @@ def setup_instance(self, *args, **kwargs):
if isinstance(value, BaseDescriptor):
value.instance_init(self)

def rollback_change(change):
if change.old is not Undefined:
change.owner.set_trait(change.name, change.old)
else:
change.owner._trait_values.pop(change.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail in the case that a default value was defined via TraitType.__init__ since those are only assigned to the HasTraits instance during instance_init. I'm not sure this can be resolved given the current way default values are handled. However with #332, popping off of _trait_values would work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue resolved - #322 was merged.


class HasTraits(six.with_metaclass(MetaHasTraits, HasDescriptors)):

Expand Down Expand Up @@ -1110,12 +1116,8 @@ def hold(change):
self.notify_change = lambda x: None
for name, changes in cache.items():
for change in changes[::-1]:
# TODO: Separate in a rollback function per notification type.
if change.type == 'change':
if change.old is not Undefined:
self.set_trait(name, change.old)
else:
self._trait_values.pop(name)
if 'rollback' in change:
change.rollback(change)
cache = {}
raise e
finally:
Expand All @@ -1140,6 +1142,7 @@ def _notify_trait(self, name, old_value, new_value):
old=old_value,
new=new_value,
owner=self,
rollback=rollback_change,
type='change',
))

Expand Down Expand Up @@ -2572,6 +2575,128 @@ def instance_init(self, obj):
trait.instance_init(obj)
super(Dict, self).instance_init(obj)

def rollback_set_element(change):
evdict = getattr(change.owner, change.name)
if change.old is Undefined:
del evdict[change.key]
else:
evdict[change.key] = change.old

def rollback_del_element(change):
evdict = getattr(change.owner, change.name)
evdict[change.key] = change.old


class EDict(TraitType):

def _validate(self, obj, value):
# Callbacks are wired in _validate instead of set
# so that it is wired even in the case where make dynamic defaults is
# called (delayed initialization).
value = super(EDict, self)._validate(obj, value)
self._set_callbacks(obj, value)
return value

def validate(self, obj, value):
if isinstance(value, dict):
value = edict(value)
return value
else:
self.error(obj, value)

def set(self, obj, value):
new_value = self._validate(obj, value)
old_value = obj._trait_values.get(self.name, {})
obj._trait_values[self.name] = new_value
self._clear_callbacks(old_value)
if old_value != new_value:
obj._notify_trait(self.name, old_value, new_value)

def _set_callbacks(self, obj, value):
value.pre_set = self._pre_set(obj, value)
value.post_set = self._post_set(obj, value)
value.pre_del = self._pre_del(obj, value)
value.post_del = self._post_del(obj, value)

def _clear_callbacks(self, evdict):
if hasattr(evdict, 'pre_set'):
del evdict.pre_set
if hasattr(evdict, 'post_set'):
del evdict.post_set
if hasattr(evdict, 'pre_del'):
del evdict.pre_del
if hasattr(evdict, 'post_del'):
del evdict.post_del

def _pre_set(self, obj, evdict):
def pre_set(k, v):
if k in evdict:
getattr(obj, self._cache_name())[k] = evdict[k]
# validating a copy of the attribute with the deleted key
value = getattr(obj, self.name).copy()
value[k] = v
self._validate(obj, value)
return v
return pre_set

def _post_set(self, obj, evdict):
def post_set(k, v):
obj.notify_change(Bunch(
key=k,
new=v,
old=getattr(obj, self._cache_name()).get(k, Undefined),
owner=obj,
name=self.name,
rollback=rollback_set_element,
type='element_set',
))
return post_set

def _pre_del(self, obj, evdict):
def pre_del(k):
if k in evdict:
getattr(obj, self._cache_name())[k] = evdict[k]
# validating a copy of the attribute with the deleted key
value = getattr(obj, self.name).copy()
del value[k]
self._validate(obj, value)
return pre_del

def _post_del(self, obj, evdict):
def post_del(k):
obj.notify_change(Bunch(
key=k,
owner=obj,
old=getattr(obj, self._cache_name())[k],
name=self.name,
rollback=rollback_del_element,
type='element_del',
))
return post_del

def info(self):
result = 'dictionary'
if self.allow_none:
return result + ' or None'
else:
return result

def make_dynamic_default(self):
if self.default_value is not Undefined:
return edict(self.default_value)
else:
return edict()

def default_value_repr(self):
return repr(self.make_dynamic_default())

def instance_init(self, obj):
setattr(obj, self._cache_name(), {})
super(EDict, self).instance_init(obj)

def _cache_name(self):
return '_' + self.name + '_cache'


class TCPAddress(TraitType):
"""A trait for an (ip, port) tuple.
Expand Down