295 lines
9.5 KiB
Python
295 lines
9.5 KiB
Python
import math
|
|
from typing import Tuple, Callable
|
|
|
|
from fakeredis import _msgs as msgs
|
|
from fakeredis._command_args_parsing import extract_args
|
|
from fakeredis._commands import (
|
|
command,
|
|
Key,
|
|
Int,
|
|
Float,
|
|
MAX_STRING_SIZE,
|
|
delete_keys,
|
|
fix_range_string,
|
|
)
|
|
from fakeredis._helpers import OK, SimpleError, casematch, Database
|
|
|
|
|
|
def _lcs(s1, s2):
|
|
l1 = len(s1)
|
|
l2 = len(s2)
|
|
|
|
# Opt array to store the optimal solution value till ith and jth position for 2 strings
|
|
opt = [[0] * (l2 + 1) for _ in range(0, l1 + 1)]
|
|
|
|
# Pi array to store the direction when calculating the actual sequence
|
|
pi = [[0] * (l2 + 1) for _ in range(0, l1 + 1)]
|
|
|
|
# Algorithm to calculate the length of the longest common subsequence
|
|
for r in range(1, l1 + 1):
|
|
for c in range(1, l2 + 1):
|
|
if s1[r - 1] == s2[c - 1]:
|
|
opt[r][c] = opt[r - 1][c - 1] + 1
|
|
pi[r][c] = 0
|
|
elif opt[r][c - 1] >= opt[r - 1][c]:
|
|
opt[r][c] = opt[r][c - 1]
|
|
pi[r][c] = 1
|
|
else:
|
|
opt[r][c] = opt[r - 1][c]
|
|
pi[r][c] = 2
|
|
# Length of the longest common subsequence is saved at opt[n][m]
|
|
|
|
# Algorithm to calculate the longest common subsequence using the Pi array
|
|
# Also calculate the list of matches
|
|
r, c = l1, l2
|
|
result = ""
|
|
matches = list()
|
|
s1ind, s2ind, curr_length = None, None, 0
|
|
|
|
while r > 0 and c > 0:
|
|
if pi[r][c] == 0:
|
|
result = chr(s1[r - 1]) + result
|
|
r -= 1
|
|
c -= 1
|
|
curr_length += 1
|
|
elif pi[r][c] == 2:
|
|
r -= 1
|
|
else:
|
|
c -= 1
|
|
|
|
if pi[r][c] == 0 and curr_length == 1:
|
|
s1ind = r
|
|
s2ind = c
|
|
elif pi[r][c] > 0 and curr_length > 0:
|
|
matches.append([[r, s1ind], [c, s2ind], curr_length])
|
|
s1ind, s2ind, curr_length = None, None, 0
|
|
if curr_length:
|
|
matches.append([[s1ind, r], [s2ind, c], curr_length])
|
|
|
|
return opt[l1][l2], result.encode(), matches
|
|
|
|
|
|
class StringCommandsMixin:
|
|
_db: Database
|
|
version: Tuple[int]
|
|
_encodeint: Callable
|
|
_encodefloat: Callable
|
|
|
|
@command((Key(bytes), bytes))
|
|
def append(self, key, value):
|
|
old = key.get(b"")
|
|
if len(old) + len(value) > MAX_STRING_SIZE:
|
|
raise SimpleError(msgs.STRING_OVERFLOW_MSG)
|
|
key.update(key.get(b"") + value)
|
|
return len(key.value)
|
|
|
|
@command((Key(bytes),))
|
|
def decr(self, key):
|
|
return self.incrby(key, -1)
|
|
|
|
@command((Key(bytes), Int))
|
|
def decrby(self, key, amount):
|
|
return self.incrby(key, -amount)
|
|
|
|
@command((Key(bytes),))
|
|
def get(self, key):
|
|
return key.get(None)
|
|
|
|
@command((Key(bytes),))
|
|
def getdel(self, key):
|
|
res = key.get(None)
|
|
delete_keys(key)
|
|
return res
|
|
|
|
@command(name=["GETRANGE", "SUBSTR"], fixed=(Key(bytes), Int, Int))
|
|
def getrange(self, key, start, end):
|
|
value = key.get(b"")
|
|
start, end = fix_range_string(start, end, len(value))
|
|
return value[start:end]
|
|
|
|
@command((Key(bytes), bytes))
|
|
def getset(self, key, value):
|
|
old = key.value
|
|
key.value = value
|
|
return old
|
|
|
|
@command((Key(bytes), Int))
|
|
def incrby(self, key, amount):
|
|
c = Int.decode(key.get(b"0")) + amount
|
|
key.update(self._encodeint(c))
|
|
return c
|
|
|
|
@command((Key(bytes),))
|
|
def incr(self, key):
|
|
return self.incrby(key, 1)
|
|
|
|
@command((Key(bytes), bytes))
|
|
def incrbyfloat(self, key, amount):
|
|
# TODO: introduce convert_order so that we can specify amount is Float
|
|
c = Float.decode(key.get(b"0")) + Float.decode(amount)
|
|
if not math.isfinite(c):
|
|
raise SimpleError(msgs.NONFINITE_MSG)
|
|
encoded = self._encodefloat(c, True)
|
|
key.update(encoded)
|
|
return encoded
|
|
|
|
@command((Key(),), (Key(),))
|
|
def mget(self, *keys):
|
|
return [key.value if isinstance(key.value, bytes) else None for key in keys]
|
|
|
|
@command((Key(), bytes), (Key(), bytes))
|
|
def mset(self, *args):
|
|
for i in range(0, len(args), 2):
|
|
args[i].value = args[i + 1]
|
|
return OK
|
|
|
|
@command((Key(), bytes), (Key(), bytes))
|
|
def msetnx(self, *args):
|
|
for i in range(0, len(args), 2):
|
|
if args[i]:
|
|
return 0
|
|
for i in range(0, len(args), 2):
|
|
args[i].value = args[i + 1]
|
|
return 1
|
|
|
|
@command((Key(), Int, bytes))
|
|
def psetex(self, key, ms, value):
|
|
if ms <= 0 or self._db.time * 1000 + ms >= 2 ** 63:
|
|
raise SimpleError(msgs.INVALID_EXPIRE_MSG.format("psetex"))
|
|
key.value = value
|
|
key.expireat = self._db.time + ms / 1000.0
|
|
return OK
|
|
|
|
@command(name="SET", fixed=(Key(), bytes), repeat=(bytes,))
|
|
def set_(self, key, value, *args):
|
|
(ex, px, xx, nx, keepttl, get), _ = extract_args(
|
|
args, ("+ex", "+px", "xx", "nx", "keepttl", "get")
|
|
)
|
|
if ex is not None and (ex <= 0 or (self._db.time + ex) * 1000 >= 2 ** 63):
|
|
raise SimpleError(msgs.INVALID_EXPIRE_MSG.format("set"))
|
|
if px is not None and (px <= 0 or self._db.time * 1000 + px >= 2 ** 63):
|
|
raise SimpleError(msgs.INVALID_EXPIRE_MSG.format("set"))
|
|
|
|
if (xx and nx) or ((px is not None) + (ex is not None) + keepttl > 1):
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
if nx and get and self.version < (7,):
|
|
# The command docs say this is allowed from Redis 7.0.
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
|
|
old_value = None
|
|
if get:
|
|
if key.value is not None and type(key.value) is not bytes:
|
|
raise SimpleError(msgs.WRONGTYPE_MSG)
|
|
old_value = key.value
|
|
|
|
if nx and key:
|
|
return old_value
|
|
if xx and not key:
|
|
return old_value
|
|
if not keepttl:
|
|
key.value = value
|
|
else:
|
|
key.update(value)
|
|
if ex is not None:
|
|
key.expireat = self._db.time + ex
|
|
if px is not None:
|
|
key.expireat = self._db.time + px / 1000.0
|
|
return OK if not get else old_value
|
|
|
|
@command((Key(), Int, bytes))
|
|
def setex(self, key, seconds, value):
|
|
if seconds <= 0 or (self._db.time + seconds) * 1000 >= 2 ** 63:
|
|
raise SimpleError(msgs.INVALID_EXPIRE_MSG.format("setex"))
|
|
key.value = value
|
|
key.expireat = self._db.time + seconds
|
|
return OK
|
|
|
|
@command((Key(), bytes))
|
|
def setnx(self, key, value):
|
|
if key:
|
|
return 0
|
|
key.value = value
|
|
return 1
|
|
|
|
@command((Key(bytes), Int, bytes))
|
|
def setrange(self, key, offset, value):
|
|
if offset < 0:
|
|
raise SimpleError(msgs.INVALID_OFFSET_MSG)
|
|
elif not value:
|
|
return len(key.get(b""))
|
|
elif offset + len(value) > MAX_STRING_SIZE:
|
|
raise SimpleError(msgs.STRING_OVERFLOW_MSG)
|
|
out = key.get(b"")
|
|
if len(out) < offset:
|
|
out += b"\x00" * (offset - len(out))
|
|
out = out[0:offset] + value + out[offset + len(value):]
|
|
key.update(out)
|
|
return len(out)
|
|
|
|
@command((Key(bytes),))
|
|
def strlen(self, key):
|
|
return len(key.get(b""))
|
|
|
|
@command((Key(bytes),), (bytes,))
|
|
def getex(self, key, *args):
|
|
i, count_options, expire_time, diff = 0, 0, None, None
|
|
|
|
while i < len(args):
|
|
count_options += 1
|
|
if casematch(args[i], b"ex") and i + 1 < len(args):
|
|
diff = Int.decode(args[i + 1])
|
|
expire_time = self._db.time + diff
|
|
i += 2
|
|
elif casematch(args[i], b"px") and i + 1 < len(args):
|
|
diff = Int.decode(args[i + 1])
|
|
expire_time = (self._db.time * 1000 + diff) / 1000.0
|
|
i += 2
|
|
elif casematch(args[i], b"exat") and i + 1 < len(args):
|
|
expire_time = Int.decode(args[i + 1])
|
|
i += 2
|
|
elif casematch(args[i], b"pxat") and i + 1 < len(args):
|
|
expire_time = Int.decode(args[i + 1]) / 1000.0
|
|
i += 2
|
|
elif casematch(args[i], b"persist"):
|
|
expire_time = None
|
|
i += 1
|
|
else:
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
if (
|
|
expire_time is not None
|
|
and (expire_time <= 0 or expire_time * 1000 >= 2 ** 63)
|
|
) or (diff is not None and (diff <= 0 or diff * 1000 >= 2 ** 63)):
|
|
raise SimpleError(msgs.INVALID_EXPIRE_MSG.format("getex"))
|
|
if count_options > 1:
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
|
|
key.expireat = expire_time
|
|
return key.get(None)
|
|
|
|
@command(
|
|
(
|
|
Key(bytes),
|
|
Key(bytes),
|
|
),
|
|
(bytes,),
|
|
)
|
|
def lcs(self, k1, k2, *args):
|
|
s1 = k1.value or b""
|
|
s2 = k2.value or b""
|
|
|
|
(arg_idx, arg_len, arg_minmatchlen, arg_withmatchlen), _ = extract_args(
|
|
args, ("idx", "len", "+minmatchlen", "withmatchlen")
|
|
)
|
|
if arg_idx and arg_len:
|
|
raise SimpleError(msgs.LCS_CANT_HAVE_BOTH_LEN_AND_IDX)
|
|
lcs_len, lcs_val, matches = _lcs(s1, s2)
|
|
if not arg_idx and not arg_len:
|
|
return lcs_val
|
|
if arg_len:
|
|
return lcs_len
|
|
arg_minmatchlen = arg_minmatchlen if arg_minmatchlen else 0
|
|
results = list(filter(lambda x: x[2] >= arg_minmatchlen, matches))
|
|
if not arg_withmatchlen:
|
|
results = list(map(lambda x: [x[0], x[1]], results))
|
|
return [b"matches", results, b"len", lcs_len]
|