Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow keys to be encoded before usage #52

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
pload=None, pid=None,
server_max_key_length=None, server_max_value_length=None,
dead_retry=_DEAD_RETRY, socket_timeout=_SOCKET_TIMEOUT,
cache_cas=False, flush_on_reconnect=0, check_keys=True):
cache_cas=False, flush_on_reconnect=0, check_keys=True,
key_encoder=None):
"""Create a new Client object with the given list of servers.

@param servers: C{servers} is passed to L{set_servers}.
Expand Down Expand Up @@ -224,6 +225,10 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
@param check_keys: (default True) If True, the key is checked
to ensure it is the correct length and composed of the right
characters.
@param key_encoder: (default None) If provided a functor that will
be called to encode keys before they are checked and used. It will
be expected to take one parameter (the key) and return a new encoded
key as a result.
"""
super(Client, self).__init__()
self.debug = debug
Expand All @@ -240,6 +245,10 @@ def __init__(self, servers, debug=0, pickleProtocol=0,
self.pickleProtocol = pickleProtocol
self.pickler = pickler
self.unpickler = unpickler
if key_encoder is None:
self.key_encoder = lambda key: key
else:
self.key_encoder = key_encoder
self.persistent_load = pload
self.persistent_id = pid
self.server_max_key_length = server_max_key_length
Expand Down Expand Up @@ -474,7 +483,8 @@ def delete(self, key, time=0):
should fail. Defaults to None for no delay.
@rtype: int
'''
return self._deletetouch(['DELETED', 'NOT_FOUND'], "delete", key, time)
return self._deletetouch(['DELETED', 'NOT_FOUND'], "delete",
self.key_encoder(key), time)

def touch(self, key, time=0):
'''Updates the expiration time of a key in memcache.
Expand All @@ -487,7 +497,8 @@ def touch(self, key, time=0):
default to 0 == cache forever.
@rtype: int
'''
return self._deletetouch(['TOUCHED'], "touch", key, time)
return self._deletetouch(['TOUCHED'], "touch",
self.key_encoder(key), time)

def _deletetouch(self, expected, cmd, key, time=0):
if self.do_check_key:
Expand Down Expand Up @@ -542,7 +553,7 @@ def incr(self, key, delta=1):
@return: New value after incrementing.
@rtype: int
"""
return self._incrdecr("incr", key, delta)
return self._incrdecr("incr", self.key_encoder(key), delta)

def decr(self, key, delta=1):
"""Decrement value for C{key} by C{delta}
Expand All @@ -557,7 +568,7 @@ def decr(self, key, delta=1):
@return: New value after decrementing or None on error.
@rtype: int
"""
return self._incrdecr("decr", key, delta)
return self._incrdecr("decr", self.key_encoder(key), delta)

def _incrdecr(self, cmd, key, delta):
if self.do_check_key:
Expand Down Expand Up @@ -588,7 +599,8 @@ def add(self, key, val, time=0, min_compress_len=0):
@return: Nonzero on success.
@rtype: int
'''
return self._set("add", key, val, time, min_compress_len)
return self._set("add", self.key_encoder(key),
val, time, min_compress_len)

def append(self, key, val, time=0, min_compress_len=0):
'''Append the value to the end of the existing key's value.
Expand All @@ -599,7 +611,8 @@ def append(self, key, val, time=0, min_compress_len=0):
@return: Nonzero on success.
@rtype: int
'''
return self._set("append", key, val, time, min_compress_len)
return self._set("append", self.key_encoder(key),
val, time, min_compress_len)

def prepend(self, key, val, time=0, min_compress_len=0):
'''Prepend the value to the beginning of the existing key's value.
Expand All @@ -610,7 +623,8 @@ def prepend(self, key, val, time=0, min_compress_len=0):
@return: Nonzero on success.
@rtype: int
'''
return self._set("prepend", key, val, time, min_compress_len)
return self._set("prepend", self.key_encoder(key),
val, time, min_compress_len)

def replace(self, key, val, time=0, min_compress_len=0):
'''Replace existing key with value.
Expand All @@ -621,7 +635,8 @@ def replace(self, key, val, time=0, min_compress_len=0):
@return: Nonzero on success.
@rtype: int
'''
return self._set("replace", key, val, time, min_compress_len)
return self._set("replace", self.key_encoder(key),
val, time, min_compress_len)

def set(self, key, val, time=0, min_compress_len=0):
'''Unconditionally sets a key to a given value in the memcache.
Expand Down Expand Up @@ -653,7 +668,8 @@ def set(self, key, val, time=0, min_compress_len=0):
ever try to compress.

'''
return self._set("set", key, val, time, min_compress_len)
return self._set("set", self.key_encoder(key),
val, time, min_compress_len)

def cas(self, key, val, time=0, min_compress_len=0):
'''Check and set (CAS)
Expand Down Expand Up @@ -687,7 +703,8 @@ def cas(self, key, val, time=0, min_compress_len=0):
compatability, this parameter defaults to 0, indicating don't
ever try to compress.
'''
return self._set("cas", key, val, time, min_compress_len)
return self._set("cas", self.key_encoder(key),
val, time, min_compress_len)

def _map_and_prefix_keys(self, key_iterable, key_prefix):
"""Compute the mapping of server (_Host instance) -> list of keys to
Expand All @@ -709,15 +726,15 @@ def _map_and_prefix_keys(self, key_iterable, key_prefix):
# Tuple of hashvalue, key ala _get_server(). Caller is
# essentially telling us what server to stuff this on.
# Ensure call to _get_server gets a Tuple as well.
str_orig_key = str(orig_key[1])
str_orig_key = self.key_encoder(str(orig_key[1]))

# Gotta pre-mangle key before hashing to a
# server. Returns the mangled key.
server, key = self._get_server(
(orig_key[0], key_prefix + str_orig_key))
else:
# set_multi supports int / long keys.
str_orig_key = str(orig_key)
str_orig_key = self.key_encoder(str(orig_key))
server, key = self._get_server(key_prefix + str_orig_key)

# Now check to make sure key length is proper ...
Expand Down Expand Up @@ -999,14 +1016,14 @@ def get(self, key):

@return: The value or None.
'''
return self._get('get', key)
return self._get('get', self.key_encoder(key))

def gets(self, key):
'''Retrieves a key from the memcache. Used in conjunction with 'cas'.

@return: The value or None.
'''
return self._get('gets', key)
return self._get('gets', self.key_encoder(key))

def get_multi(self, keys, key_prefix=''):
'''Retrieves multiple keys from the memcache doing just one query.
Expand Down