305 lines
9.0 KiB
Python
305 lines
9.0 KiB
Python
import hashlib
|
|
import pickle
|
|
import random
|
|
from typing import Tuple, Any, Callable
|
|
|
|
from fakeredis import _msgs as msgs
|
|
from fakeredis._command_args_parsing import extract_args
|
|
from fakeredis._commands import (
|
|
command,
|
|
Key,
|
|
Int,
|
|
DbIndex,
|
|
BeforeAny,
|
|
CommandItem,
|
|
SortFloat,
|
|
delete_keys,
|
|
)
|
|
from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch, Database
|
|
from fakeredis._zset import ZSet
|
|
|
|
|
|
class GenericCommandsMixin:
|
|
version: Tuple[int]
|
|
_server: Any
|
|
_db: Database
|
|
_db_num: int
|
|
_ttl: Callable
|
|
_scan: Callable
|
|
_key_value_type: Callable
|
|
|
|
def _lookup_key(self, key, pattern):
|
|
"""Python implementation of lookupKeyByPattern from redis"""
|
|
if pattern == b"#":
|
|
return key
|
|
p = pattern.find(b"*")
|
|
if p == -1:
|
|
return None
|
|
prefix = pattern[:p]
|
|
suffix = pattern[p + 1:]
|
|
arrow = suffix.find(b"->", 0, -1)
|
|
if arrow != -1:
|
|
field = suffix[arrow + 2:]
|
|
suffix = suffix[:arrow]
|
|
else:
|
|
field = None
|
|
new_key = prefix + key + suffix
|
|
item = CommandItem(new_key, self._db, item=self._db.get(new_key))
|
|
if item.value is None:
|
|
return None
|
|
if field is not None:
|
|
if not isinstance(item.value, dict):
|
|
return None
|
|
return item.value.get(field)
|
|
else:
|
|
if not isinstance(item.value, bytes):
|
|
return None
|
|
return item.value
|
|
|
|
def _expireat(self, key, timestamp, *args):
|
|
(
|
|
nx,
|
|
xx,
|
|
gt,
|
|
lt,
|
|
), _ = extract_args(
|
|
args,
|
|
(
|
|
"nx",
|
|
"xx",
|
|
"gt",
|
|
"lt",
|
|
),
|
|
exception=msgs.EXPIRE_UNSUPPORTED_OPTION,
|
|
)
|
|
if self.version < (7,) and any((nx, xx, gt, lt)):
|
|
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("expire"))
|
|
counter = (nx, gt, lt).count(True)
|
|
if (counter > 1) or (nx and xx):
|
|
raise SimpleError(msgs.NX_XX_GT_LT_ERROR_MSG)
|
|
if (
|
|
not key
|
|
or (xx and key.expireat is None)
|
|
or (nx and key.expireat is not None)
|
|
or (gt and key.expireat is not None and timestamp < key.expireat)
|
|
or (lt and key.expireat is not None and timestamp > key.expireat)
|
|
):
|
|
return 0
|
|
key.expireat = timestamp
|
|
return 1
|
|
|
|
@command((Key(),), (Key(),), name="del")
|
|
def del_(self, *keys):
|
|
return delete_keys(*keys)
|
|
|
|
@command((Key(missing_return=None),))
|
|
def dump(self, key):
|
|
value = pickle.dumps(key.value)
|
|
checksum = hashlib.sha1(value).digest()
|
|
return checksum + value
|
|
|
|
@command((Key(),), (Key(),))
|
|
def exists(self, *keys):
|
|
ret = 0
|
|
for key in keys:
|
|
if key:
|
|
ret += 1
|
|
return ret
|
|
|
|
@command(
|
|
(
|
|
Key(),
|
|
Int,
|
|
),
|
|
(bytes,),
|
|
name="expire",
|
|
)
|
|
def expire(self, key, seconds, *args):
|
|
res = self._expireat(key, self._db.time + seconds, *args)
|
|
return res
|
|
|
|
@command((Key(), Int))
|
|
def expireat(self, key, timestamp):
|
|
return self._expireat(key, float(timestamp))
|
|
|
|
@command((bytes,))
|
|
def keys(self, pattern):
|
|
if pattern == b"*":
|
|
return list(self._db)
|
|
else:
|
|
regex = compile_pattern(pattern)
|
|
return [key for key in self._db if regex.match(key)]
|
|
|
|
@command((Key(), DbIndex))
|
|
def move(self, key, db):
|
|
if db == self._db_num:
|
|
raise SimpleError(msgs.SRC_DST_SAME_MSG)
|
|
if not key or key.key in self._server.dbs[db]:
|
|
return 0
|
|
# TODO: what is the interaction with expiry?
|
|
self._server.dbs[db][key.key] = self._server.dbs[self._db_num][key.key]
|
|
key.value = None # Causes deletion
|
|
return 1
|
|
|
|
@command((Key(),))
|
|
def persist(self, key):
|
|
if key.expireat is None:
|
|
return 0
|
|
key.expireat = None
|
|
return 1
|
|
|
|
@command((Key(), Int))
|
|
def pexpire(self, key, ms):
|
|
return self._expireat(key, self._db.time + ms / 1000.0)
|
|
|
|
@command((Key(), Int))
|
|
def pexpireat(self, key, ms_timestamp):
|
|
return self._expireat(key, ms_timestamp / 1000.0)
|
|
|
|
@command((Key(),))
|
|
def pttl(self, key):
|
|
return self._ttl(key, 1000.0)
|
|
|
|
@command(())
|
|
def randomkey(self):
|
|
keys = list(self._db.keys())
|
|
if not keys:
|
|
return None
|
|
return random.choice(keys)
|
|
|
|
@command((Key(), Key()))
|
|
def rename(self, key, newkey):
|
|
if not key:
|
|
raise SimpleError(msgs.NO_KEY_MSG)
|
|
# TODO: check interaction with WATCH
|
|
if newkey.key != key.key:
|
|
newkey.value = key.value
|
|
newkey.expireat = key.expireat
|
|
key.value = None
|
|
return OK
|
|
|
|
@command((Key(), Key()))
|
|
def renamenx(self, key, newkey):
|
|
if not key:
|
|
raise SimpleError(msgs.NO_KEY_MSG)
|
|
if newkey:
|
|
return 0
|
|
self.rename(key, newkey)
|
|
return 1
|
|
|
|
@command((Key(), Int, bytes), (bytes,))
|
|
def restore(self, key, ttl, value, *args):
|
|
(replace,), _ = extract_args(args, ("replace",))
|
|
if key and not replace:
|
|
raise SimpleError(msgs.RESTORE_KEY_EXISTS)
|
|
checksum, value = value[:20], value[20:]
|
|
if hashlib.sha1(value).digest() != checksum:
|
|
raise SimpleError(msgs.RESTORE_INVALID_CHECKSUM_MSG)
|
|
if ttl < 0:
|
|
raise SimpleError(msgs.RESTORE_INVALID_TTL_MSG)
|
|
if ttl == 0:
|
|
expireat = None
|
|
else:
|
|
expireat = self._db.time + ttl / 1000.0
|
|
key.value = pickle.loads(value)
|
|
key.expireat = expireat
|
|
return OK
|
|
|
|
@command((Int,), (bytes, bytes))
|
|
def scan(self, cursor, *args):
|
|
return self._scan(list(self._db), cursor, *args)
|
|
|
|
@command((Key(),), (bytes,))
|
|
def sort(self, key, *args):
|
|
if key.value is not None and not isinstance(key.value, (set, list, ZSet)):
|
|
raise SimpleError(msgs.WRONGTYPE_MSG)
|
|
(asc, desc, alpha, store, sortby, (limit_start, limit_count),), left_args = extract_args(
|
|
args,
|
|
("asc", "desc", "alpha", "*store", "*by", "++limit"),
|
|
error_on_unexpected=False,
|
|
left_from_first_unexpected=False,
|
|
)
|
|
limit_start = limit_start or 0
|
|
limit_count = -1 if limit_count is None else limit_count
|
|
dontsort = sortby is not None and b"*" not in sortby
|
|
|
|
i = 0
|
|
get = []
|
|
while i < len(left_args):
|
|
if casematch(left_args[i], b"get") and i + 1 < len(left_args):
|
|
get.append(left_args[i + 1])
|
|
i += 2
|
|
else:
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
|
|
# TODO: force sorting if the object is a set and either in Lua or
|
|
# storing to a key, to match redis behaviour.
|
|
items = list(key.value) if key.value is not None else []
|
|
|
|
# These transformations are based on the redis implementation, but
|
|
# changed to produce a half-open range.
|
|
start = max(limit_start, 0)
|
|
end = len(items) if limit_count < 0 else start + limit_count
|
|
if start >= len(items):
|
|
start = end = len(items) - 1
|
|
end = min(end, len(items))
|
|
|
|
if not get:
|
|
get.append(b"#")
|
|
if sortby is None:
|
|
sortby = b"#"
|
|
|
|
if not dontsort:
|
|
if alpha:
|
|
|
|
def sort_key(val):
|
|
byval = self._lookup_key(val, sortby)
|
|
# TODO: use locale.strxfrm when not storing? But then need to decode too.
|
|
if byval is None:
|
|
byval = BeforeAny()
|
|
return byval
|
|
|
|
else:
|
|
|
|
def sort_key(val):
|
|
byval = self._lookup_key(val, sortby)
|
|
score = (
|
|
SortFloat.decode(
|
|
byval,
|
|
)
|
|
if byval is not None
|
|
else 0.0
|
|
)
|
|
return score, val
|
|
|
|
items.sort(key=sort_key, reverse=desc)
|
|
elif isinstance(key.value, (list, ZSet)):
|
|
items.reverse()
|
|
|
|
out = []
|
|
for row in items[start:end]:
|
|
for g in get:
|
|
v = self._lookup_key(row, g)
|
|
if store is not None and v is None:
|
|
v = b""
|
|
out.append(v)
|
|
if store is not None:
|
|
item = CommandItem(store, self._db, item=self._db.get(store))
|
|
item.value = out
|
|
item.writeback()
|
|
return len(out)
|
|
else:
|
|
return out
|
|
|
|
@command((Key(),))
|
|
def ttl(self, key):
|
|
return self._ttl(key, 1.0)
|
|
|
|
@command((Key(),))
|
|
def type(self, key):
|
|
return self._key_value_type(key)
|
|
|
|
@command((Key(),), (Key(),), name="unlink")
|
|
def unlink(self, *keys):
|
|
return delete_keys(*keys)
|