481 lines
15 KiB
Python
481 lines
15 KiB
Python
"""
|
|
Helper classes and methods used in mixins implementing various commands.
|
|
Unlike _helpers.py, here the methods should be used only in mixins.
|
|
"""
|
|
import functools
|
|
import math
|
|
import re
|
|
from typing import Tuple, Union, Optional, Any, Type, List, Callable
|
|
|
|
from . import _msgs as msgs
|
|
from ._helpers import null_terminate, SimpleError, Database
|
|
|
|
MAX_STRING_SIZE = 512 * 1024 * 1024
|
|
SUPPORTED_COMMANDS = dict() # Dictionary of supported commands name => Signature
|
|
COMMANDS_WITH_SUB = set() # Commands with sub-commands
|
|
|
|
|
|
class Key:
|
|
"""Marker to indicate that argument in signature is a key"""
|
|
|
|
UNSPECIFIED = object()
|
|
|
|
def __init__(self, type_: Optional[Type[Any]] = None, missing_return: Any = UNSPECIFIED) -> None:
|
|
self.type_ = type_
|
|
self.missing_return = missing_return
|
|
|
|
|
|
class Item:
|
|
"""An item stored in the database"""
|
|
|
|
__slots__ = ["value", "expireat"]
|
|
|
|
def __init__(self, value: Any) -> None:
|
|
self.value = value
|
|
self.expireat = None
|
|
|
|
|
|
class CommandItem:
|
|
"""An item referenced by a command.
|
|
|
|
It wraps an Item but has extra fields to manage updates and notifications.
|
|
"""
|
|
|
|
def __init__(self, key: bytes, db: Database, item: Optional["CommandItem"] = None, default: Any = None) -> None:
|
|
if item is None:
|
|
self._value = default
|
|
self._expireat = None
|
|
else:
|
|
self._value = item.value
|
|
self._expireat = item.expireat
|
|
self.key = key
|
|
self.db = db
|
|
self._modified = False
|
|
self._expireat_modified = False
|
|
|
|
@property
|
|
def value(self) -> Any:
|
|
return self._value
|
|
|
|
@value.setter
|
|
def value(self, new_value: Any) -> None:
|
|
self._value = new_value
|
|
self._modified = True
|
|
self.expireat = None
|
|
|
|
@property
|
|
def expireat(self) -> Optional[int]:
|
|
return self._expireat
|
|
|
|
@expireat.setter
|
|
def expireat(self, value: int) -> None:
|
|
self._expireat = value
|
|
self._expireat_modified = True
|
|
self._modified = True # Since redis 6.0.7
|
|
|
|
def get(self, default: Any) -> Any:
|
|
return self._value if self else default
|
|
|
|
def update(self, new_value: Any) -> None:
|
|
self._value = new_value
|
|
self._modified = True
|
|
|
|
def updated(self) -> None:
|
|
self._modified = True
|
|
|
|
def writeback(self, remove_empty_val: bool = True) -> None:
|
|
if self._modified:
|
|
self.db.notify_watch(self.key)
|
|
if (not isinstance(self.value, bytes)
|
|
and (self.value is None or (not self.value and remove_empty_val))):
|
|
self.db.pop(self.key, None)
|
|
return
|
|
item = self.db.setdefault(self.key, Item(None))
|
|
item.value = self.value
|
|
item.expireat = self.expireat
|
|
return
|
|
|
|
if self._expireat_modified and self.key in self.db:
|
|
self.db[self.key].expireat = self.expireat
|
|
|
|
def __bool__(self) -> bool:
|
|
return bool(self._value) or isinstance(self._value, bytes)
|
|
|
|
__nonzero__ = __bool__ # For Python 2
|
|
|
|
|
|
class Hash(dict): # type:ignore
|
|
DECODE_ERROR = msgs.INVALID_HASH_MSG
|
|
redis_type = b"hash"
|
|
|
|
|
|
class RedisType:
|
|
@classmethod
|
|
def decode(cls, *args, **kwargs): # type:ignore
|
|
raise NotImplementedError
|
|
|
|
|
|
class Int(RedisType):
|
|
"""Argument converter for 64-bit signed integers"""
|
|
|
|
DECODE_ERROR = msgs.INVALID_INT_MSG
|
|
ENCODE_ERROR = msgs.OVERFLOW_MSG
|
|
MIN_VALUE = -(2 ** 63)
|
|
MAX_VALUE = 2 ** 63 - 1
|
|
|
|
@classmethod
|
|
def valid(cls, value: int) -> bool:
|
|
return cls.MIN_VALUE <= value <= cls.MAX_VALUE
|
|
|
|
@classmethod
|
|
def decode(cls, value: bytes, decode_error: Optional[str] = None) -> int:
|
|
try:
|
|
out = int(value)
|
|
if not cls.valid(out) or str(out).encode() != value:
|
|
raise ValueError
|
|
return out
|
|
except ValueError:
|
|
raise SimpleError(decode_error or cls.DECODE_ERROR)
|
|
|
|
@classmethod
|
|
def encode(cls, value: int) -> bytes:
|
|
if cls.valid(value):
|
|
return str(value).encode()
|
|
else:
|
|
raise SimpleError(cls.ENCODE_ERROR)
|
|
|
|
|
|
class DbIndex(Int):
|
|
"""Argument converter for database indices"""
|
|
|
|
DECODE_ERROR = msgs.INVALID_DB_MSG
|
|
MIN_VALUE = 0
|
|
MAX_VALUE = 15
|
|
|
|
|
|
class BitOffset(Int):
|
|
"""Argument converter for unsigned bit positions"""
|
|
|
|
DECODE_ERROR = msgs.INVALID_BIT_OFFSET_MSG
|
|
MIN_VALUE = 0
|
|
MAX_VALUE = 8 * MAX_STRING_SIZE - 1 # Redis imposes 512MB limit on keys
|
|
|
|
|
|
class BitValue(Int):
|
|
DECODE_ERROR = msgs.INVALID_BIT_VALUE_MSG
|
|
MIN_VALUE = 0
|
|
MAX_VALUE = 1
|
|
|
|
|
|
class Timeout(Int):
|
|
"""Argument converter for timeouts"""
|
|
|
|
DECODE_ERROR = msgs.TIMEOUT_NEGATIVE_MSG
|
|
MIN_VALUE = 0
|
|
|
|
|
|
class Float(RedisType):
|
|
"""Argument converter for floating-point values.
|
|
|
|
Redis uses long double for some cases (INCRBYFLOAT, HINCRBYFLOAT)
|
|
and double for others (zset scores), but Python doesn't support
|
|
long double.
|
|
"""
|
|
|
|
DECODE_ERROR = msgs.INVALID_FLOAT_MSG
|
|
|
|
@classmethod
|
|
def decode(
|
|
cls,
|
|
value: bytes,
|
|
allow_leading_whitespace: bool = False,
|
|
allow_erange: bool = False,
|
|
allow_empty: bool = False,
|
|
crop_null: bool = False,
|
|
decode_error: Optional[str] = None,
|
|
) -> float:
|
|
# redis has some quirks in float parsing, with several variants.
|
|
# See https://github.com/antirez/redis/issues/5706
|
|
try:
|
|
if crop_null:
|
|
value = null_terminate(value)
|
|
if allow_empty and value == b"":
|
|
value = b"0.0"
|
|
if not allow_leading_whitespace and value[:1].isspace():
|
|
raise ValueError
|
|
if value[-1:].isspace():
|
|
raise ValueError
|
|
out = float(value)
|
|
if math.isnan(out):
|
|
raise ValueError
|
|
if not allow_erange:
|
|
# Values that over- or underflow- are explicitly rejected by
|
|
# redis. This is a crude hack to determine whether the input
|
|
# may have been such a value.
|
|
if out in (math.inf, -math.inf, 0.0) and re.match(b"^[^a-zA-Z]*[1-9]", value):
|
|
raise ValueError
|
|
return out
|
|
except ValueError:
|
|
raise SimpleError(decode_error or cls.DECODE_ERROR)
|
|
|
|
@classmethod
|
|
def encode(cls, value: float, humanfriendly: bool) -> bytes:
|
|
if math.isinf(value):
|
|
return str(value).encode()
|
|
elif humanfriendly:
|
|
# Algorithm from ld2string in redis
|
|
out = "{:.17f}".format(value)
|
|
out = re.sub(r"\.?0+$", "", out)
|
|
return out.encode()
|
|
else:
|
|
return "{:.17g}".format(value).encode()
|
|
|
|
|
|
class SortFloat(Float):
|
|
DECODE_ERROR = msgs.INVALID_SORT_FLOAT_MSG
|
|
|
|
@classmethod
|
|
def decode(
|
|
cls,
|
|
value: bytes,
|
|
allow_leading_whitespace: bool = True,
|
|
allow_erange: bool = False,
|
|
allow_empty: bool = True,
|
|
crop_null: bool = True,
|
|
decode_error: Optional[str] = None,
|
|
) -> float:
|
|
return super().decode(
|
|
value, allow_leading_whitespace=True, allow_empty=True, crop_null=True
|
|
)
|
|
|
|
|
|
@functools.total_ordering
|
|
class BeforeAny:
|
|
def __gt__(self, other: Any) -> bool:
|
|
return False
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
return isinstance(other, BeforeAny)
|
|
|
|
def __hash__(self) -> int:
|
|
return 1
|
|
|
|
|
|
@functools.total_ordering
|
|
class AfterAny:
|
|
def __lt__(self, other: Any) -> bool:
|
|
return False
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
return isinstance(other, AfterAny)
|
|
|
|
def __hash__(self) -> int:
|
|
return 1
|
|
|
|
|
|
class ScoreTest(RedisType):
|
|
"""Argument converter for sorted set score endpoints."""
|
|
|
|
def __init__(self, value: float, exclusive: bool = False, bytes_val: Optional[bytes] = None):
|
|
self.value = value
|
|
self.exclusive = exclusive
|
|
self.bytes_val = bytes_val
|
|
|
|
@classmethod
|
|
def decode(cls, value: bytes) -> "ScoreTest":
|
|
try:
|
|
original_value = value
|
|
exclusive = False
|
|
if value[:1] == b"(":
|
|
exclusive = True
|
|
value = value[1:]
|
|
fvalue = Float.decode(
|
|
value,
|
|
allow_leading_whitespace=True,
|
|
allow_erange=True,
|
|
allow_empty=True,
|
|
crop_null=True,
|
|
)
|
|
return cls(fvalue, exclusive, original_value)
|
|
except SimpleError:
|
|
raise SimpleError(msgs.INVALID_MIN_MAX_FLOAT_MSG)
|
|
|
|
def __str__(self) -> str:
|
|
if self.exclusive:
|
|
return "({!r}".format(self.value)
|
|
else:
|
|
return repr(self.value)
|
|
|
|
@property
|
|
def lower_bound(self) -> Tuple[float, Union[AfterAny, BeforeAny]]:
|
|
return self.value, AfterAny() if self.exclusive else BeforeAny()
|
|
|
|
@property
|
|
def upper_bound(self) -> Tuple[float, Union[AfterAny, BeforeAny]]:
|
|
return self.value, BeforeAny() if self.exclusive else AfterAny()
|
|
|
|
|
|
class StringTest(RedisType):
|
|
"""Argument converter for sorted set LEX endpoints."""
|
|
|
|
def __init__(self, value: Union[bytes, BeforeAny, AfterAny], exclusive: bool):
|
|
self.value = value
|
|
self.exclusive = exclusive
|
|
|
|
@classmethod
|
|
def decode(cls, value: bytes) -> "StringTest":
|
|
if value == b"-":
|
|
return cls(BeforeAny(), True)
|
|
elif value == b"+":
|
|
return cls(AfterAny(), True)
|
|
elif value[:1] == b"(":
|
|
return cls(value[1:], True)
|
|
elif value[:1] == b"[":
|
|
return cls(value[1:], False)
|
|
else:
|
|
raise SimpleError(msgs.INVALID_MIN_MAX_STR_MSG)
|
|
|
|
# def to_scoretest(self, zset: ZSet) -> ScoreTest:
|
|
# if isinstance(self.value, BeforeAny):
|
|
# return ScoreTest(float("-inf"), False)
|
|
# if isinstance(self.value, AfterAny):
|
|
# return ScoreTest(float("inf"), False)
|
|
# val: float = zset.get(self.value, None)
|
|
# return ScoreTest(val, self.exclusive)
|
|
|
|
|
|
class Signature:
|
|
def __init__(
|
|
self, name: str,
|
|
func_name: str,
|
|
fixed: Tuple[Type[Union[RedisType, bytes]]],
|
|
repeat: Tuple[Type[Union[RedisType, bytes]]] = (), # type:ignore
|
|
args: Tuple[str] = (), # type:ignore
|
|
flags: str = "",
|
|
):
|
|
self.name = name
|
|
self.func_name = func_name
|
|
self.fixed = fixed
|
|
self.repeat = repeat
|
|
self.flags = set(flags)
|
|
self.command_args = args
|
|
|
|
def check_arity(self, args: Tuple[Any], version: Tuple[int]) -> None:
|
|
if len(args) == len(self.fixed):
|
|
return
|
|
delta = len(args) - len(self.fixed)
|
|
if delta < 0 or not self.repeat:
|
|
msg = msgs.WRONG_ARGS_MSG6.format(self.name)
|
|
raise SimpleError(msg)
|
|
if delta % len(self.repeat) != 0:
|
|
msg = (
|
|
msgs.WRONG_ARGS_MSG7
|
|
if version >= (7,)
|
|
else msgs.WRONG_ARGS_MSG6.format(self.name)
|
|
)
|
|
raise SimpleError(msg)
|
|
|
|
def apply(
|
|
self, args: Tuple[Any], db: Database, version: Tuple[int]
|
|
) -> Union[Tuple[Any], Tuple[List[Any], List[CommandItem]]]:
|
|
"""Returns a tuple, which is either:
|
|
- transformed args and a dict of CommandItems; or
|
|
- a single containing a short-circuit return value
|
|
"""
|
|
self.check_arity(args, version)
|
|
|
|
types = list(self.fixed)
|
|
for i in range(len(args) - len(types)):
|
|
types.append(self.repeat[i % len(self.repeat)])
|
|
|
|
args_list = list(args)
|
|
# First pass: convert/validate non-keys, and short-circuit on missing keys
|
|
for i, (arg, type_) in enumerate(zip(args_list, types)):
|
|
if isinstance(type_, Key):
|
|
if type_.missing_return is not Key.UNSPECIFIED and arg not in db:
|
|
return (type_.missing_return,)
|
|
elif type_ != bytes:
|
|
args_list[i] = type_.decode(
|
|
args_list[i],
|
|
)
|
|
|
|
# Second pass: read keys and check their types
|
|
command_items: List[CommandItem] = []
|
|
for i, (arg, type_) in enumerate(zip(args_list, types)):
|
|
if isinstance(type_, Key):
|
|
item = db.get(arg)
|
|
default = None
|
|
if (
|
|
type_.type_ is not None
|
|
and item is not None
|
|
and type(item.value) is not type_.type_
|
|
):
|
|
raise SimpleError(msgs.WRONGTYPE_MSG)
|
|
if (
|
|
type_.type_ is not None
|
|
and item is None
|
|
and type_.type_ is not bytes
|
|
):
|
|
default = type_.type_()
|
|
args_list[i] = CommandItem(arg, db, item, default=default)
|
|
command_items.append(args_list[i])
|
|
|
|
return args_list, command_items
|
|
|
|
|
|
def command(*args, **kwargs) -> Callable: # type:ignore
|
|
def create_signature(func: Callable[..., Any], cmd_name: str) -> None:
|
|
if " " in cmd_name:
|
|
COMMANDS_WITH_SUB.add(cmd_name.split(" ")[0])
|
|
SUPPORTED_COMMANDS[cmd_name] = Signature(
|
|
cmd_name, func.__name__, *args, **kwargs
|
|
)
|
|
|
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
cmd_names = kwargs.pop("name", func.__name__)
|
|
if isinstance(cmd_names, list): # Support for alias commands
|
|
for cmd_name in cmd_names:
|
|
create_signature(func, cmd_name.lower())
|
|
elif isinstance(cmd_names, str):
|
|
create_signature(func, cmd_names.lower())
|
|
else:
|
|
raise ValueError("command name should be a string or list of strings")
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
def delete_keys(*keys: CommandItem) -> int:
|
|
ans = 0
|
|
done = set()
|
|
for key in keys:
|
|
if key and key.key not in done:
|
|
key.value = None
|
|
done.add(key.key)
|
|
ans += 1
|
|
return ans
|
|
|
|
|
|
def fix_range(start: int, end: int, length: int) -> Tuple[int, int]:
|
|
# Redis handles negative slightly differently for zrange
|
|
if start < 0:
|
|
start = max(0, start + length)
|
|
if end < 0:
|
|
end += length
|
|
if start > end or start >= length:
|
|
return -1, -1
|
|
end = min(end, length - 1)
|
|
return start, end + 1
|
|
|
|
|
|
def fix_range_string(start: int, end: int, length: int) -> Tuple[int, int]:
|
|
# Negative number handling is based on the redis source code
|
|
if 0 > start > end and end < 0:
|
|
return -1, -1
|
|
if start < 0:
|
|
start = max(0, start + length)
|
|
if end < 0:
|
|
end = max(0, end + length)
|
|
end = min(end, length - 1)
|
|
return start, end + 1
|