Skip to content

Commit

Permalink
Merge pull request #182 from shargan/shargan/fix-set-multi
Browse files Browse the repository at this point in the history
Properly batch Client.set_many() calls
  • Loading branch information
cgordon authored Aug 30, 2018
2 parents fdca1a3 + abe233f commit 99d312a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 56 deletions.
111 changes: 56 additions & 55 deletions pymemcache/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def set(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'set', key, expire, noreply, value)
return self._store_cmd(b'set', {key: value}, expire, noreply)[key]

def set_many(self, values, expire=0, noreply=None):
"""
Expand All @@ -312,17 +312,10 @@ def set_many(self, values, expire=0, noreply=None):
Returns a list of keys that failed to be inserted.
If noreply is True, alwais returns empty list.
"""
# TODO: make this more performant by sending all the values first, then
# waiting for all the responses.
if noreply is None:
noreply = self.default_noreply

failed = []
for key, value in six.iteritems(values):
result = self.set(key, value, expire, noreply)
if not result:
failed.append(key)
return failed
result = self._store_cmd(b'set', values, expire, noreply)
return [k for k, v in six.iteritems(result) if not v]

set_multi = set_many

Expand All @@ -345,7 +338,7 @@ def add(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'add', key, expire, noreply, value)
return self._store_cmd(b'add', {key: value}, expire, noreply)[key]

def replace(self, key, value, expire=0, noreply=None):
"""
Expand All @@ -366,7 +359,7 @@ def replace(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'replace', key, expire, noreply, value)
return self._store_cmd(b'replace', {key: value}, expire, noreply)[key]

def append(self, key, value, expire=0, noreply=None):
"""
Expand All @@ -385,7 +378,7 @@ def append(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'append', key, expire, noreply, value)
return self._store_cmd(b'append', {key: value}, expire, noreply)[key]

def prepend(self, key, value, expire=0, noreply=None):
"""
Expand All @@ -404,7 +397,7 @@ def prepend(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'prepend', key, expire, noreply, value)
return self._store_cmd(b'prepend', {key: value}, expire, noreply)[key]

def cas(self, key, value, cas, expire=0, noreply=False):
"""
Expand All @@ -423,7 +416,7 @@ def cas(self, key, value, cas, expire=0, noreply=False):
the key didn't exist, False if it existed but had a different cas
value and True if it existed and was changed.
"""
return self._store_cmd(b'cas', key, expire, noreply, value, cas)
return self._store_cmd(b'cas', {key: value}, expire, noreply, cas)[key]

def get(self, key, default=None):
"""
Expand Down Expand Up @@ -769,55 +762,63 @@ def _fetch_cmd(self, name, keys, expect_cas):
return {}
raise

def _store_cmd(self, name, key, expire, noreply, data, cas=None):
key = self.check_key(key)
if not self.sock:
self._connect()
def _store_cmd(self, name, values, expire, noreply, cas=None):
cmds = []
keys = []
for key, data in six.iteritems(values):
# must be able to reliably map responses back to the original order
keys.append(key)

if self.serializer:
data, flags = self.serializer(key, data)
else:
flags = 0
key = self.check_key(key)
if self.serializer:
data, flags = self.serializer(key, data)
else:
flags = 0

if not isinstance(data, six.binary_type):
try:
data = six.text_type(data).encode('ascii')
except UnicodeEncodeError as e:
raise MemcacheIllegalInputError(str(e))
if not isinstance(data, six.binary_type):
try:
data = six.text_type(data).encode('ascii')
except UnicodeEncodeError as e:
raise MemcacheIllegalInputError(str(e))

extra = b''
if cas is not None:
extra += b' ' + cas
if noreply:
extra += b' noreply'
extra = b''
if cas is not None:
extra += b' ' + cas
if noreply:
extra += b' noreply'

cmd = (name + b' ' + key + b' ' +
six.text_type(flags).encode('ascii') +
b' ' + six.text_type(expire).encode('ascii') +
b' ' + six.text_type(len(data)).encode('ascii') + extra +
b'\r\n' + data + b'\r\n')
cmds.append(name + b' ' + key + b' ' +
six.text_type(flags).encode('ascii') +
b' ' + six.text_type(expire).encode('ascii') +
b' ' + six.text_type(len(data)).encode('ascii') +
extra + b'\r\n' + data + b'\r\n')

try:
self.sock.sendall(cmd)
if not self.sock:
self._connect()

try:
self.sock.sendall(b''.join(cmds))
if noreply:
return True
return {k: True for k in keys}

results = {}
buf = b''
buf, line = _readline(self.sock, buf)
self._raise_errors(line, name)

if line in VALID_STORE_RESULTS[name]:
if line == b'STORED':
return True
if line == b'NOT_STORED':
return False
if line == b'NOT_FOUND':
return None
if line == b'EXISTS':
return False
else:
raise MemcacheUnknownError(line[:32])
for key in keys:
buf, line = _readline(self.sock, buf)
self._raise_errors(line, name)

if line in VALID_STORE_RESULTS[name]:
if line == b'STORED':
results[key] = True
if line == b'NOT_STORED':
results[key] = False
if line == b'NOT_FOUND':
results[key] = None
if line == b'EXISTS':
results[key] = False
else:
raise MemcacheUnknownError(line[:32])
return results
except Exception:
self.close()
raise
Expand Down
9 changes: 8 additions & 1 deletion pymemcache/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import collections
import errno
import functools
import json
import mock
import socket
import unittest
import pytest
Expand Down Expand Up @@ -85,7 +87,12 @@ def __getattr__(self, name):
class ClientTestMixin(object):
def make_client(self, mock_socket_values, **kwargs):
client = Client(None, **kwargs)
client.sock = MockSocket(list(mock_socket_values))
# mock out client._connect() rather than hard-settting client.sock to
# ensure methods are checking whether self.sock is None before
# attempting to use it
sock = MockSocket(list(mock_socket_values))
client._connect = mock.Mock(side_effect=functools.partial(
setattr, client, "sock", sock))
return client

def test_set_success(self):
Expand Down

0 comments on commit 99d312a

Please sign in to comment.