Skip to content

Commit

Permalink
Merge pull request #1740 from Robbybp/reference-api
Browse files Browse the repository at this point in the history
Add is_reference API
  • Loading branch information
jsiirola authored Jan 19, 2021
2 parents 6e33b20 + 778aee5 commit dcd3fd8
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 0 deletions.
9 changes: 9 additions & 0 deletions pyomo/core/base/indexed_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ def is_indexed(self):
"""Return true if this component is indexed"""
return self._index is not UnindexedComponent_set

def is_reference(self):
"""Return True if this component is a reference, where
"reference" is interpreted as any component that does not
own its own data.
"""
return self._data is not None and type(self._data) is not dict

def dim(self):
"""Return the dimension of the index"""
if not self.is_indexed():
Expand Down Expand Up @@ -286,6 +293,8 @@ def __iter__(self):
# user iterates over the set when the _data dict is empty.
#
return self._data.__iter__()
elif self.is_reference():
return self._data.__iter__()
elif len(self._data) == len(self._index):
#
# If the data is dense then return the index iterator.
Expand Down
2 changes: 2 additions & 0 deletions pyomo/core/base/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ def Reference(reference, ctype=_NotSpecified):
4 : 1 : 10 : None : False : False : Reals
"""
referent = reference
if isinstance(reference, IndexedComponent_slice):
_data = _ReferenceDict(reference)
_iter = iter(reference)
Expand Down Expand Up @@ -692,4 +693,5 @@ def Reference(reference, ctype=_NotSpecified):
obj = ctype(index, ctype=ctype)
obj._constructed = True
obj._data = _data
obj.referent = referent
return obj
4 changes: 4 additions & 0 deletions pyomo/core/pyomoobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ def is_named_expression_type(self):
def is_logical_type(self):
"""Return True if this class is a Pyomo Boolean value, variable, or expression."""
return False

def is_reference(self):
"""Return True if this object is a reference."""
return False
6 changes: 6 additions & 0 deletions pyomo/core/tests/unit/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ def test_component_data_pprint(self):
'2 : None : None : None : False : True : Reals\n'
self.assertEqual(correct_s, stream.getvalue())

def test_is_reference(self):
m = ConcreteModel()
class _NotSpecified(object):
pass
m.comp = Component(ctype=_NotSpecified)
self.assertFalse(m.comp.is_reference())

class TestEnviron(unittest.TestCase):

Expand Down
52 changes: 52 additions & 0 deletions pyomo/core/tests/unit/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pyomo.core.base.indexed_component import (
UnindexedComponent_set, IndexedComponent
)
from pyomo.core.base.indexed_component_slice import IndexedComponent_slice
from pyomo.core.base.reference import (
_ReferenceDict, _ReferenceSet, Reference
)
Expand Down Expand Up @@ -803,6 +804,57 @@ def test_reference_to_list(self):
KeyError, "Index '1' is not valid for indexed component 'r'"):
m.r[1] = m.x

def test_is_reference(self):
m = ConcreteModel()
m.v0 = Var()
m.v1 = Var([1,2,3])

m.ref0 = Reference(m.v0)
m.ref1 = Reference(m.v1)

self.assertFalse(m.v0.is_reference())
self.assertFalse(m.v1.is_reference())

self.assertTrue(m.ref0.is_reference())
self.assertTrue(m.ref1.is_reference())

unique_vars = list(
v for v in m.component_objects(Var) if not v.is_reference())
self.assertEqual(len(unique_vars), 2)

def test_referent(self):
m = ConcreteModel()
m.v0 = Var()
m.v2 = Var([1, 2, 3],['a', 'b'])

varlist = [m.v2[1, 'a'], m.v2[1, 'b']]

vardict = {
0: m.v0,
1: m.v2[1, 'a'],
2: m.v2[2, 'a'],
3: m.v2[3, 'a'],
}

scalar_ref = Reference(m.v0)
self.assertIs(scalar_ref.referent, m.v0)

sliced_ref = Reference(m.v2[:,'a'])
referent = sliced_ref.referent
self.assertIs(type(referent), IndexedComponent_slice)
self.assertEqual(len(referent._call_stack), 1)
call, info = referent._call_stack[0]
self.assertEqual(call, IndexedComponent_slice.slice_info)
self.assertIs(info[0], m.v2)
self.assertEqual(info[1], {1: 'a'}) # Fixed
self.assertEqual(info[2], {0: slice(None)}) # Sliced
self.assertIs(info[3], None) # Ellipsis

list_ref = Reference(varlist)
self.assertIs(list_ref.referent, varlist)

dict_ref = Reference(vardict)
self.assertIs(dict_ref.referent, vardict)

if __name__ == "__main__":
unittest.main()

0 comments on commit dcd3fd8

Please sign in to comment.