youtube-summarizer/venv311/lib/python3.11/site-packages/fakeredis/_server.py

220 lines
7.6 KiB
Python

import inspect
import logging
import queue
import threading
import time
import uuid
import warnings
import weakref
from collections import defaultdict
from typing import Dict, Tuple, Any, List
import redis
from fakeredis._fakesocket import FakeSocket
from fakeredis._helpers import Database, FakeSelector
from . import _msgs as msgs
LOGGER = logging.getLogger("fakeredis")
def _create_version(v) -> Tuple[int]:
if isinstance(v, tuple):
return v # type: ignore
if isinstance(v, int):
return (v,)
if isinstance(v, str):
v = v.split(".")
return tuple(int(x) for x in v) # type: ignore
return v
class FakeServer:
_servers_map: Dict[str, "FakeServer"] = dict()
def __init__(self, version: Tuple[int] = (7,)):
self.lock = threading.Lock()
self.dbs: Dict[int, Database] = defaultdict(lambda: Database(self.lock))
# Maps channel/pattern to weak set of sockets
self.subscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet)
self.psubscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet)
self.ssubscribers: Dict[bytes, weakref.WeakSet] = defaultdict(weakref.WeakSet)
self.lastsave = int(time.time())
self.connected = True
# List of weakrefs to sockets that are being closed lazily
self.closed_sockets: List[Any] = []
self.version = _create_version(version)
@staticmethod
def get_server(key, version: Tuple[int]):
return FakeServer._servers_map.setdefault(key, FakeServer(version=version))
class FakeBaseConnectionMixin:
def __init__(self, *args, **kwargs):
self.client_name = None
self._sock = None
self._selector = None
self._server = kwargs.pop("server", None)
path = kwargs.pop("path", None)
version = kwargs.pop("version", (7, 0))
connected = kwargs.pop("connected", True)
if self._server is None:
if path:
self.server_key = path
else:
host, port = kwargs.get("host"), kwargs.get("port")
self.server_key = f"{host}:{port}"
self.server_key += f":v{version}"
self._server = FakeServer.get_server(self.server_key, version=version)
self._server.connected = connected
super().__init__(*args, **kwargs)
class FakeConnection(FakeBaseConnectionMixin, redis.Connection):
def connect(self):
super().connect()
# The selector is set in redis.Connection.connect() after _connect() is called
self._selector = FakeSelector(self._sock)
def _connect(self):
if not self._server.connected:
raise redis.ConnectionError(msgs.CONNECTION_ERROR_MSG)
return FakeSocket(self._server, db=self.db)
def can_read(self, timeout=0):
if not self._server.connected:
return True
if not self._sock:
self.connect()
# We use check_can_read rather than can_read, because on redis-py<3.2,
# FakeSelector inherits from a stub BaseSelector which doesn't
# implement can_read. Normally can_read provides retries on EINTR,
# but that's not necessary for the implementation of
# FakeSelector.check_can_read.
return self._selector and self._selector.check_can_read(timeout)
def _decode(self, response):
if isinstance(response, list):
return [self._decode(item) for item in response]
elif isinstance(response, bytes):
return self.encoder.decode(
response,
)
else:
return response
def read_response(self, **kwargs):
if not self._sock:
raise redis.ConnectionError(msgs.CONNECTION_ERROR_MSG)
if not self._server.connected:
try:
response = self._sock.responses.get_nowait()
except queue.Empty:
if kwargs.get("disconnect_on_error", True):
self.disconnect()
raise redis.ConnectionError(msgs.CONNECTION_ERROR_MSG)
else:
response = self._sock.responses.get()
if isinstance(response, redis.ResponseError):
raise response
if kwargs.get("disable_decoding", False):
return response
else:
return self._decode(response)
def repr_pieces(self):
pieces = [("server", self._server), ("db", self.db)]
if self.client_name:
pieces.append(("client_name", self.client_name))
return pieces
def __str__(self):
return self.server_key
class FakeRedisMixin:
def __init__(self, *args, server=None, version=(7,), **kwargs):
# Interpret the positional and keyword arguments according to the
# version of redis in use.
parameters = list(inspect.signature(redis.Redis.__init__).parameters.values())[
1:
]
# Convert args => kwargs
kwargs.update({parameters[i].name: args[i] for i in range(len(args))})
kwargs.setdefault("host", uuid.uuid4().hex)
kwds = {
p.name: kwargs.get(p.name, p.default)
for ind, p in enumerate(parameters)
if p.default != inspect.Parameter.empty
}
if not kwds.get("connection_pool", None):
charset = kwds.get("charset", None)
errors = kwds.get("errors", None)
# Adapted from redis-py
if charset is not None:
warnings.warn(
DeprecationWarning(
'"charset" is deprecated. Use "encoding" instead'
)
)
kwds["encoding"] = charset
if errors is not None:
warnings.warn(
DeprecationWarning(
'"errors" is deprecated. Use "encoding_errors" instead'
)
)
kwds["encoding_errors"] = errors
conn_pool_args = {
"host",
"port",
"db",
# Ignoring because AUTH is not implemented
# 'username',
# 'password',
"socket_timeout",
"encoding",
"encoding_errors",
"decode_responses",
"retry_on_timeout",
"max_connections",
"health_check_interval",
"client_name",
"connected",
}
connection_kwargs = {
"connection_class": FakeConnection,
"server": server,
"version": version,
}
connection_kwargs.update(
{arg: kwds[arg] for arg in conn_pool_args if arg in kwds}
)
kwds["connection_pool"] = redis.connection.ConnectionPool(
**connection_kwargs
)
kwds.pop("server", None)
kwds.pop("connected", None)
kwds.pop("version", None)
super().__init__(**kwds)
@classmethod
def from_url(cls, *args, **kwargs):
pool = redis.ConnectionPool.from_url(*args, **kwargs)
# Now override how it creates connections
pool.connection_class = FakeConnection
# Using username and password fails since AUTH is not implemented.
# https://github.com/cunla/fakeredis-py/issues/9
pool.connection_kwargs.pop("username", None)
pool.connection_kwargs.pop("password", None)
return cls(connection_pool=pool)
class FakeStrictRedis(FakeRedisMixin, redis.StrictRedis):
pass
class FakeRedis(FakeRedisMixin, redis.Redis):
pass