youtube-summarizer/venv311/lib/python3.11/site-packages/fakeredis/_basefakesocket.py

396 lines
15 KiB
Python

import itertools
import queue
import time
import weakref
from typing import List, Any, Tuple, Optional, Callable, Union, Match
import redis
from redis.connection import DefaultParser
from . import _msgs as msgs
from ._command_args_parsing import extract_args
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Item, Signature
from ._helpers import (
SimpleError,
valid_response_type,
SimpleString,
NoResponse,
casematch,
compile_pattern,
QUEUED,
encode_command,
)
from ._stream import XStream
from ._zset import ZSet
def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]:
"""Extracts the command and command arguments from a list of bytes fields.
:param fields: A list of bytes fields containing the command and command arguments.
:return: A tuple of the command and command arguments.
Example:
fields = [b'GET', b'key1']
result = _extract_command(fields)
print(result) # ('GET', ['key1'])
"""
cmd = encode_command(fields[0])
if cmd in COMMANDS_WITH_SUB and len(fields) >= 2:
cmd += " " + encode_command(fields[1])
cmd_arguments = fields[2:]
else:
cmd_arguments = fields[1:]
return cmd, cmd_arguments
def bin_reverse(x, bits_count):
result = 0
for i in range(bits_count):
if (x >> i) & 1:
result |= 1 << (bits_count - 1 - i)
return result
class BaseFakeSocket:
_transaction: Optional[List[Any]]
_clear_watches: Callable
_in_transaction: bool
_pubsub: int
ACCEPTED_COMMANDS_WHILE_PUBSUB = {
"ping",
"subscribe",
"unsubscribe",
"psubscribe",
"punsubscribe",
"quit",
"ssubscribe",
"sunsubscribe",
}
_connection_error_class = redis.ConnectionError
def __init__(self, server, db, *args, **kwargs):
super(BaseFakeSocket, self).__init__(*args, **kwargs)
self._server = server
self._db_num = db
self._db = server.dbs[self._db_num]
self.responses: Optional[queue.Queue] = queue.Queue()
# Prevents parser from processing commands. Not used in this module,
# but set by aioredis module to prevent new commands being processed
# while handling a blocking command.
self._paused = False
self._parser = self._parse_commands()
self._parser.send(None)
self.version = server.version
def put_response(self, msg: Any) -> None:
"""Put a response message into the responses queue.
:param msg: The response message.
"""
# redis.Connection.__del__ might call self.close at any time, which
# will set self.responses to None. We assume this will happen
# atomically, and the code below then protects us against this.
responses = self.responses
if responses:
responses.put(msg)
def pause(self) -> None:
self._paused = True
def resume(self) -> None:
self._paused = False
self._parser.send(b"")
def shutdown(self, _) -> None:
self._parser.close()
@staticmethod
def fileno() -> int:
# Our fake socket must return an integer from `FakeSocket.fileno()` since a real selector
# will be created. The value does not matter since we replace the selector with our own
# `FakeSelector` before it is ever used.
return 0
def _cleanup(self, server: Any) -> None: # noqa: F821
"""Remove all the references to `self` from `server`.
This is called with the server lock held, but it may be some time after
self.close.
"""
for subs in server.subscribers.values():
subs.discard(self)
for subs in server.psubscribers.values():
subs.discard(self)
self._clear_watches()
def close(self):
# Mark ourselves for cleanup. This might be called from
# redis.Connection.__del__, which the garbage collection could call
# at any time, and hence we can't safely take the server lock.
# We rely on list.append being atomic.
self._server.closed_sockets.append(weakref.ref(self))
self._server = None
self._db = None
self.responses = None
@staticmethod
def _extract_line(buf):
pos = buf.find(b"\n") + 1
assert pos > 0
line = buf[:pos]
buf = buf[pos:]
assert line.endswith(b"\r\n")
return line, buf
def _parse_commands(self):
"""Generator that parses commands.
It is fed pieces of redis protocol data (via `send`) and calls
`_process_command` whenever it has a complete one.
"""
buf = b""
while True:
while self._paused or b"\n" not in buf:
buf += yield
line, buf = self._extract_line(buf)
assert line[:1] == b"*" # array
n_fields = int(line[1:-2])
fields = []
for i in range(n_fields):
while b"\n" not in buf:
buf += yield
line, buf = self._extract_line(buf)
assert line[:1] == b"$" # string
length = int(line[1:-2])
while len(buf) < length + 2:
buf += yield
fields.append(buf[:length])
buf = buf[length + 2:] # +2 to skip the CRLF
self._process_command(fields)
def _run_command(self, func: Callable[..., Any], sig: Signature, args: Tuple[Any], from_script: bool) -> Any:
command_items = {}
try:
ret = sig.apply(args, self._db, self.version)
if from_script and msgs.FLAG_NO_SCRIPT in sig.flags:
raise SimpleError(msgs.COMMAND_IN_SCRIPT_MSG)
if (
self._pubsub
and sig.name not in BaseFakeSocket.ACCEPTED_COMMANDS_WHILE_PUBSUB
):
raise SimpleError(msgs.BAD_COMMAND_IN_PUBSUB_MSG)
if len(ret) == 1:
result = ret[0]
else:
args, command_items = ret
result = func(*args)
assert valid_response_type(result)
except SimpleError as exc:
result = exc
for command_item in command_items:
command_item.writeback(
remove_empty_val=msgs.FLAG_LEAVE_EMPTY_VAL not in sig.flags
)
return result
def _decode_error(self, error):
return DefaultParser(socket_read_size=65536).parse_error(error.value) # type: ignore
def _decode_result(self, result):
"""Convert SimpleString and SimpleError, recursively"""
if isinstance(result, list):
return [self._decode_result(r) for r in result]
elif isinstance(result, SimpleString):
return result.value
elif isinstance(result, SimpleError):
return self._decode_error(result)
else:
return result
def _blocking(self, timeout: Optional[Union[float, int]], func: Callable):
"""Run a function until it succeeds or timeout is reached.
The timeout is in seconds, and 0 means infinite. The function
is called with a boolean to indicate whether this is the first call.
If it returns None, it is considered to have "failed" and is retried
each time the condition variable is notified, until the timeout is
reached.
Returns the function return value, or None if the timeout has passed.
"""
ret = func(True)
if ret is not None or self._in_transaction:
return ret
deadline = time.time() + timeout if timeout else None
while True:
timeout = (deadline - time.time()) if deadline is not None else None
if timeout is not None and timeout <= 0:
return None
if self._db.condition.wait(timeout=timeout) is False:
return None # Timeout expired
ret = func(False)
if ret is not None:
return ret
def _name_to_func(self, cmd_name: str):
"""Get the signature and the method from the command name."""
if cmd_name not in SUPPORTED_COMMANDS:
# redis remaps \r or \n in an error to ' ' to make it legal protocol
clean_name = cmd_name.replace("\r", " ").replace("\n", " ")
raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format(clean_name))
sig = SUPPORTED_COMMANDS[cmd_name]
func = getattr(self, sig.func_name, None)
return func, sig
def sendall(self, data):
if not self._server.connected:
raise self._connection_error_class(msgs.CONNECTION_ERROR_MSG)
if isinstance(data, str):
data = data.encode("ascii")
self._parser.send(data)
def _process_command(self, fields: List[bytes]):
if not fields:
return
result: Any
cmd, cmd_arguments = _extract_command(fields)
try:
func, sig = self._name_to_func(cmd)
with self._server.lock:
# Clean out old connections
while True:
try:
weak_sock = self._server.closed_sockets.pop()
except IndexError:
break
else:
sock = weak_sock()
if sock:
sock._cleanup(self._server)
now = time.time()
for db in self._server.dbs.values():
db.time = now
sig.check_arity(cmd_arguments, self.version)
if (
self._transaction is not None
and msgs.FLAG_TRANSACTION not in sig.flags
):
self._transaction.append((func, sig, cmd_arguments))
result = QUEUED
else:
result = self._run_command(func, sig, cmd_arguments, False)
except SimpleError as exc:
if self._transaction is not None:
# TODO: should not apply if the exception is from _run_command
# e.g. watch inside multi
self._transaction_failed = True
if cmd == "exec" and exc.value.startswith("ERR "):
exc.value = (
"EXECABORT Transaction discarded because of: " + exc.value[4:]
)
self._transaction = None
self._transaction_failed = False
self._clear_watches()
result = exc
result = self._decode_result(result)
if not isinstance(result, NoResponse):
self.put_response(result)
def _scan(self, keys, cursor, *args):
"""This is the basis of most of the ``scan`` methods.
This implementation is KNOWN to be un-performant, as it requires grabbing the full set of keys over which
we are investigating subsets.
The SCAN command, and the other commands in the SCAN family, are able to provide to the user a set of
guarantees associated with full iterations.
- A full iteration always retrieves all the elements that were present in the collection from the start to the
end of a full iteration. This means that if a given element is inside the collection when an iteration is
started, and is still there when an iteration terminates, then at some point the SCAN command returned it to
the user.
- A full iteration never returns any element that was NOT present in the collection from the start to the end
of a full iteration. So if an element was removed before the start of an iteration, and is never added back
to the collection for all the time an iteration lasts, the SCAN command ensures that this element will never
be returned.
However, because the SCAN command has very little state associated (just the cursor),
it has the following drawbacks:
- A given element may be returned multiple times. It is up to the application to handle the case of duplicated
elements, for example, only using the returned elements in order to perform operations that are safe when
re-applied multiple times.
- Elements that were not constantly present in the collection during a full iteration may be returned or not:
it is undefined.
"""
cursor = int(cursor)
(pattern, _type, count), _ = extract_args(args, ("*match", "*type", "+count"))
count = 10 if count is None else count
data = sorted(keys)
bits_len = (len(keys) - 1).bit_length()
cursor = bin_reverse(cursor, bits_len)
if cursor >= len(keys):
return [0, []]
result_cursor = cursor + count
result_data = []
regex = compile_pattern(pattern) if pattern is not None else None
def match_key(key: bytes) -> Union[bool, Match[bytes], None]:
return regex.match(key) if regex is not None else True
def match_type(key) -> bool:
return _type is None or casematch(BaseFakeSocket._key_value_type(self._db[key]).value, _type)
if pattern is not None or _type is not None:
for val in itertools.islice(data, cursor, cursor + count):
compare_val = val[0] if isinstance(val, tuple) else val
if match_key(compare_val) and match_type(compare_val):
result_data.append(val)
else:
result_data = data[cursor: cursor + count]
if result_cursor >= len(data):
result_cursor = 0
return [str(bin_reverse(result_cursor, bits_len)).encode(), result_data]
def _ttl(self, key, scale) -> int:
if not key:
return -2
elif key.expireat is None:
return -1
else:
return int(round((key.expireat - self._db.time) * scale))
def _encodefloat(self, value: float, humanfriendly: bool) -> bytes:
if self.version >= (7,):
value = 0 + value
return Float.encode(value, humanfriendly)
def _encodeint(self, value: int) -> bytes:
if self.version >= (7,):
value = 0 + value
return Int.encode(value)
@staticmethod
def _key_value_type(key: Item) -> SimpleString:
if key.value is None:
return SimpleString(b"none")
elif isinstance(key.value, bytes):
return SimpleString(b"string")
elif isinstance(key.value, list):
return SimpleString(b"list")
elif isinstance(key.value, set):
return SimpleString(b"set")
elif isinstance(key.value, ZSet):
return SimpleString(b"zset")
elif isinstance(key.value, dict):
return SimpleString(b"hash")
elif isinstance(key.value, XStream):
return SimpleString(b"stream")
else:
assert False # pragma: nocover