""" Helper classes and methods used in mixins implementing various commands. Unlike _helpers.py, here the methods should be used only in mixins. """ import functools import math import re from typing import Tuple, Union, Optional, Any, Type, List, Callable from . import _msgs as msgs from ._helpers import null_terminate, SimpleError, Database MAX_STRING_SIZE = 512 * 1024 * 1024 SUPPORTED_COMMANDS = dict() # Dictionary of supported commands name => Signature COMMANDS_WITH_SUB = set() # Commands with sub-commands class Key: """Marker to indicate that argument in signature is a key""" UNSPECIFIED = object() def __init__(self, type_: Optional[Type[Any]] = None, missing_return: Any = UNSPECIFIED) -> None: self.type_ = type_ self.missing_return = missing_return class Item: """An item stored in the database""" __slots__ = ["value", "expireat"] def __init__(self, value: Any) -> None: self.value = value self.expireat = None class CommandItem: """An item referenced by a command. It wraps an Item but has extra fields to manage updates and notifications. """ def __init__(self, key: bytes, db: Database, item: Optional["CommandItem"] = None, default: Any = None) -> None: if item is None: self._value = default self._expireat = None else: self._value = item.value self._expireat = item.expireat self.key = key self.db = db self._modified = False self._expireat_modified = False @property def value(self) -> Any: return self._value @value.setter def value(self, new_value: Any) -> None: self._value = new_value self._modified = True self.expireat = None @property def expireat(self) -> Optional[int]: return self._expireat @expireat.setter def expireat(self, value: int) -> None: self._expireat = value self._expireat_modified = True self._modified = True # Since redis 6.0.7 def get(self, default: Any) -> Any: return self._value if self else default def update(self, new_value: Any) -> None: self._value = new_value self._modified = True def updated(self) -> None: self._modified = True def writeback(self, remove_empty_val: bool = True) -> None: if self._modified: self.db.notify_watch(self.key) if (not isinstance(self.value, bytes) and (self.value is None or (not self.value and remove_empty_val))): self.db.pop(self.key, None) return item = self.db.setdefault(self.key, Item(None)) item.value = self.value item.expireat = self.expireat return if self._expireat_modified and self.key in self.db: self.db[self.key].expireat = self.expireat def __bool__(self) -> bool: return bool(self._value) or isinstance(self._value, bytes) __nonzero__ = __bool__ # For Python 2 class Hash(dict): # type:ignore DECODE_ERROR = msgs.INVALID_HASH_MSG redis_type = b"hash" class RedisType: @classmethod def decode(cls, *args, **kwargs): # type:ignore raise NotImplementedError class Int(RedisType): """Argument converter for 64-bit signed integers""" DECODE_ERROR = msgs.INVALID_INT_MSG ENCODE_ERROR = msgs.OVERFLOW_MSG MIN_VALUE = -(2 ** 63) MAX_VALUE = 2 ** 63 - 1 @classmethod def valid(cls, value: int) -> bool: return cls.MIN_VALUE <= value <= cls.MAX_VALUE @classmethod def decode(cls, value: bytes, decode_error: Optional[str] = None) -> int: try: out = int(value) if not cls.valid(out) or str(out).encode() != value: raise ValueError return out except ValueError: raise SimpleError(decode_error or cls.DECODE_ERROR) @classmethod def encode(cls, value: int) -> bytes: if cls.valid(value): return str(value).encode() else: raise SimpleError(cls.ENCODE_ERROR) class DbIndex(Int): """Argument converter for database indices""" DECODE_ERROR = msgs.INVALID_DB_MSG MIN_VALUE = 0 MAX_VALUE = 15 class BitOffset(Int): """Argument converter for unsigned bit positions""" DECODE_ERROR = msgs.INVALID_BIT_OFFSET_MSG MIN_VALUE = 0 MAX_VALUE = 8 * MAX_STRING_SIZE - 1 # Redis imposes 512MB limit on keys class BitValue(Int): DECODE_ERROR = msgs.INVALID_BIT_VALUE_MSG MIN_VALUE = 0 MAX_VALUE = 1 class Timeout(Int): """Argument converter for timeouts""" DECODE_ERROR = msgs.TIMEOUT_NEGATIVE_MSG MIN_VALUE = 0 class Float(RedisType): """Argument converter for floating-point values. Redis uses long double for some cases (INCRBYFLOAT, HINCRBYFLOAT) and double for others (zset scores), but Python doesn't support long double. """ DECODE_ERROR = msgs.INVALID_FLOAT_MSG @classmethod def decode( cls, value: bytes, allow_leading_whitespace: bool = False, allow_erange: bool = False, allow_empty: bool = False, crop_null: bool = False, decode_error: Optional[str] = None, ) -> float: # redis has some quirks in float parsing, with several variants. # See https://github.com/antirez/redis/issues/5706 try: if crop_null: value = null_terminate(value) if allow_empty and value == b"": value = b"0.0" if not allow_leading_whitespace and value[:1].isspace(): raise ValueError if value[-1:].isspace(): raise ValueError out = float(value) if math.isnan(out): raise ValueError if not allow_erange: # Values that over- or underflow- are explicitly rejected by # redis. This is a crude hack to determine whether the input # may have been such a value. if out in (math.inf, -math.inf, 0.0) and re.match(b"^[^a-zA-Z]*[1-9]", value): raise ValueError return out except ValueError: raise SimpleError(decode_error or cls.DECODE_ERROR) @classmethod def encode(cls, value: float, humanfriendly: bool) -> bytes: if math.isinf(value): return str(value).encode() elif humanfriendly: # Algorithm from ld2string in redis out = "{:.17f}".format(value) out = re.sub(r"\.?0+$", "", out) return out.encode() else: return "{:.17g}".format(value).encode() class SortFloat(Float): DECODE_ERROR = msgs.INVALID_SORT_FLOAT_MSG @classmethod def decode( cls, value: bytes, allow_leading_whitespace: bool = True, allow_erange: bool = False, allow_empty: bool = True, crop_null: bool = True, decode_error: Optional[str] = None, ) -> float: return super().decode( value, allow_leading_whitespace=True, allow_empty=True, crop_null=True ) @functools.total_ordering class BeforeAny: def __gt__(self, other: Any) -> bool: return False def __eq__(self, other: Any) -> bool: return isinstance(other, BeforeAny) def __hash__(self) -> int: return 1 @functools.total_ordering class AfterAny: def __lt__(self, other: Any) -> bool: return False def __eq__(self, other: Any) -> bool: return isinstance(other, AfterAny) def __hash__(self) -> int: return 1 class ScoreTest(RedisType): """Argument converter for sorted set score endpoints.""" def __init__(self, value: float, exclusive: bool = False, bytes_val: Optional[bytes] = None): self.value = value self.exclusive = exclusive self.bytes_val = bytes_val @classmethod def decode(cls, value: bytes) -> "ScoreTest": try: original_value = value exclusive = False if value[:1] == b"(": exclusive = True value = value[1:] fvalue = Float.decode( value, allow_leading_whitespace=True, allow_erange=True, allow_empty=True, crop_null=True, ) return cls(fvalue, exclusive, original_value) except SimpleError: raise SimpleError(msgs.INVALID_MIN_MAX_FLOAT_MSG) def __str__(self) -> str: if self.exclusive: return "({!r}".format(self.value) else: return repr(self.value) @property def lower_bound(self) -> Tuple[float, Union[AfterAny, BeforeAny]]: return self.value, AfterAny() if self.exclusive else BeforeAny() @property def upper_bound(self) -> Tuple[float, Union[AfterAny, BeforeAny]]: return self.value, BeforeAny() if self.exclusive else AfterAny() class StringTest(RedisType): """Argument converter for sorted set LEX endpoints.""" def __init__(self, value: Union[bytes, BeforeAny, AfterAny], exclusive: bool): self.value = value self.exclusive = exclusive @classmethod def decode(cls, value: bytes) -> "StringTest": if value == b"-": return cls(BeforeAny(), True) elif value == b"+": return cls(AfterAny(), True) elif value[:1] == b"(": return cls(value[1:], True) elif value[:1] == b"[": return cls(value[1:], False) else: raise SimpleError(msgs.INVALID_MIN_MAX_STR_MSG) # def to_scoretest(self, zset: ZSet) -> ScoreTest: # if isinstance(self.value, BeforeAny): # return ScoreTest(float("-inf"), False) # if isinstance(self.value, AfterAny): # return ScoreTest(float("inf"), False) # val: float = zset.get(self.value, None) # return ScoreTest(val, self.exclusive) class Signature: def __init__( self, name: str, func_name: str, fixed: Tuple[Type[Union[RedisType, bytes]]], repeat: Tuple[Type[Union[RedisType, bytes]]] = (), # type:ignore args: Tuple[str] = (), # type:ignore flags: str = "", ): self.name = name self.func_name = func_name self.fixed = fixed self.repeat = repeat self.flags = set(flags) self.command_args = args def check_arity(self, args: Tuple[Any], version: Tuple[int]) -> None: if len(args) == len(self.fixed): return delta = len(args) - len(self.fixed) if delta < 0 or not self.repeat: msg = msgs.WRONG_ARGS_MSG6.format(self.name) raise SimpleError(msg) if delta % len(self.repeat) != 0: msg = ( msgs.WRONG_ARGS_MSG7 if version >= (7,) else msgs.WRONG_ARGS_MSG6.format(self.name) ) raise SimpleError(msg) def apply( self, args: Tuple[Any], db: Database, version: Tuple[int] ) -> Union[Tuple[Any], Tuple[List[Any], List[CommandItem]]]: """Returns a tuple, which is either: - transformed args and a dict of CommandItems; or - a single containing a short-circuit return value """ self.check_arity(args, version) types = list(self.fixed) for i in range(len(args) - len(types)): types.append(self.repeat[i % len(self.repeat)]) args_list = list(args) # First pass: convert/validate non-keys, and short-circuit on missing keys for i, (arg, type_) in enumerate(zip(args_list, types)): if isinstance(type_, Key): if type_.missing_return is not Key.UNSPECIFIED and arg not in db: return (type_.missing_return,) elif type_ != bytes: args_list[i] = type_.decode( args_list[i], ) # Second pass: read keys and check their types command_items: List[CommandItem] = [] for i, (arg, type_) in enumerate(zip(args_list, types)): if isinstance(type_, Key): item = db.get(arg) default = None if ( type_.type_ is not None and item is not None and type(item.value) is not type_.type_ ): raise SimpleError(msgs.WRONGTYPE_MSG) if ( type_.type_ is not None and item is None and type_.type_ is not bytes ): default = type_.type_() args_list[i] = CommandItem(arg, db, item, default=default) command_items.append(args_list[i]) return args_list, command_items def command(*args, **kwargs) -> Callable: # type:ignore def create_signature(func: Callable[..., Any], cmd_name: str) -> None: if " " in cmd_name: COMMANDS_WITH_SUB.add(cmd_name.split(" ")[0]) SUPPORTED_COMMANDS[cmd_name] = Signature( cmd_name, func.__name__, *args, **kwargs ) def decorator(func: Callable[..., Any]) -> Callable[..., Any]: cmd_names = kwargs.pop("name", func.__name__) if isinstance(cmd_names, list): # Support for alias commands for cmd_name in cmd_names: create_signature(func, cmd_name.lower()) elif isinstance(cmd_names, str): create_signature(func, cmd_names.lower()) else: raise ValueError("command name should be a string or list of strings") return func return decorator def delete_keys(*keys: CommandItem) -> int: ans = 0 done = set() for key in keys: if key and key.key not in done: key.value = None done.add(key.key) ans += 1 return ans def fix_range(start: int, end: int, length: int) -> Tuple[int, int]: # Redis handles negative slightly differently for zrange if start < 0: start = max(0, start + length) if end < 0: end += length if start > end or start >= length: return -1, -1 end = min(end, length - 1) return start, end + 1 def fix_range_string(start: int, end: int, length: int) -> Tuple[int, int]: # Negative number handling is based on the redis source code if 0 > start > end and end < 0: return -1, -1 if start < 0: start = max(0, start + length) if end < 0: end = max(0, end + length) end = min(end, length - 1) return start, end + 1