youtube-summarizer/venv311/lib/python3.11/site-packages/fakeredis/commands_mixins/set_mixin.py

202 lines
7.1 KiB
Python

import random
from typing import Callable, Tuple, Any, Optional, List, Union
from fakeredis import _msgs as msgs
from fakeredis._commands import command, Key, Int, CommandItem
from fakeredis._helpers import OK, SimpleError, casematch, Database, SimpleString
def _calc_setop(op: Callable[..., Any], stop_if_missing: bool, key: CommandItem, *keys: CommandItem) -> Any:
if stop_if_missing and not key.value:
return set()
value = key.value
if not isinstance(value, set):
raise SimpleError(msgs.WRONGTYPE_MSG)
ans = value.copy()
for other in keys:
value = other.value if other.value is not None else set()
if not isinstance(value, set):
raise SimpleError(msgs.WRONGTYPE_MSG)
if stop_if_missing and not value:
return set()
ans = op(ans, value)
return ans
def _setop(
op: Callable[..., Any],
stop_if_missing: bool,
dst: Optional[CommandItem],
key: CommandItem,
*keys: CommandItem) -> Any:
"""Apply one of SINTER[STORE], SUNION[STORE], SDIFF[STORE].
If `stop_if_missing`, the output will be made an empty set as soon as
an empty input set is encountered (use for SINTER[STORE]). May assume
that `key` is a set (or empty), but `keys` could be anything.
"""
ans = _calc_setop(op, stop_if_missing, key, *keys)
if dst is None:
return list(ans)
else:
dst.value = ans
return len(dst.value)
class SetCommandsMixin:
version: Tuple[int]
_db: Database
_scan: Callable[..., Any]
@command((Key(set), bytes), (bytes,))
def sadd(self, key: CommandItem, *members: bytes) -> int:
old_size = len(key.value)
key.value.update(members)
key.updated()
return len(key.value) - old_size
@command((Key(set),))
def scard(self, key: CommandItem) -> int:
return len(key.value)
@command((Key(set),), (Key(set),))
def sdiff(self, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a - b, False, None, *keys)
@command((Key(), Key(set)), (Key(set),))
def sdiffstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a - b, False, dst, *keys)
@command((Key(set),), (Key(set),))
def sinter(self, *keys: CommandItem) -> Any:
res = _setop(lambda a, b: a & b, True, None, *keys)
return res
@command((Int, bytes), (bytes,))
def sintercard(self, numkeys: int, *args: bytes) -> int:
if self.version < (7,):
raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format("sintercard"))
if numkeys < 1:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
limit = 0
if casematch(args[-2], b"limit"):
limit = Int.decode(args[-1])
args = args[:-2]
if numkeys != len(args):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
keys = [
CommandItem(args[i], self._db, item=self._db.get(args[i], default=None))
for i in range(numkeys)
]
res = _setop(lambda a, b: a & b, False, None, *keys)
return len(res) if limit == 0 else min(limit, len(res))
@command((Key(), Key(set)), (Key(set),))
def sinterstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a & b, True, dst, *keys)
@command((Key(set), bytes))
def sismember(self, key: CommandItem, member: bytes) -> int:
return int(member in key.value)
@command((Key(set), bytes), (bytes,))
def smismember(self, key: CommandItem, *members: bytes) -> List[int]:
return [self.sismember(key, member) for member in members]
@command((Key(set),))
def smembers(self, key: CommandItem) -> List[bytes]:
return list(key.value)
@command((Key(set, 0), Key(set), bytes))
def smove(self, src: CommandItem, dst: CommandItem, member: bytes) -> int:
try:
src.value.remove(member)
src.updated()
except KeyError:
return 0
else:
dst.value.add(member)
dst.updated() # TODO: is it updated if member was already present?
return 1
@command((Key(set),), (Int,))
def spop(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]:
if count is None:
if not key.value:
return None
item = random.sample(list(key.value), 1)[0]
key.value.remove(item)
key.updated()
return item # type: ignore
else:
if count < 0:
raise SimpleError(msgs.INDEX_ERROR_MSG)
items: Union[bytes, List[bytes]] = self.srandmember(key, count)
for item in items:
key.value.remove(item)
key.updated() # Inside the loop because redis special-cases count=0
return items
@command((Key(set),), (Int,))
def srandmember(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]:
if count is None:
if not key.value:
return None
else:
return random.sample(list(key.value), 1)[0] # type: ignore
elif count >= 0:
count = min(count, len(key.value))
return random.sample(list(key.value), count)
else:
items = list(key.value)
return [random.choice(items) for _ in range(-count)]
@command((Key(set), bytes), (bytes,))
def srem(self, key: CommandItem, *members: bytes) -> int:
old_size = len(key.value)
for member in members:
key.value.discard(member)
deleted = old_size - len(key.value)
if deleted:
key.updated()
return deleted
@command((Key(set), Int), (bytes, bytes))
def sscan(self, key: CommandItem, cursor: int, *args: bytes) -> Any:
return self._scan(key.value, cursor, *args)
@command((Key(set),), (Key(set),))
def sunion(self, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a | b, False, None, *keys)
@command((Key(), Key(set)), (Key(set),))
def sunionstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
return _setop(lambda a, b: a | b, False, dst, *keys)
# Hyperloglog commands
# These are not quite the same as the real redis ones, which are
# approximate and store the results in a string. Instead, it is implemented
# on top of sets.
@command((Key(set),), (bytes,))
def pfadd(self, key: CommandItem, *elements: bytes) -> int:
result = self.sadd(key, *elements)
# Per the documentation:
# - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise.
return 1 if result > 0 else 0
@command((Key(set),), (Key(set),))
def pfcount(self, *keys: CommandItem) -> int:
"""
Return the approximated cardinality of
the set observed by the HyperLogLog at key(s).
"""
return len(self.sunion(*keys))
@command((Key(set), Key(set)), (Key(set),))
def pfmerge(self, dest: CommandItem, *sources: CommandItem) -> SimpleString:
"""Merge N different HyperLogLogs into a single one."""
self.sunionstore(dest, *sources)
return OK