import itertools import queue import time import weakref from typing import List, Any, Tuple, Optional, Callable, Union, Match import redis from redis.connection import DefaultParser from . import _msgs as msgs from ._command_args_parsing import extract_args from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Item, Signature from ._helpers import ( SimpleError, valid_response_type, SimpleString, NoResponse, casematch, compile_pattern, QUEUED, encode_command, ) from ._stream import XStream from ._zset import ZSet def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]: """Extracts the command and command arguments from a list of bytes fields. :param fields: A list of bytes fields containing the command and command arguments. :return: A tuple of the command and command arguments. Example: fields = [b'GET', b'key1'] result = _extract_command(fields) print(result) # ('GET', ['key1']) """ cmd = encode_command(fields[0]) if cmd in COMMANDS_WITH_SUB and len(fields) >= 2: cmd += " " + encode_command(fields[1]) cmd_arguments = fields[2:] else: cmd_arguments = fields[1:] return cmd, cmd_arguments def bin_reverse(x, bits_count): result = 0 for i in range(bits_count): if (x >> i) & 1: result |= 1 << (bits_count - 1 - i) return result class BaseFakeSocket: _transaction: Optional[List[Any]] _clear_watches: Callable _in_transaction: bool _pubsub: int ACCEPTED_COMMANDS_WHILE_PUBSUB = { "ping", "subscribe", "unsubscribe", "psubscribe", "punsubscribe", "quit", "ssubscribe", "sunsubscribe", } _connection_error_class = redis.ConnectionError def __init__(self, server, db, *args, **kwargs): super(BaseFakeSocket, self).__init__(*args, **kwargs) self._server = server self._db_num = db self._db = server.dbs[self._db_num] self.responses: Optional[queue.Queue] = queue.Queue() # Prevents parser from processing commands. Not used in this module, # but set by aioredis module to prevent new commands being processed # while handling a blocking command. self._paused = False self._parser = self._parse_commands() self._parser.send(None) self.version = server.version def put_response(self, msg: Any) -> None: """Put a response message into the responses queue. :param msg: The response message. """ # redis.Connection.__del__ might call self.close at any time, which # will set self.responses to None. We assume this will happen # atomically, and the code below then protects us against this. responses = self.responses if responses: responses.put(msg) def pause(self) -> None: self._paused = True def resume(self) -> None: self._paused = False self._parser.send(b"") def shutdown(self, _) -> None: self._parser.close() @staticmethod def fileno() -> int: # Our fake socket must return an integer from `FakeSocket.fileno()` since a real selector # will be created. The value does not matter since we replace the selector with our own # `FakeSelector` before it is ever used. return 0 def _cleanup(self, server: Any) -> None: # noqa: F821 """Remove all the references to `self` from `server`. This is called with the server lock held, but it may be some time after self.close. """ for subs in server.subscribers.values(): subs.discard(self) for subs in server.psubscribers.values(): subs.discard(self) self._clear_watches() def close(self): # Mark ourselves for cleanup. This might be called from # redis.Connection.__del__, which the garbage collection could call # at any time, and hence we can't safely take the server lock. # We rely on list.append being atomic. self._server.closed_sockets.append(weakref.ref(self)) self._server = None self._db = None self.responses = None @staticmethod def _extract_line(buf): pos = buf.find(b"\n") + 1 assert pos > 0 line = buf[:pos] buf = buf[pos:] assert line.endswith(b"\r\n") return line, buf def _parse_commands(self): """Generator that parses commands. It is fed pieces of redis protocol data (via `send`) and calls `_process_command` whenever it has a complete one. """ buf = b"" while True: while self._paused or b"\n" not in buf: buf += yield line, buf = self._extract_line(buf) assert line[:1] == b"*" # array n_fields = int(line[1:-2]) fields = [] for i in range(n_fields): while b"\n" not in buf: buf += yield line, buf = self._extract_line(buf) assert line[:1] == b"$" # string length = int(line[1:-2]) while len(buf) < length + 2: buf += yield fields.append(buf[:length]) buf = buf[length + 2:] # +2 to skip the CRLF self._process_command(fields) def _run_command(self, func: Callable[..., Any], sig: Signature, args: Tuple[Any], from_script: bool) -> Any: command_items = {} try: ret = sig.apply(args, self._db, self.version) if from_script and msgs.FLAG_NO_SCRIPT in sig.flags: raise SimpleError(msgs.COMMAND_IN_SCRIPT_MSG) if ( self._pubsub and sig.name not in BaseFakeSocket.ACCEPTED_COMMANDS_WHILE_PUBSUB ): raise SimpleError(msgs.BAD_COMMAND_IN_PUBSUB_MSG) if len(ret) == 1: result = ret[0] else: args, command_items = ret result = func(*args) assert valid_response_type(result) except SimpleError as exc: result = exc for command_item in command_items: command_item.writeback( remove_empty_val=msgs.FLAG_LEAVE_EMPTY_VAL not in sig.flags ) return result def _decode_error(self, error): return DefaultParser(socket_read_size=65536).parse_error(error.value) # type: ignore def _decode_result(self, result): """Convert SimpleString and SimpleError, recursively""" if isinstance(result, list): return [self._decode_result(r) for r in result] elif isinstance(result, SimpleString): return result.value elif isinstance(result, SimpleError): return self._decode_error(result) else: return result def _blocking(self, timeout: Optional[Union[float, int]], func: Callable): """Run a function until it succeeds or timeout is reached. The timeout is in seconds, and 0 means infinite. The function is called with a boolean to indicate whether this is the first call. If it returns None, it is considered to have "failed" and is retried each time the condition variable is notified, until the timeout is reached. Returns the function return value, or None if the timeout has passed. """ ret = func(True) if ret is not None or self._in_transaction: return ret deadline = time.time() + timeout if timeout else None while True: timeout = (deadline - time.time()) if deadline is not None else None if timeout is not None and timeout <= 0: return None if self._db.condition.wait(timeout=timeout) is False: return None # Timeout expired ret = func(False) if ret is not None: return ret def _name_to_func(self, cmd_name: str): """Get the signature and the method from the command name.""" if cmd_name not in SUPPORTED_COMMANDS: # redis remaps \r or \n in an error to ' ' to make it legal protocol clean_name = cmd_name.replace("\r", " ").replace("\n", " ") raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format(clean_name)) sig = SUPPORTED_COMMANDS[cmd_name] func = getattr(self, sig.func_name, None) return func, sig def sendall(self, data): if not self._server.connected: raise self._connection_error_class(msgs.CONNECTION_ERROR_MSG) if isinstance(data, str): data = data.encode("ascii") self._parser.send(data) def _process_command(self, fields: List[bytes]): if not fields: return result: Any cmd, cmd_arguments = _extract_command(fields) try: func, sig = self._name_to_func(cmd) with self._server.lock: # Clean out old connections while True: try: weak_sock = self._server.closed_sockets.pop() except IndexError: break else: sock = weak_sock() if sock: sock._cleanup(self._server) now = time.time() for db in self._server.dbs.values(): db.time = now sig.check_arity(cmd_arguments, self.version) if ( self._transaction is not None and msgs.FLAG_TRANSACTION not in sig.flags ): self._transaction.append((func, sig, cmd_arguments)) result = QUEUED else: result = self._run_command(func, sig, cmd_arguments, False) except SimpleError as exc: if self._transaction is not None: # TODO: should not apply if the exception is from _run_command # e.g. watch inside multi self._transaction_failed = True if cmd == "exec" and exc.value.startswith("ERR "): exc.value = ( "EXECABORT Transaction discarded because of: " + exc.value[4:] ) self._transaction = None self._transaction_failed = False self._clear_watches() result = exc result = self._decode_result(result) if not isinstance(result, NoResponse): self.put_response(result) def _scan(self, keys, cursor, *args): """This is the basis of most of the ``scan`` methods. This implementation is KNOWN to be un-performant, as it requires grabbing the full set of keys over which we are investigating subsets. The SCAN command, and the other commands in the SCAN family, are able to provide to the user a set of guarantees associated with full iterations. - A full iteration always retrieves all the elements that were present in the collection from the start to the end of a full iteration. This means that if a given element is inside the collection when an iteration is started, and is still there when an iteration terminates, then at some point the SCAN command returned it to the user. - A full iteration never returns any element that was NOT present in the collection from the start to the end of a full iteration. So if an element was removed before the start of an iteration, and is never added back to the collection for all the time an iteration lasts, the SCAN command ensures that this element will never be returned. However, because the SCAN command has very little state associated (just the cursor), it has the following drawbacks: - A given element may be returned multiple times. It is up to the application to handle the case of duplicated elements, for example, only using the returned elements in order to perform operations that are safe when re-applied multiple times. - Elements that were not constantly present in the collection during a full iteration may be returned or not: it is undefined. """ cursor = int(cursor) (pattern, _type, count), _ = extract_args(args, ("*match", "*type", "+count")) count = 10 if count is None else count data = sorted(keys) bits_len = (len(keys) - 1).bit_length() cursor = bin_reverse(cursor, bits_len) if cursor >= len(keys): return [0, []] result_cursor = cursor + count result_data = [] regex = compile_pattern(pattern) if pattern is not None else None def match_key(key: bytes) -> Union[bool, Match[bytes], None]: return regex.match(key) if regex is not None else True def match_type(key) -> bool: return _type is None or casematch(BaseFakeSocket._key_value_type(self._db[key]).value, _type) if pattern is not None or _type is not None: for val in itertools.islice(data, cursor, cursor + count): compare_val = val[0] if isinstance(val, tuple) else val if match_key(compare_val) and match_type(compare_val): result_data.append(val) else: result_data = data[cursor: cursor + count] if result_cursor >= len(data): result_cursor = 0 return [str(bin_reverse(result_cursor, bits_len)).encode(), result_data] def _ttl(self, key, scale) -> int: if not key: return -2 elif key.expireat is None: return -1 else: return int(round((key.expireat - self._db.time) * scale)) def _encodefloat(self, value: float, humanfriendly: bool) -> bytes: if self.version >= (7,): value = 0 + value return Float.encode(value, humanfriendly) def _encodeint(self, value: int) -> bytes: if self.version >= (7,): value = 0 + value return Int.encode(value) @staticmethod def _key_value_type(key: Item) -> SimpleString: if key.value is None: return SimpleString(b"none") elif isinstance(key.value, bytes): return SimpleString(b"string") elif isinstance(key.value, list): return SimpleString(b"list") elif isinstance(key.value, set): return SimpleString(b"set") elif isinstance(key.value, ZSet): return SimpleString(b"zset") elif isinstance(key.value, dict): return SimpleString(b"hash") elif isinstance(key.value, XStream): return SimpleString(b"stream") else: assert False # pragma: nocover