Skip to content

Commit

Permalink
Allow keys to be encoded before use.
Browse files Browse the repository at this point in the history
Ported patch in #52 from @harlowja to current branch.  Added tests.

For the cases where the user wants to transparently
encode keys (say using urllib) before they are used
further allow a encoding function to be passed in that
will perform these types of activities (by default it
is the identity function).
  • Loading branch information
linsomniac committed Apr 18, 2023
1 parent 88b83c6 commit 3c8465f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 18 deletions.
45 changes: 27 additions & 18 deletions memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
pload=None, pid=None,
server_max_key_length=None, server_max_value_length=None,
dead_retry=_DEAD_RETRY, socket_timeout=_SOCKET_TIMEOUT,
cache_cas=False, flush_on_reconnect=0, check_keys=True):
cache_cas=False, flush_on_reconnect=0, check_keys=True,
key_encoder=None):
"""Create a new Client object with the given list of servers.
@param servers: C{servers} is passed to L{set_servers}.
Expand Down Expand Up @@ -205,6 +206,10 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
@param check_keys: (default True) If True, the key is checked
to ensure it is the correct length and composed of the right
characters.
@param key_encoder: (default None) If provided a functor that will
be called to encode keys before they are checked and used. It will
be expected to take one parameter (the key) and return a new encoded
key as a result.
"""
super(Client, self).__init__()
self.debug = debug
Expand All @@ -226,6 +231,10 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
self.persistent_load = pload
self.persistent_id = pid
self.server_max_key_length = server_max_key_length
if key_encoder is None:
def key_encoder(key):
return key
self.key_encoder = key_encoder
if self.server_max_key_length is None:
self.server_max_key_length = SERVER_MAX_KEY_LENGTH
self.server_max_value_length = server_max_value_length
Expand Down Expand Up @@ -494,7 +503,7 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
else:
headers = None
for key in server_keys[server]: # These are mangled keys
cmd = self._encode_cmd('delete', key, headers, noreply, b'\r\n')
cmd = self._encode_cmd('delete', self.key_encoder(key), headers, noreply, b'\r\n')
write(cmd)
try:
server.send_cmds(b''.join(bigcmd))
Expand Down Expand Up @@ -532,7 +541,7 @@ def delete(self, key, noreply=False):
reply.
@rtype: int
'''
key = self._encode_key(key)
key = self._encode_key(self.key_encoder(key))
if self.do_check_key:
self.check_key(key)
server, key = self._get_server(key)
Expand Down Expand Up @@ -568,7 +577,7 @@ def touch(self, key, time=0, noreply=False):
reply.
@rtype: int
'''
key = self._encode_key(key)
key = self._encode_key(self.key_encoder(key))
if self.do_check_key:
self.check_key(key)
server, key = self._get_server(key)
Expand Down Expand Up @@ -622,7 +631,7 @@ def incr(self, key, delta=1, noreply=False):
@return: New value after incrementing, no None for noreply or error.
@rtype: int
"""
return self._incrdecr("incr", key, delta, noreply)
return self._incrdecr("incr", self.key_encoder(key), delta, noreply)

def decr(self, key, delta=1, noreply=False):
"""Decrement value for C{key} by C{delta}
Expand All @@ -640,7 +649,7 @@ def decr(self, key, delta=1, noreply=False):
@return: New value after decrementing, or None for noreply or error.
@rtype: int
"""
return self._incrdecr("decr", key, delta, noreply)
return self._incrdecr("decr", self.key_encoder(key), delta, noreply)

def _incrdecr(self, cmd, key, delta, noreply=False):
key = self._encode_key(key)
Expand Down Expand Up @@ -674,7 +683,7 @@ def add(self, key, val, time=0, min_compress_len=0, noreply=False):
@return: Nonzero on success.
@rtype: int
'''
return self._set("add", key, val, time, min_compress_len, noreply)
return self._set("add", self.key_encoder(key), val, time, min_compress_len, noreply)

def append(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Append the value to the end of the existing key's value.
Expand All @@ -685,7 +694,7 @@ def append(self, key, val, time=0, min_compress_len=0, noreply=False):
@return: Nonzero on success.
@rtype: int
'''
return self._set("append", key, val, time, min_compress_len, noreply)
return self._set("append", self.key_encoder(key), val, time, min_compress_len, noreply)

def prepend(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Prepend the value to the beginning of the existing key's value.
Expand All @@ -696,7 +705,7 @@ def prepend(self, key, val, time=0, min_compress_len=0, noreply=False):
@return: Nonzero on success.
@rtype: int
'''
return self._set("prepend", key, val, time, min_compress_len, noreply)
return self._set("prepend", self.key_encoder(key), val, time, min_compress_len, noreply)

def replace(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Replace existing key with value.
Expand All @@ -707,7 +716,7 @@ def replace(self, key, val, time=0, min_compress_len=0, noreply=False):
@return: Nonzero on success.
@rtype: int
'''
return self._set("replace", key, val, time, min_compress_len, noreply)
return self._set("replace", self.key_encoder(key), val, time, min_compress_len, noreply)

def set(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Unconditionally sets a key to a given value in the memcache.
Expand Down Expand Up @@ -743,7 +752,7 @@ def set(self, key, val, time=0, min_compress_len=0, noreply=False):
'''
if isinstance(time, timedelta):
time = int(time.total_seconds())
return self._set("set", key, val, time, min_compress_len, noreply)
return self._set("set", self.key_encoder(key), val, time, min_compress_len, noreply)

def cas(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Check and set (CAS)
Expand Down Expand Up @@ -780,7 +789,7 @@ def cas(self, key, val, time=0, min_compress_len=0, noreply=False):
@param noreply: optional parameter instructs the server to not
send the reply.
'''
return self._set("cas", key, val, time, min_compress_len, noreply)
return self._set("cas", self.key_encoder(key), val, time, min_compress_len, noreply)

def _map_and_prefix_keys(self, key_iterable, key_prefix):
"""Map keys to the servers they will reside on.
Expand All @@ -807,7 +816,7 @@ def _map_and_prefix_keys(self, key_iterable, key_prefix):
# Ensure call to _get_server gets a Tuple as well.
serverhash, key = orig_key

