Skip to content

Commit

Permalink
Merge pull request #332 from rmorshea/refactor_defaults
Browse files Browse the repository at this point in the history
Create All Default Values With Generators
  • Loading branch information
minrk authored Nov 1, 2016
2 parents 6de87c0 + a8e0443 commit 5a49dda
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 50 deletions.
2 changes: 1 addition & 1 deletion traitlets/config/configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _find_my_config(self, cfg):
If I am Bar and my parent is Foo, and their parent is Tim,
this will return merge following config sections, in this order::
[Bar, Foo.bar, Tim.Foo.Bar]
[Bar, Foo.Bar, Tim.Foo.Bar]
With the last item being the highest priority.
"""
Expand Down
2 changes: 1 addition & 1 deletion traitlets/config/tests/test_configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class SomeSingleton(SingletonConfigurable):

# reset deprecation limiter
_deprecations_shown.clear()
with expected_warnings(['metadata should be set using the \.tag\(\) method']):
with expected_warnings(['metadata should be set using the \.tag\(\) method', "use @default decorator instead\\."]):
class DefaultConfigurable(Configurable):
a = Integer(config=True)
def _config_default(self):
Expand Down
4 changes: 2 additions & 2 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class A(HasTraits):
# Defaults are validated when the HasTraits is instantiated
class B(HasTraits):
tt = MyIntTT('bad default')
self.assertRaises(TraitError, B)
self.assertRaises(TraitError, getattr, B(), 'tt')

def test_info(self):
class A(HasTraits):
Expand Down Expand Up @@ -896,7 +896,7 @@ class A(HasTraits):
class C(HasTraits):
klass = Type(None, B)

self.assertRaises(TraitError, C)
self.assertRaises(TraitError, getattr, C(), 'klass')

def test_str_klass(self):

Expand Down
106 changes: 60 additions & 46 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ def class_init(self, cls, name):
self.this_class = cls
self.name = name

def subclass_init(self, cls):
pass

def instance_init(self, obj):
"""Part of the initialization which may depend on the underlying
HasDescriptors instance.
Expand All @@ -416,6 +419,23 @@ class TraitType(BaseDescriptor):
read_only = False
info_text = 'any value'

def class_init(self, cls, name):
super(TraitType, self).class_init(cls, name)
if self.name not in cls._trait_default_generators:
if hasattr(self, 'make_dynamic_default'):
cls._trait_default_generators[self.name] = (
lambda obj: self.make_dynamic_default())
elif self.default_value is not Undefined:
cls._trait_default_generators[self.name] = (
lambda obj : self.default_value)

def subclass_init(self, cls):
if '_%s_default' % self.name in cls.__dict__:
method = getattr(cls, '_%s_default' % self.name)
_deprecated_method(method, cls, '_%s_default' % self.name,
"use @default decorator instead.")
cls._trait_default_generators[self.name] = method

def __init__(self, default_value=Undefined, allow_none=False, read_only=None, help=None, **kwargs):
"""Declare a traitlet.
Expand Down Expand Up @@ -480,57 +500,18 @@ def init_default_value(self, obj):
obj._trait_values[self.name] = value
return value

def _dynamic_default_callable(self, obj):
"""Retrieve a callable to calculate the default for this traitlet.
This looks for:
* default generators registered with the @default descriptor.
* obj._{name}_default() on the class with the traitlet, or a subclass
that obj belongs to.
* trait.make_dynamic_default, which is defined by Instance
If neither exist, it returns None
"""
# Traitlets without a name are not on the instance, e.g. in List or Union
if self.name:

# Only look for default handlers in classes derived from self.this_class.
mro = type(obj).mro()
meth_name = '_%s_default' % self.name
for cls in mro[:mro.index(self.this_class) + 1]:
if hasattr(cls, '_trait_default_generators'):
default_handler = cls._trait_default_generators.get(self.name)
if default_handler is not None and default_handler.this_class == cls:
return types.MethodType(default_handler.func, obj)

if meth_name in cls.__dict__:
method = getattr(obj, meth_name)
_deprecated_method(method, cls, meth_name, "use @default decorator instead.")
return method

return getattr(self, 'make_dynamic_default', None)

def instance_init(self, obj):
# If no dynamic initialiser is present, and the trait implementation or
# use provides a static default, transfer that to obj._trait_values.
with obj.cross_validation_lock:
if (self._dynamic_default_callable(obj) is None) \
and (self.default_value is not Undefined):
v = self._validate(obj, self.default_value)
if self.name is not None:
obj._trait_values[self.name] = v

def get(self, obj, cls=None):
try:
value = obj._trait_values[self.name]
except KeyError:
# Check for a dynamic initializer.
dynamic_default = self._dynamic_default_callable(obj)
if dynamic_default is None:
raise TraitError("No default value found for %s trait of %r"
% (self.name, obj))
value = self._validate(obj, dynamic_default())
try:
dgen = cls._trait_default_generators[self.name]
except:
raise TraitError("No default value found for "
"the '%s' trait named '%s' of %r" % (
type(self).__name__, self.name, obj))
value = self._validate(obj, dgen(obj))
obj._trait_values[self.name] = value
return value
except Exception:
Expand Down Expand Up @@ -748,13 +729,24 @@ def setup_class(cls, classdict):
if isinstance(v, BaseDescriptor):
v.class_init(cls, k)

for k, v in getmembers(cls):
if isinstance(v, BaseDescriptor):
v.subclass_init(cls)


class MetaHasTraits(MetaHasDescriptors):
"""A metaclass for HasTraits."""

def setup_class(cls, classdict):
cls._trait_default_generators = {}
super(MetaHasTraits, cls).setup_class(classdict)
new = {}
for c in reversed(cls.mro()):
if hasattr(c, "_trait_default_generators"):
new.update(c._trait_default_generators)
cls._trait_default_generators = new




def observe(*names, **kwargs):
Expand Down Expand Up @@ -1398,6 +1390,28 @@ def has_trait(self, name):
"""Returns True if the object has a trait with the specified name."""
return isinstance(getattr(self.__class__, name, None), TraitType)

def trait_defaults(self, *names, **metadata):
"""Return a trait's default value or a dictionary of them
Notes
-----
Dynamically generated default values may
depend on the current state of the object."""
if len(names) == 1 and len(metadata) == 0:
return self._trait_default_generators[names[0]](self)

for n in names:
if not has_trait(self, n):
raise TraitError("'%s' is not a trait of '%s' "
"instances" % (n, type(self).__name__))
trait_names = self.trait_names(**metadata)
trait_names.extend(names)

defaults = {}
for n in trait_names:
defaults[n] = self._trait_default_generators[n](self)
return defaults

def trait_names(self, **metadata):
"""Get a list of all the names of this class' traits."""
return list(self.traits(**metadata))
Expand Down

0 comments on commit 5a49dda

Please sign in to comment.