diff --git a/bmemcached/protocol.py b/bmemcached/protocol.py index b1132fb..59cc767 100644 --- a/bmemcached/protocol.py +++ b/bmemcached/protocol.py @@ -14,7 +14,7 @@ import six from six import binary_type, text_type -from bmemcached.compat import long +from bmemcached.compat import long, pickle from bmemcached.exceptions import AuthenticationNotSupported, InvalidCredentials, MemcachedException from bmemcached.utils import str_to_bytes @@ -54,24 +54,27 @@ class Protocol(threading.local): 'response': 0x81 } - # All structures will be appended to HEADER_STRUCT + # 'packer' is a struct.Struct compiled from HEADER_STRUCT plus the + # fixed-size leading "extras" bytes for that command. Variable-length + # tails (key, value, auth payloads) are concatenated as bytes after + # packer.pack(...). COMMANDS = { - 'get': {'command': 0x00, 'struct': '%ds'}, - 'getk': {'command': 0x0C, 'struct': '%ds'}, - 'getkq': {'command': 0x0D, 'struct': '%ds'}, - 'set': {'command': 0x01, 'struct': 'LL%ds%ds'}, - 'setq': {'command': 0x11, 'struct': 'LL%ds%ds'}, - 'add': {'command': 0x02, 'struct': 'LL%ds%ds'}, - 'addq': {'command': 0x12, 'struct': 'LL%ds%ds'}, - 'replace': {'command': 0x03, 'struct': 'LL%ds%ds'}, - 'delete': {'command': 0x04, 'struct': '%ds'}, - 'incr': {'command': 0x05, 'struct': 'QQL%ds'}, - 'decr': {'command': 0x06, 'struct': 'QQL%ds'}, - 'flush': {'command': 0x08, 'struct': 'I'}, - 'noop': {'command': 0x0a, 'struct': ''}, - 'stat': {'command': 0x10}, - 'auth_negotiation': {'command': 0x20}, - 'auth_request': {'command': 0x21, 'struct': '%ds%ds'}, + 'get': {'command': 0x00, 'packer': struct.Struct(HEADER_STRUCT)}, + 'getk': {'command': 0x0C, 'packer': struct.Struct(HEADER_STRUCT)}, + 'getkq': {'command': 0x0D, 'packer': struct.Struct(HEADER_STRUCT)}, + 'set': {'command': 0x01, 'packer': struct.Struct(HEADER_STRUCT + 'LL')}, + 'setq': {'command': 0x11, 'packer': struct.Struct(HEADER_STRUCT + 'LL')}, + 'add': {'command': 0x02, 'packer': struct.Struct(HEADER_STRUCT + 'LL')}, + 'addq': {'command': 0x12, 'packer': struct.Struct(HEADER_STRUCT + 'LL')}, + 'replace': {'command': 0x03, 'packer': struct.Struct(HEADER_STRUCT + 'LL')}, + 'delete': {'command': 0x04, 'packer': struct.Struct(HEADER_STRUCT)}, + 'incr': {'command': 0x05, 'packer': struct.Struct(HEADER_STRUCT + 'QQL')}, + 'decr': {'command': 0x06, 'packer': struct.Struct(HEADER_STRUCT + 'QQL')}, + 'flush': {'command': 0x08, 'packer': struct.Struct(HEADER_STRUCT + 'I')}, + 'noop': {'command': 0x0a, 'packer': struct.Struct(HEADER_STRUCT)}, + 'stat': {'command': 0x10, 'packer': struct.Struct(HEADER_STRUCT)}, + 'auth_negotiation': {'command': 0x20, 'packer': struct.Struct(HEADER_STRUCT)}, + 'auth_request': {'command': 0x21, 'packer': struct.Struct(HEADER_STRUCT)}, } STATUS = { @@ -297,10 +300,10 @@ def _send_authentication(self): return False logger.debug('Authenticating as %s', self._username) - self._send(struct.pack(self.HEADER_STRUCT, - self.MAGIC['request'], - self.COMMANDS['auth_negotiation']['command'], - 0, 0, 0, 0, 0, 0, 0)) + cmd = self.COMMANDS['auth_negotiation'] + self._send(cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + 0, 0, 0, 0, 0, 0, 0)) (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() @@ -324,10 +327,10 @@ def _send_authentication(self): if isinstance(auth, text_type): auth = auth.encode() - self._send(struct.pack(self.HEADER_STRUCT + - self.COMMANDS['auth_request']['struct'] % (len(method), len(auth)), - self.MAGIC['request'], self.COMMANDS['auth_request']['command'], - len(method), 0, 0, 0, len(method) + len(auth), 0, 0, method, auth)) + cmd = self.COMMANDS['auth_request'] + self._send(cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + len(method), 0, 0, 0, len(method) + len(auth), 0, 0) + method + auth) (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() @@ -357,7 +360,7 @@ def serialize(self, value, compress_level=-1): -1 = default compression level. :type compress_level: int :return: Serialized type - :rtype: str + :rtype: bytes """ flags = 0 if isinstance(value, binary_type): @@ -366,16 +369,19 @@ def serialize(self, value, compress_level=-1): value = value.encode('utf8') elif isinstance(value, int) and isinstance(value, bool) is False: flags |= self.FLAGS['integer'] - value = str(value) + value = str(value).encode() elif isinstance(value, long) and isinstance(value, bool) is False: flags |= self.FLAGS['long'] - value = str(value) + value = str(value).encode() else: flags |= self.FLAGS['object'] - buf = BytesIO() - pickler = self.pickler(buf, self.pickle_protocol) - pickler.dump(value) - value = buf.getvalue() + if self.pickler is None or self.pickler is pickle.Pickler: + value = pickle.dumps(value, self.pickle_protocol) + else: + buf = BytesIO() + pickler = self.pickler(buf, self.pickle_protocol) + pickler.dump(value) + value = buf.getvalue() if compress_level != 0 and len(value) > self.COMPRESSION_THRESHOLD: if compress_level is not None and compress_level > 0: @@ -415,9 +421,9 @@ def deserialize(self, value, flags): elif flags & FLAGS['long']: return long(value) elif flags & FLAGS['object']: - buf = BytesIO(value) - unpickler = self.unpickler(buf) - return unpickler.load() + if self.unpickler is None or self.unpickler is pickle.Unpickler: + return pickle.loads(value) + return self.unpickler(BytesIO(value)).load() if six.PY3: return value.decode('utf8') @@ -447,11 +453,11 @@ def get(self, key): """ logger.debug('Getting key %s', key) keybytes = str_to_bytes(key) - data = struct.pack(self.HEADER_STRUCT + - self.COMMANDS['get']['struct'] % (len(keybytes),), - self.MAGIC['request'], - self.COMMANDS['get']['command'], - len(keybytes), 0, 0, 0, len(keybytes), 0, 0, keybytes) + cmd = self.COMMANDS['get'] + klen = len(keybytes) + data = cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + klen, 0, 0, 0, klen, 0, 0) + keybytes self._send(data) (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, @@ -482,11 +488,10 @@ def noop(self): :rtype: int """ logger.debug('Sending NOOP') - data = struct.pack(self.HEADER_STRUCT + - self.COMMANDS['noop']['struct'], - self.MAGIC['request'], - self.COMMANDS['noop']['command'], - 0, 0, 0, 0, 0, 0, 0) + cmd = self.COMMANDS['noop'] + data = cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + 0, 0, 0, 0, 0, 0, 0) self._send(data) (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, @@ -521,38 +526,45 @@ def get_multi(self, keys): if n == 0: return {} + MAGIC_REQ = self.MAGIC['request'] + getkq = self.COMMANDS['getkq'] + GETKQ_CMD = getkq['command'] + pack_header = getkq['packer'].pack # same packer for getk and getkq + GETK_CMD = self.COMMANDS['getk']['command'] + msg = bytearray() - for i, key in enumerate(keys): - keybytes = str_to_bytes(key) - command = self.COMMANDS['getk' if i == n - 1 else 'getkq'] - msg += struct.pack(self.HEADER_STRUCT + - command['struct'] % (len(keybytes),), - self.MAGIC['request'], - command['command'], - len(keybytes), 0, 0, 0, len(keybytes), 0, 0, keybytes) + keybytes_list = [str_to_bytes(k) for k in keys] + last = n - 1 + for i, keybytes in enumerate(keybytes_list): + klen = len(keybytes) + opcode = GETK_CMD if i == last else GETKQ_CMD + msg += pack_header(MAGIC_REQ, opcode, klen, 0, 0, 0, klen, 0, 0) + msg += keybytes self._send(msg) d = {} + SUCCESS = self.STATUS['success'] + DISCONNECTED = self.STATUS['server_disconnected'] + NOT_FOUND = self.STATUS['key_not_found'] opcode = -1 - while opcode != self.COMMANDS['getk']['command']: + while opcode != GETK_CMD: (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() - if status == self.STATUS['success']: + if status == SUCCESS: flags, key, value = struct.unpack('!L%ds%ds' % (keylen, bodylen - keylen - 4), extra_content) d[key] = self.deserialize(value, flags), cas - elif status == self.STATUS['server_disconnected']: + elif status == DISCONNECTED: break - elif status != self.STATUS['key_not_found']: + elif status != NOT_FOUND: raise MemcachedException('Code: %d Message: %s' % (status, extra_content), status) ret = {} - for key in keys: - keybytes = str_to_bytes(key) + for key, keybytes in zip(keys, keybytes_list): if keybytes in d: ret[key] = d[keybytes] return ret @@ -581,16 +593,15 @@ def _set_add_replace(self, command, key, value, time, cas=0, compress_level=-1): logger.debug('Setting/adding/replacing key %s.', key) flags, value = self.serialize(value, compress_level=compress_level) logger.debug('Value bytes %s.', len(value)) - if isinstance(value, text_type): - value = value.encode('utf8') keybytes = str_to_bytes(key) - self._send(struct.pack(self.HEADER_STRUCT + - self.COMMANDS[command]['struct'] % (len(keybytes), len(value)), - self.MAGIC['request'], - self.COMMANDS[command]['command'], - len(keybytes), 8, 0, 0, len(keybytes) + len(value) + 8, 0, cas, flags, - time, keybytes, value)) + cmd = self.COMMANDS[command] + klen = len(keybytes) + vlen = len(value) + self._send(cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + klen, 8, 0, 0, klen + vlen + 8, 0, cas, + flags, time) + keybytes + value) (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() @@ -741,6 +752,12 @@ def set_multi(self, mappings, time=100, compress_level=-1): mappings = list(mappings.items()) msg = bytearray() + MAGIC_REQ = self.MAGIC['request'] + addq = self.COMMANDS['addq'] + ADDQ_CMD = addq['command'] + pack_set_prefix = addq['packer'].pack # same packer for setq/addq + SETQ_CMD = self.COMMANDS['setq']['command'] + for opaque, (key, value) in enumerate(mappings): if isinstance(key, tuple): key, cas = key @@ -750,37 +767,38 @@ def set_multi(self, mappings, time=100, compress_level=-1): if cas == 0: # Like cas(), if the cas value is 0, treat it as compare-and-set against not # existing. - command = 'addq' + opcode = ADDQ_CMD else: - command = 'setq' + opcode = SETQ_CMD keybytes = str_to_bytes(key) flags, value = self.serialize(value, compress_level=compress_level) - msg += struct.pack(self.HEADER_STRUCT + - self.COMMANDS[command]['struct'] % (len(keybytes), len(value)), - self.MAGIC['request'], - self.COMMANDS[command]['command'], - len(keybytes), - 8, 0, 0, len(keybytes) + len(value) + 8, opaque, cas or 0, - flags, time, keybytes, value) - - msg += struct.pack(self.HEADER_STRUCT + - self.COMMANDS['noop']['struct'], - self.MAGIC['request'], - self.COMMANDS['noop']['command'], - 0, 0, 0, 0, 0, 0, 0) + klen = len(keybytes) + vlen = len(value) + msg += pack_set_prefix(MAGIC_REQ, opcode, klen, + 8, 0, 0, klen + vlen + 8, opaque, cas or 0, + flags, time) + msg += keybytes + msg += value + + noop = self.COMMANDS['noop'] + NOOP_CMD = noop['command'] + msg += noop['packer'].pack(MAGIC_REQ, NOOP_CMD, + 0, 0, 0, 0, 0, 0, 0) self._send(msg) opcode = -1 failed = [] - while opcode != self.COMMANDS['noop']['command']: + DISCONNECTED = self.STATUS['server_disconnected'] + SUCCESS = self.STATUS['success'] + while opcode != NOOP_CMD: (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() - if status == self.STATUS['server_disconnected']: + if status == DISCONNECTED: # Assume that the entire operation failed. return list(key for key, value in mappings) - if status != self.STATUS['success']: + if status != SUCCESS: key, value = mappings[opaque] if isinstance(key, tuple): failed.append((key[0], cas)) @@ -817,6 +835,12 @@ def set_multi_cas(self, mappings, time=100, compress_level=-1): msg = bytearray() result = {} + MAGIC_REQ = self.MAGIC['request'] + add = self.COMMANDS['add'] + ADD_CMD = add['command'] + pack_set_prefix = add['packer'].pack # same packer for set/add + SET_CMD = self.COMMANDS['set']['command'] + for opaque, (key, value) in enumerate(mappings): if isinstance(key, tuple): str_key, cas = key @@ -825,30 +849,32 @@ def set_multi_cas(self, mappings, time=100, compress_level=-1): result[str_key] = None if cas == 0: - command = 'add' + opcode = ADD_CMD else: - command = 'set' + opcode = SET_CMD keybytes = str_to_bytes(str_key) flags, value = self.serialize(value, compress_level=compress_level) - msg += struct.pack(self.HEADER_STRUCT + - self.COMMANDS[command]['struct'] % (len(keybytes), len(value)), - self.MAGIC['request'], - self.COMMANDS[command]['command'], - len(keybytes), - 8, 0, 0, len(keybytes) + len(value) + 8, opaque, cas or 0, - flags, time, keybytes, value) + klen = len(keybytes) + vlen = len(value) + msg += pack_set_prefix(MAGIC_REQ, opcode, klen, + 8, 0, 0, klen + vlen + 8, opaque, cas or 0, + flags, time) + msg += keybytes + msg += value self._send(msg) # Non-quiet set/add return exactly one response per request, so we can # read a fixed count rather than relying on a trailing noop sentinel. + DISCONNECTED = self.STATUS['server_disconnected'] + SUCCESS = self.STATUS['success'] for _ in range(len(mappings)): (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() - if status == self.STATUS['server_disconnected']: + if status == DISCONNECTED: return result - if status == self.STATUS['success']: + if status == SUCCESS: key, value = mappings[opaque] str_key = key[0] if isinstance(key, tuple) else key result[str_key] = cas @@ -872,13 +898,12 @@ def _incr_decr(self, command, key, value, default, time): """ keybytes = str_to_bytes(key) time = time if time >= 0 else self.MAXIMUM_EXPIRE_TIME - self._send(struct.pack(self.HEADER_STRUCT + - self.COMMANDS[command]['struct'] % len(key), - self.MAGIC['request'], - self.COMMANDS[command]['command'], - len(keybytes), - 20, 0, 0, len(keybytes) + 20, 0, 0, value, - default, time, keybytes)) + cmd = self.COMMANDS[command] + klen = len(keybytes) + self._send(cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + klen, 20, 0, 0, klen + 20, 0, 0, + value, default, time) + keybytes) (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() @@ -938,11 +963,11 @@ def delete(self, key, cas=0): """ logger.debug('Deleting key %s', key) keybytes = str_to_bytes(key) - self._send(struct.pack(self.HEADER_STRUCT + - self.COMMANDS['delete']['struct'] % (len(keybytes),), - self.MAGIC['request'], - self.COMMANDS['delete']['command'], - len(keybytes), 0, 0, 0, len(keybytes), 0, cas, keybytes)) + cmd = self.COMMANDS['delete'] + klen = len(keybytes) + self._send(cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + klen, 0, 0, 0, klen, 0, cas) + keybytes) (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() @@ -966,27 +991,25 @@ def delete_multi(self, keys): """ logger.debug('Deleting keys %r', keys) msg = bytearray() + delete = self.COMMANDS['delete'] + DELETE_CMD = delete['command'] + pack_header = delete['packer'].pack # same packer as noop + MAGIC_REQ = self.MAGIC['request'] for key in keys: keybytes = str_to_bytes(key) - msg += struct.pack( - self.HEADER_STRUCT + - self.COMMANDS['delete']['struct'] % (len(keybytes),), - self.MAGIC['request'], - self.COMMANDS['delete']['command'], - len(keybytes), 0, 0, 0, len(keybytes), 0, 0, keybytes) - - msg += struct.pack( - self.HEADER_STRUCT + - self.COMMANDS['noop']['struct'], - self.MAGIC['request'], - self.COMMANDS['noop']['command'], - 0, 0, 0, 0, 0, 0, 0) + klen = len(keybytes) + msg += pack_header(MAGIC_REQ, DELETE_CMD, klen, 0, 0, 0, klen, 0, 0) + msg += keybytes + + noop = self.COMMANDS['noop'] + NOOP_CMD = noop['command'] + msg += noop['packer'].pack(MAGIC_REQ, NOOP_CMD, 0, 0, 0, 0, 0, 0, 0) self._send(msg) opcode = -1 retval = True - while opcode != self.COMMANDS['noop']['command']: + while opcode != NOOP_CMD: (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() if status != self.STATUS['success']: @@ -1006,11 +1029,10 @@ def flush_all(self, time): :rtype: bool """ logger.info('Flushing memcached') - self._send(struct.pack(self.HEADER_STRUCT + - self.COMMANDS['flush']['struct'], - self.MAGIC['request'], - self.COMMANDS['flush']['command'], - 0, 4, 0, 0, 4, 0, 0, time)) + cmd = self.COMMANDS['flush'] + self._send(cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + 0, 4, 0, 0, 4, 0, 0, time)) (magic, opcode, keylen, extlen, datatype, status, bodylen, opaque, cas, extra_content) = self._get_response() @@ -1031,20 +1053,17 @@ def stats(self, key=None): :rtype: dict """ # TODO: Stats with key is not working. + cmd = self.COMMANDS['stat'] if key is not None: if isinstance(key, text_type): key = str_to_bytes(key) keylen = len(key) - packed = struct.pack( - self.HEADER_STRUCT + '%ds' % keylen, - self.MAGIC['request'], - self.COMMANDS['stat']['command'], - keylen, 0, 0, 0, keylen, 0, 0, key) + packed = cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], + keylen, 0, 0, 0, keylen, 0, 0) + key else: - packed = struct.pack( - self.HEADER_STRUCT, - self.MAGIC['request'], - self.COMMANDS['stat']['command'], + packed = cmd['packer'].pack( + self.MAGIC['request'], cmd['command'], 0, 0, 0, 0, 0, 0, 0) self._send(packed) diff --git a/test/test_simple_functions.py b/test/test_simple_functions.py index ae3c024..0135798 100644 --- a/test/test_simple_functions.py +++ b/test/test_simple_functions.py @@ -42,6 +42,23 @@ def testSetMultiBigData(self): self.client.set_multi( dict((unicode(k), b'value') for k in range(32767))) + def testSetMultiNumericValues(self): + six.assertCountEqual(self, self.client.set_multi({ + 'test_key': 42, + 'test_key2': long(2 ** 40), + }), []) + self.assertEqual(self.client.get('test_key'), 42) + self.assertEqual(self.client.get('test_key2'), 2 ** 40) + + result = self.client.set_multi_cas({ + 'test_key': 7, + 'test_key2': long(2 ** 40 + 1), + }) + self.assertTrue(result['test_key'] is not None) + self.assertTrue(result['test_key2'] is not None) + self.assertEqual(self.client.get('test_key'), 7) + self.assertEqual(self.client.get('test_key2'), 2 ** 40 + 1) + def testGetSimple(self): self.client.set('test_key', 'test') self.assertEqual('test', self.client.get('test_key')) @@ -385,6 +402,53 @@ def testDecrementInitialize(self): self.assertEqual(10, self.client.decr('test_key', 1, default=10)) self.assertEqual(9, self.client.decr('test_key', 1, default=10)) + def testNonAsciiKeySingle(self): + key = u'シシ' + try: + self.assertEqual(0, self.client.incr(key, 1)) + self.assertEqual(1, self.client.incr(key, 1)) + self.assertEqual(0, self.client.decr(key, 1)) + self.client.delete(key) + + self.assertTrue(self.client.set(key, 'v1')) + self.assertEqual('v1', self.client.get(key)) + + self.assertFalse(self.client.add(key, 'v2')) + self.assertTrue(self.client.replace(key, 'v3')) + self.assertEqual('v3', self.client.get(key)) + + value, cas = self.client.gets(key) + self.assertEqual('v3', value) + self.assertTrue(self.client.cas(key, 'v4', cas)) + self.assertEqual('v4', self.client.get(key)) + + self.assertTrue(self.client.delete(key)) + self.assertEqual(None, self.client.get(key)) + finally: + self.client.delete(key) + + def testSetLargeNumeric(self): + big = 10 ** 200 + self.client.set('test_key', big) + self.assertEqual(big, self.client.get('test_key')) + + def testNonAsciiKeyBulk(self): + keys = [u'café', u'日本語'] + try: + self.assertEqual([], self.client.set_multi({k: 'v' for k in keys})) + self.assertEqual({k: 'v' for k in keys}, self.client.get_multi(keys)) + + self.client.delete_multi(keys) + self.assertEqual({}, self.client.get_multi(keys)) + + result = self.client.set_multi_cas({k: 'w' for k in keys}) + for k in keys: + self.assertTrue(result[k] is not None) + self.assertEqual({k: 'w' for k in keys}, self.client.get_multi(keys)) + finally: + for k in keys: + self.client.delete(k) + def testFlush(self): self.client.set('test_key', 'test') self.assertTrue(self.client.flush_all())