diff --git a/rpyc/core/netref.py b/rpyc/core/netref.py index 3a9542ad..329c10e4 100644 --- a/rpyc/core/netref.py +++ b/rpyc/core/netref.py @@ -21,6 +21,7 @@ '__init__', '__metaclass__', '__module__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__slots__', '__str__', '__weakref__', '__dict__', '__members__', '__methods__', '__exit__', + '__eq__', '__ne__', '__lt__', '__gt__', '__le__', '__ge__', ]) | _deleted_netref_attrs """the set of attributes that are local to the netref object""" @@ -172,7 +173,19 @@ def __dir__(self): def __hash__(self): return syncreq(self, consts.HANDLE_HASH) def __cmp__(self, other): - return syncreq(self, consts.HANDLE_CMP, other) + return syncreq(self, consts.HANDLE_CMP, other, '__cmp__') + def __eq__(self, other): + return syncreq(self, consts.HANDLE_CMP, other, '__eq__') + def __ne__(self, other): + return syncreq(self, consts.HANDLE_CMP, other, '__ne__') + def __lt__(self, other): + return syncreq(self, consts.HANDLE_CMP, other, '__lt__') + def __gt__(self, other): + return syncreq(self, consts.HANDLE_CMP, other, '__gt__') + def __le__(self, other): + return syncreq(self, consts.HANDLE_CMP, other, '__le__') + def __ge__(self, other): + return syncreq(self, consts.HANDLE_CMP, other, '__ge__') def __repr__(self): return syncreq(self, consts.HANDLE_REPR) def __str__(self): @@ -278,4 +291,3 @@ def class_factory(clsname, modname, methods): for cls in _builtin_types: builtin_classes_cache[cls.__name__, cls.__module__] = class_factory( cls.__name__, cls.__module__, inspect_methods(cls)) - diff --git a/rpyc/core/protocol.py b/rpyc/core/protocol.py index efcb1a56..ad751c10 100644 --- a/rpyc/core/protocol.py +++ b/rpyc/core/protocol.py @@ -577,11 +577,11 @@ def _handle_repr(self, obj): return repr(obj) def _handle_str(self, obj): return str(obj) - def _handle_cmp(self, obj, other): + def _handle_cmp(self, obj, other, op='__cmp__'): # cmp() might enter recursive resonance... yet another workaround #return cmp(obj, other) try: - return type(obj).__cmp__(obj, other) + return getattr(type(obj), op)(obj, other) except (AttributeError, TypeError): return NotImplemented def _handle_hash(self, obj): diff --git a/tests/test_magic.py b/tests/test_magic.py new file mode 100644 index 00000000..4ab14122 --- /dev/null +++ b/tests/test_magic.py @@ -0,0 +1,68 @@ +import sys +import rpyc +import unittest + +is_py3 = sys.version_info >= (3,) + +class Meta(type): + + def __hash__(self): + return 4321 + +Base = Meta('Base', (object,), {}) + +class Foo(Base): + def __hash__(self): + return 1234 + +class Bar(Foo): + pass + +class Mux(Foo): + def __eq__(self, other): + return True + + +class TestContextManagers(unittest.TestCase): + def setUp(self): + self.conn = rpyc.classic.connect_thread() + + def tearDown(self): + self.conn.close() + + def test_hash_class(self): + hesh = self.conn.builtins.hash + mod = self.conn.modules.test_magic + self.assertEqual(hash(mod.Base), 4321) + self.assertEqual(hash(mod.Foo), 4321) + self.assertEqual(hash(mod.Bar), 4321) + self.assertEqual(hash(mod.Base().__class__), 4321) + self.assertEqual(hash(mod.Foo().__class__), 4321) + self.assertEqual(hash(mod.Bar().__class__), 4321) + + basecl_ = mod.Foo().__class__.__mro__[1] + object_ = mod.Foo().__class__.__mro__[2] + self.assertEqual(hash(basecl_), hesh(basecl_)) + self.assertEqual(hash(object_), hesh(object_)) + self.assertEqual(hash(object_), hesh(self.conn.builtins.object)) + + def test_hash_obj(self): + hesh = self.conn.builtins.hash + mod = self.conn.modules.test_magic + obj = mod.Base() + + self.assertNotEqual(hash(obj), 1234) + self.assertNotEqual(hash(obj), 4321) + self.assertEqual(hash(obj), hesh(obj)) + + self.assertEqual(hash(mod.Foo()), 1234) + self.assertEqual(hash(mod.Bar()), 1234) + if is_py3: + with self.assertRaises(TypeError): + hash(mod.Mux()) + else: + self.assertEqual(hash(mod.Mux()), 1234) + + +if __name__ == "__main__": + unittest.main()