key = self._encode_key(key)
key = self._encode_key(self.key_encoder(key))
if not isinstance(key, six.binary_type):
# set_multi supports int / long keys.
key = str(key).encode('utf8')
Expand All @@ -818,7 +827,7 @@ def _map_and_prefix_keys(self, key_iterable, key_prefix):
server, key = self._get_server(
(serverhash, key_prefix + key))
else:
key = self._encode_key(orig_key)
key = self._encode_key(self.key_encoder(orig_key))
if not isinstance(key, six.binary_type):
# set_multi supports int / long keys.
key = str(key).encode('utf8')
Expand Down Expand Up @@ -923,7 +932,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
if store_info:
flags, len_val, val = store_info
headers = "%d %d %d" % (flags, time, len_val)
fullcmd = self._encode_cmd('set', key, headers,
fullcmd = self._encode_cmd('set', self.key_encoder(key), headers,
noreply,
b'\r\n', val, b'\r\n')
write(fullcmd)
Expand Down Expand Up @@ -1121,14 +1130,14 @@ def get(self, key, default=None):
@return: The value or None.
'''
return self._get('get', key, default)
return self._get('get', self.key_encoder(key), default)

def gets(self, key):
'''Retrieves a key from the memcache. Used in conjunction with 'cas'.
@return: The value or None.
'''
return self._get('gets', key)
return self._get('gets', self.key_encoder(key))

def get_multi(self, keys, key_prefix=''):
'''Retrieves multiple keys from the memcache doing just one query.
Expand Down Expand Up @@ -1188,7 +1197,7 @@ def get_multi(self, keys, key_prefix=''):
self._statlog('get_multi')

server_keys, prefixed_to_orig_key = self._map_and_prefix_keys(
keys, key_prefix)
[self.key_encoder(k) for k in keys], key_prefix)

# send out all requests on each server before reading anything
dead_servers = []
Expand Down
27 changes: 27 additions & 0 deletions tests/test_memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,5 +252,32 @@ def test_touch_unexpected_reply(self, mock_readline, mock_send_cmd):
)


class TestMemcacheEncoder(unittest.TestCase):
def setUp(self):
# TODO(): unix socket server stuff
servers = ["127.0.0.1:11211"]
self.mc = Client(servers, debug=1, key_encoder=self.encoder)

def tearDown(self):
self.mc.flush_all()
self.mc.disconnect_all()

def encoder(self, key):
return key.lower()

def check_setget(self, key, val, noreply=False):
self.mc.set(key, val, noreply=noreply)
newval = self.mc.get(key)
self.assertEqual(newval, val)

def test_setget(self):
self.check_setget("a_string", "some random string")
self.check_setget("A_String2", "some random string")
self.check_setget("an_integer", 42)
self.assertEqual("some random string", self.mc.get("A_String"))
self.assertEqual("some random string", self.mc.get("a_sTRing2"))
self.assertEqual(42, self.mc.get("An_Integer"))


if __name__ == '__main__':
unittest.main()

0 comments on commit 3c8465f

Please sign in to comment.