diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 2892029c..05db0818 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -294,7 +294,7 @@ def set(self, key, value, expire=0, noreply=None): """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b'set', key, expire, noreply, value) + return self._store_cmd(b'set', {key: value}, expire, noreply)[key] def set_many(self, values, expire=0, noreply=None): """ @@ -312,17 +312,10 @@ def set_many(self, values, expire=0, noreply=None): Returns a list of keys that failed to be inserted. If noreply is True, alwais returns empty list. """ - # TODO: make this more performant by sending all the values first, then - # waiting for all the responses. if noreply is None: noreply = self.default_noreply - - failed = [] - for key, value in six.iteritems(values): - result = self.set(key, value, expire, noreply) - if not result: - failed.append(key) - return failed + result = self._store_cmd(b'set', values, expire, noreply) + return [k for k, v in six.iteritems(result) if not v] set_multi = set_many @@ -345,7 +338,7 @@ def add(self, key, value, expire=0, noreply=None): """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b'add', key, expire, noreply, value) + return self._store_cmd(b'add', {key: value}, expire, noreply)[key] def replace(self, key, value, expire=0, noreply=None): """ @@ -366,7 +359,7 @@ def replace(self, key, value, expire=0, noreply=None): """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b'replace', key, expire, noreply, value) + return self._store_cmd(b'replace', {key: value}, expire, noreply)[key] def append(self, key, value, expire=0, noreply=None): """ @@ -385,7 +378,7 @@ def append(self, key, value, expire=0, noreply=None): """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b'append', key, expire, noreply, value) + return self._store_cmd(b'append', {key: value}, expire, noreply)[key] def prepend(self, key, value, expire=0, noreply=None): """ @@ -404,7 +397,7 @@ def prepend(self, key, value, expire=0, noreply=None): """ if noreply is None: noreply = self.default_noreply - return self._store_cmd(b'prepend', key, expire, noreply, value) + return self._store_cmd(b'prepend', {key: value}, expire, noreply)[key] def cas(self, key, value, cas, expire=0, noreply=False): """ @@ -423,7 +416,7 @@ def cas(self, key, value, cas, expire=0, noreply=False): the key didn't exist, False if it existed but had a different cas value and True if it existed and was changed. """ - return self._store_cmd(b'cas', key, expire, noreply, value, cas) + return self._store_cmd(b'cas', {key: value}, expire, noreply, cas)[key] def get(self, key, default=None): """ @@ -769,55 +762,63 @@ def _fetch_cmd(self, name, keys, expect_cas): return {} raise - def _store_cmd(self, name, key, expire, noreply, data, cas=None): - key = self.check_key(key) - if not self.sock: - self._connect() + def _store_cmd(self, name, values, expire, noreply, cas=None): + cmds = [] + keys = [] + for key, data in six.iteritems(values): + # must be able to reliably map responses back to the original order + keys.append(key) - if self.serializer: - data, flags = self.serializer(key, data) - else: - flags = 0 + key = self.check_key(key) + if self.serializer: + data, flags = self.serializer(key, data) + else: + flags = 0 - if not isinstance(data, six.binary_type): - try: - data = six.text_type(data).encode('ascii') - except UnicodeEncodeError as e: - raise MemcacheIllegalInputError(str(e)) + if not isinstance(data, six.binary_type): + try: + data = six.text_type(data).encode('ascii') + except UnicodeEncodeError as e: + raise MemcacheIllegalInputError(str(e)) - extra = b'' - if cas is not None: - extra += b' ' + cas - if noreply: - extra += b' noreply' + extra = b'' + if cas is not None: + extra += b' ' + cas + if noreply: + extra += b' noreply' - cmd = (name + b' ' + key + b' ' + - six.text_type(flags).encode('ascii') + - b' ' + six.text_type(expire).encode('ascii') + - b' ' + six.text_type(len(data)).encode('ascii') + extra + - b'\r\n' + data + b'\r\n') + cmds.append(name + b' ' + key + b' ' + + six.text_type(flags).encode('ascii') + + b' ' + six.text_type(expire).encode('ascii') + + b' ' + six.text_type(len(data)).encode('ascii') + + extra + b'\r\n' + data + b'\r\n') - try: - self.sock.sendall(cmd) + if not self.sock: + self._connect() + try: + self.sock.sendall(b''.join(cmds)) if noreply: - return True + return {k: True for k in keys} + results = {} buf = b'' - buf, line = _readline(self.sock, buf) - self._raise_errors(line, name) - - if line in VALID_STORE_RESULTS[name]: - if line == b'STORED': - return True - if line == b'NOT_STORED': - return False - if line == b'NOT_FOUND': - return None - if line == b'EXISTS': - return False - else: - raise MemcacheUnknownError(line[:32]) + for key in keys: + buf, line = _readline(self.sock, buf) + self._raise_errors(line, name) + + if line in VALID_STORE_RESULTS[name]: + if line == b'STORED': + results[key] = True + if line == b'NOT_STORED': + results[key] = False + if line == b'NOT_FOUND': + results[key] = None + if line == b'EXISTS': + results[key] = False + else: + raise MemcacheUnknownError(line[:32]) + return results except Exception: self.close() raise diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index a20b740f..e7871367 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -15,7 +15,9 @@ import collections import errno +import functools import json +import mock import socket import unittest import pytest @@ -85,7 +87,12 @@ def __getattr__(self, name): class ClientTestMixin(object): def make_client(self, mock_socket_values, **kwargs): client = Client(None, **kwargs) - client.sock = MockSocket(list(mock_socket_values)) + # mock out client._connect() rather than hard-settting client.sock to + # ensure methods are checking whether self.sock is None before + # attempting to use it + sock = MockSocket(list(mock_socket_values)) + client._connect = mock.Mock(side_effect=functools.partial( + setattr, client, "sock", sock)) return client def test_set_success(self):