202 lines
7.1 KiB
Python
202 lines
7.1 KiB
Python
import random
|
|
from typing import Callable, Tuple, Any, Optional, List, Union
|
|
|
|
from fakeredis import _msgs as msgs
|
|
from fakeredis._commands import command, Key, Int, CommandItem
|
|
from fakeredis._helpers import OK, SimpleError, casematch, Database, SimpleString
|
|
|
|
|
|
def _calc_setop(op: Callable[..., Any], stop_if_missing: bool, key: CommandItem, *keys: CommandItem) -> Any:
|
|
if stop_if_missing and not key.value:
|
|
return set()
|
|
value = key.value
|
|
if not isinstance(value, set):
|
|
raise SimpleError(msgs.WRONGTYPE_MSG)
|
|
ans = value.copy()
|
|
for other in keys:
|
|
value = other.value if other.value is not None else set()
|
|
if not isinstance(value, set):
|
|
raise SimpleError(msgs.WRONGTYPE_MSG)
|
|
if stop_if_missing and not value:
|
|
return set()
|
|
ans = op(ans, value)
|
|
return ans
|
|
|
|
|
|
def _setop(
|
|
op: Callable[..., Any],
|
|
stop_if_missing: bool,
|
|
dst: Optional[CommandItem],
|
|
key: CommandItem,
|
|
*keys: CommandItem) -> Any:
|
|
"""Apply one of SINTER[STORE], SUNION[STORE], SDIFF[STORE].
|
|
|
|
If `stop_if_missing`, the output will be made an empty set as soon as
|
|
an empty input set is encountered (use for SINTER[STORE]). May assume
|
|
that `key` is a set (or empty), but `keys` could be anything.
|
|
"""
|
|
ans = _calc_setop(op, stop_if_missing, key, *keys)
|
|
if dst is None:
|
|
return list(ans)
|
|
else:
|
|
dst.value = ans
|
|
return len(dst.value)
|
|
|
|
|
|
class SetCommandsMixin:
|
|
version: Tuple[int]
|
|
_db: Database
|
|
_scan: Callable[..., Any]
|
|
|
|
@command((Key(set), bytes), (bytes,))
|
|
def sadd(self, key: CommandItem, *members: bytes) -> int:
|
|
old_size = len(key.value)
|
|
key.value.update(members)
|
|
key.updated()
|
|
return len(key.value) - old_size
|
|
|
|
@command((Key(set),))
|
|
def scard(self, key: CommandItem) -> int:
|
|
return len(key.value)
|
|
|
|
@command((Key(set),), (Key(set),))
|
|
def sdiff(self, *keys: CommandItem) -> Any:
|
|
return _setop(lambda a, b: a - b, False, None, *keys)
|
|
|
|
@command((Key(), Key(set)), (Key(set),))
|
|
def sdiffstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
|
|
return _setop(lambda a, b: a - b, False, dst, *keys)
|
|
|
|
@command((Key(set),), (Key(set),))
|
|
def sinter(self, *keys: CommandItem) -> Any:
|
|
res = _setop(lambda a, b: a & b, True, None, *keys)
|
|
return res
|
|
|
|
@command((Int, bytes), (bytes,))
|
|
def sintercard(self, numkeys: int, *args: bytes) -> int:
|
|
if self.version < (7,):
|
|
raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format("sintercard"))
|
|
if numkeys < 1:
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
limit = 0
|
|
if casematch(args[-2], b"limit"):
|
|
limit = Int.decode(args[-1])
|
|
args = args[:-2]
|
|
if numkeys != len(args):
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
keys = [
|
|
CommandItem(args[i], self._db, item=self._db.get(args[i], default=None))
|
|
for i in range(numkeys)
|
|
]
|
|
|
|
res = _setop(lambda a, b: a & b, False, None, *keys)
|
|
return len(res) if limit == 0 else min(limit, len(res))
|
|
|
|
@command((Key(), Key(set)), (Key(set),))
|
|
def sinterstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
|
|
return _setop(lambda a, b: a & b, True, dst, *keys)
|
|
|
|
@command((Key(set), bytes))
|
|
def sismember(self, key: CommandItem, member: bytes) -> int:
|
|
return int(member in key.value)
|
|
|
|
@command((Key(set), bytes), (bytes,))
|
|
def smismember(self, key: CommandItem, *members: bytes) -> List[int]:
|
|
return [self.sismember(key, member) for member in members]
|
|
|
|
@command((Key(set),))
|
|
def smembers(self, key: CommandItem) -> List[bytes]:
|
|
return list(key.value)
|
|
|
|
@command((Key(set, 0), Key(set), bytes))
|
|
def smove(self, src: CommandItem, dst: CommandItem, member: bytes) -> int:
|
|
try:
|
|
src.value.remove(member)
|
|
src.updated()
|
|
except KeyError:
|
|
return 0
|
|
else:
|
|
dst.value.add(member)
|
|
dst.updated() # TODO: is it updated if member was already present?
|
|
return 1
|
|
|
|
@command((Key(set),), (Int,))
|
|
def spop(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]:
|
|
if count is None:
|
|
if not key.value:
|
|
return None
|
|
item = random.sample(list(key.value), 1)[0]
|
|
key.value.remove(item)
|
|
key.updated()
|
|
return item # type: ignore
|
|
else:
|
|
if count < 0:
|
|
raise SimpleError(msgs.INDEX_ERROR_MSG)
|
|
items: Union[bytes, List[bytes]] = self.srandmember(key, count)
|
|
for item in items:
|
|
key.value.remove(item)
|
|
key.updated() # Inside the loop because redis special-cases count=0
|
|
return items
|
|
|
|
@command((Key(set),), (Int,))
|
|
def srandmember(self, key: CommandItem, count: Optional[int] = None) -> Union[bytes, List[bytes], None]:
|
|
if count is None:
|
|
if not key.value:
|
|
return None
|
|
else:
|
|
return random.sample(list(key.value), 1)[0] # type: ignore
|
|
elif count >= 0:
|
|
count = min(count, len(key.value))
|
|
return random.sample(list(key.value), count)
|
|
else:
|
|
items = list(key.value)
|
|
return [random.choice(items) for _ in range(-count)]
|
|
|
|
@command((Key(set), bytes), (bytes,))
|
|
def srem(self, key: CommandItem, *members: bytes) -> int:
|
|
old_size = len(key.value)
|
|
for member in members:
|
|
key.value.discard(member)
|
|
deleted = old_size - len(key.value)
|
|
if deleted:
|
|
key.updated()
|
|
return deleted
|
|
|
|
@command((Key(set), Int), (bytes, bytes))
|
|
def sscan(self, key: CommandItem, cursor: int, *args: bytes) -> Any:
|
|
return self._scan(key.value, cursor, *args)
|
|
|
|
@command((Key(set),), (Key(set),))
|
|
def sunion(self, *keys: CommandItem) -> Any:
|
|
return _setop(lambda a, b: a | b, False, None, *keys)
|
|
|
|
@command((Key(), Key(set)), (Key(set),))
|
|
def sunionstore(self, dst: CommandItem, *keys: CommandItem) -> Any:
|
|
return _setop(lambda a, b: a | b, False, dst, *keys)
|
|
|
|
# Hyperloglog commands
|
|
# These are not quite the same as the real redis ones, which are
|
|
# approximate and store the results in a string. Instead, it is implemented
|
|
# on top of sets.
|
|
|
|
@command((Key(set),), (bytes,))
|
|
def pfadd(self, key: CommandItem, *elements: bytes) -> int:
|
|
result = self.sadd(key, *elements)
|
|
# Per the documentation:
|
|
# - 1 if at least 1 HyperLogLog internal register was altered. 0 otherwise.
|
|
return 1 if result > 0 else 0
|
|
|
|
@command((Key(set),), (Key(set),))
|
|
def pfcount(self, *keys: CommandItem) -> int:
|
|
"""
|
|
Return the approximated cardinality of
|
|
the set observed by the HyperLogLog at key(s).
|
|
"""
|
|
return len(self.sunion(*keys))
|
|
|
|
@command((Key(set), Key(set)), (Key(set),))
|
|
def pfmerge(self, dest: CommandItem, *sources: CommandItem) -> SimpleString:
|
|
"""Merge N different HyperLogLogs into a single one."""
|
|
self.sunionstore(dest, *sources)
|
|
return OK
|