diff --git a/pymemcache/serde.py b/pymemcache/serde.py index cf57ace7..6291d61d 100644 --- a/pymemcache/serde.py +++ b/pymemcache/serde.py @@ -14,6 +14,7 @@ import logging from io import BytesIO +import six from six.moves import cPickle as pickle try: @@ -22,22 +23,35 @@ long_type = None +FLAG_BYTES = 0 FLAG_PICKLE = 1 << 0 FLAG_INTEGER = 1 << 1 FLAG_LONG = 1 << 2 +FLAG_COMPRESSED = 1 << 3 # unused, to main compatability with python-memcached +FLAG_TEXT = 1 << 4 def python_memcache_serializer(key, value): flags = 0 + value_type = type(value) - if isinstance(value, str): + # Check against exact types so that subclasses of native types will be + # restored as their native type + if value_type is bytes: pass - elif isinstance(value, int): + + elif value_type is six.text_type: + flags |= FLAG_TEXT + value = value.encode('utf8') + + elif value_type is int: flags |= FLAG_INTEGER value = "%d" % value - elif long_type is not None and isinstance(value, long_type): + + elif six.PY2 and value_type is long_type: flags |= FLAG_LONG value = "%d" % value + else: flags |= FLAG_PICKLE output = BytesIO() @@ -52,13 +66,19 @@ def python_memcache_deserializer(key, value, flags): if flags == 0: return value - if flags & FLAG_INTEGER: + elif flags & FLAG_TEXT: + return value.decode('utf8') + + elif flags & FLAG_INTEGER: return int(value) - if flags & FLAG_LONG: - return long_type(value) + elif flags & FLAG_LONG: + if six.PY3: + return int(value) + else: + return long_type(value) - if flags & FLAG_PICKLE: + elif flags & FLAG_PICKLE: try: buf = BytesIO(value) unpickler = pickle.Unpickler(buf) diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py index 14e9f03b..4cbf30eb 100644 --- a/pymemcache/test/test_integration.py +++ b/pymemcache/test/test_integration.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict import json import pytest import six @@ -22,6 +23,10 @@ MemcacheIllegalInputError, MemcacheClientError ) +from pymemcache.serde import ( + python_memcache_serializer, + python_memcache_deserializer +) def get_set_helper(client, key, value, key2, value2): @@ -231,6 +236,33 @@ def _des(key, value, flags): assert result == value +@pytest.mark.integration() +def test_serde_serialization(client_class, host, port, socket_module): + def check(value): + client.set(b'key', value, noreply=False) + result = client.get(b'key') + assert result == value + assert type(result) is type(value) + + client = client_class((host, port), serializer=python_memcache_serializer, + deserializer=python_memcache_deserializer, + socket_module=socket_module) + client.flush_all() + + check(b'byte string') + check(u'unicode string') + check('olé') + check(u'olé') + check(1) + check(123123123123123123123) + check({'a': 'pickle'}) + check([u'one pickle', u'two pickle']) + testdict = defaultdict(int) + testdict[u'one pickle'] + testdict[b'two pickle'] + check(testdict) + + @pytest.mark.integration() def test_errors(client_class, host, port, socket_module): client = client_class((host, port), socket_module=socket_module) diff --git a/pymemcache/test/test_serde.py b/pymemcache/test/test_serde.py index 9849cdc8..04c8e079 100644 --- a/pymemcache/test/test_serde.py +++ b/pymemcache/test/test_serde.py @@ -1,24 +1,61 @@ +# -*- coding: utf-8 -*- from unittest import TestCase from pymemcache.serde import (python_memcache_serializer, - python_memcache_deserializer) + python_memcache_deserializer, FLAG_BYTES, + FLAG_PICKLE, FLAG_INTEGER, FLAG_LONG, FLAG_TEXT) +import pytest +import six +class CustomInt(int): + """ + Custom integer type for testing. + + Entirely useless, but used to show that built in types get serialized and + deserialized back as the same type of object. + """ + pass + + +@pytest.mark.unit() class TestSerde(TestCase): - def check(self, value): + def check(self, value, expected_flags): serialized, flags = python_memcache_serializer(b'key', value) + assert flags == expected_flags + + # pymemcache stores values as byte strings, so we immediately the value + # if needed so deserialized works as it would with a real server + if not isinstance(serialized, six.binary_type): + serialized = six.text_type(serialized).encode('ascii') + deserialized = python_memcache_deserializer(b'key', serialized, flags) assert deserialized == value - def test_str(self): - self.check('value') + def test_bytes(self): + self.check(b'value', FLAG_BYTES) + self.check(b'\xc2\xa3 $ \xe2\x82\xac', FLAG_BYTES) # £ $ € + + def test_unicode(self): + self.check(u'value', FLAG_TEXT) + self.check(u'£ $ €', FLAG_TEXT) def test_int(self): - self.check(1) + self.check(1, FLAG_INTEGER) def test_long(self): - self.check(123123123123123123123) + # long only exists with Python 2, so we're just testing for another + # integer with Python 3 + if six.PY2: + expected_flags = FLAG_LONG + else: + expected_flags = FLAG_INTEGER + self.check(123123123123123123123, expected_flags) def test_pickleable(self): - self.check({'a': 'dict'}) + self.check({'a': 'dict'}, FLAG_PICKLE) + + def test_subtype(self): + # Subclass of a native type will be restored as the same type + self.check(CustomInt(123123), FLAG_PICKLE)