251 lines
9.2 KiB
Python
251 lines
9.2 KiB
Python
from typing import Tuple
|
|
|
|
import re
|
|
from fakeredis import _msgs as msgs
|
|
from fakeredis._commands import (
|
|
command,
|
|
Key,
|
|
Int,
|
|
BitOffset,
|
|
BitValue,
|
|
fix_range_string,
|
|
fix_range,
|
|
)
|
|
from fakeredis._helpers import SimpleError, casematch
|
|
|
|
|
|
class BitfieldEncoding:
|
|
signed: bool
|
|
size: int
|
|
|
|
def __init__(self, encoding):
|
|
match = re.match(br'^([ui])(\d+)$', encoding)
|
|
if match is None:
|
|
raise SimpleError(msgs.INVALID_BITFIELD_TYPE)
|
|
|
|
self.signed = match[1] == b'i'
|
|
self.size = int(match[2])
|
|
|
|
if self.size < 1 or self.size > (64 if self.signed else 63):
|
|
raise SimpleError(msgs.INVALID_BITFIELD_TYPE)
|
|
|
|
|
|
class BitmapCommandsMixin:
|
|
version: Tuple[int]
|
|
|
|
# TODO: bitfield, bitfield_ro, bitpos
|
|
@staticmethod
|
|
def _bytes_as_bin_string(value):
|
|
return "".join([bin(i).lstrip("0b").rjust(8, "0") for i in value])
|
|
|
|
@command((Key(bytes), Int), (bytes,))
|
|
def bitpos(self, key, bit, *args):
|
|
if bit != 0 and bit != 1:
|
|
raise SimpleError(msgs.BIT_ARG_MUST_BE_ZERO_OR_ONE)
|
|
if len(args) > 3:
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
if len(args) == 3 and self.version < (7,):
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
bit_mode = False
|
|
if len(args) == 3 and self.version >= (7,):
|
|
bit_mode = casematch(args[2], b"bit")
|
|
if not bit_mode and not casematch(args[2], b"byte"):
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
start = 0 if len(args) == 0 else Int.decode(args[0])
|
|
bit_chr = str(bit)
|
|
key_value = key.value if key.value else b""
|
|
|
|
if bit_mode:
|
|
value = self._bytes_as_bin_string(key_value)
|
|
end = len(value) if len(args) <= 1 else Int.decode(args[1])
|
|
start, end = fix_range(start, end, len(value))
|
|
value = value[start:end]
|
|
else:
|
|
end = len(key_value) if len(args) <= 1 else Int.decode(args[1])
|
|
start, end = fix_range(start, end, len(key_value))
|
|
value = self._bytes_as_bin_string(key_value[start:end])
|
|
|
|
result = value.find(bit_chr)
|
|
if result != -1:
|
|
result += start if bit_mode else (start * 8)
|
|
return result
|
|
|
|
@command((Key(bytes, 0),), (bytes,))
|
|
def bitcount(self, key, *args):
|
|
# Redis checks the argument count before decoding integers. That's why
|
|
# we can't declare them as Int.
|
|
if len(args) == 0:
|
|
value = key.value
|
|
return bin(int.from_bytes(value, "little")).count("1")
|
|
|
|
if not 2 <= len(args) <= 3:
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
start = Int.decode(args[0])
|
|
end = Int.decode(args[1])
|
|
bit_mode = False
|
|
if len(args) == 3 and self.version < (7,):
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
if len(args) == 3 and self.version >= (7,):
|
|
bit_mode = casematch(args[2], b"bit")
|
|
if not bit_mode and not casematch(args[2], b"byte"):
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
|
|
if bit_mode:
|
|
value = self._bytes_as_bin_string(key.value if key.value else b"")
|
|
start, end = fix_range_string(start, end, len(value))
|
|
return value[start:end].count("1")
|
|
start, end = fix_range_string(start, end, len(key.value))
|
|
value = key.value[start:end]
|
|
|
|
return bin(int.from_bytes(value, "little")).count("1")
|
|
|
|
@command((Key(bytes), BitOffset))
|
|
def getbit(self, key, offset):
|
|
value = key.get(b"")
|
|
byte = offset // 8
|
|
remaining = offset % 8
|
|
actual_bitoffset = 7 - remaining
|
|
try:
|
|
actual_val = value[byte]
|
|
except IndexError:
|
|
return 0
|
|
return 1 if (1 << actual_bitoffset) & actual_val else 0
|
|
|
|
@command((Key(bytes), BitOffset, BitValue))
|
|
def setbit(self, key, offset, value):
|
|
val = key.value if key.value is not None else b"\x00"
|
|
byte = offset // 8
|
|
remaining = offset % 8
|
|
actual_bitoffset = 7 - remaining
|
|
if len(val) - 1 < byte:
|
|
# We need to expand val so that we can set the appropriate
|
|
# bit.
|
|
needed = byte - (len(val) - 1)
|
|
val += b"\x00" * needed
|
|
old_byte = val[byte]
|
|
if value == 1:
|
|
new_byte = old_byte | (1 << actual_bitoffset)
|
|
else:
|
|
new_byte = old_byte & ~(1 << actual_bitoffset)
|
|
old_value = value if old_byte == new_byte else 1 - value
|
|
reconstructed = bytearray(val)
|
|
reconstructed[byte] = new_byte
|
|
if bytes(reconstructed) != key.value or (
|
|
self.version == 6 and old_byte != new_byte
|
|
):
|
|
key.update(bytes(reconstructed))
|
|
return old_value
|
|
|
|
@staticmethod
|
|
def _bitop(op, *keys):
|
|
value = keys[0].value
|
|
ans = keys[0].value
|
|
i = 1
|
|
while i < len(keys):
|
|
value = keys[i].value if keys[i].value is not None else b""
|
|
ans = bytes(op(a, b) for a, b in zip(ans, value))
|
|
i += 1
|
|
return ans
|
|
|
|
@command((bytes, Key()), (Key(bytes),))
|
|
def bitop(self, op_name, dst, *keys):
|
|
if len(keys) == 0:
|
|
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("bitop"))
|
|
if casematch(op_name, b"and"):
|
|
res = self._bitop(lambda a, b: a & b, *keys)
|
|
elif casematch(op_name, b"or"):
|
|
res = self._bitop(lambda a, b: a | b, *keys)
|
|
elif casematch(op_name, b"xor"):
|
|
res = self._bitop(lambda a, b: a ^ b, *keys)
|
|
elif casematch(op_name, b"not"):
|
|
if len(keys) != 1:
|
|
raise SimpleError(msgs.BITOP_NOT_ONE_KEY_ONLY)
|
|
val = keys[0].value
|
|
res = bytes([((1 << 8) - 1 - val[i]) for i in range(len(val))])
|
|
else:
|
|
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("bitop"))
|
|
dst.value = res
|
|
return len(dst.value)
|
|
|
|
def _bitfield_get(self, key, encoding, offset):
|
|
ans = 0
|
|
for i in range(0, encoding.size):
|
|
ans <<= 1
|
|
if self.getbit(key, offset + i):
|
|
ans += -1 if encoding.signed and i == 0 else 1
|
|
return ans
|
|
|
|
def _bitfield_set(self, key, encoding, offset, overflow, value=None, incr=0):
|
|
if encoding.signed:
|
|
min_value = -(1 << (encoding.size - 1))
|
|
max_value = (1 << (encoding.size - 1)) - 1
|
|
else:
|
|
min_value = 0
|
|
max_value = (1 << encoding.size) - 1
|
|
|
|
ans = self._bitfield_get(key, encoding, offset)
|
|
new_value = ans if value is None else value
|
|
if not encoding.signed:
|
|
new_value &= (1 << 64) - 1 # force cast to uint64_t
|
|
|
|
if overflow == b"FAIL" and not (min_value <= new_value + incr <= max_value):
|
|
return None # yes, failing in this context is not writing the value
|
|
elif overflow == b"SAT":
|
|
if new_value + incr > max_value:
|
|
new_value, incr = max_value, 0
|
|
# REDIS only checks for unsigned underflow on negative incr:
|
|
if (encoding.signed or incr < 0) and new_value + incr < min_value:
|
|
new_value, incr = min_value, 0
|
|
|
|
new_value += incr
|
|
new_value &= (1 << encoding.size) - 1
|
|
# normalize signed number by changing the sign associated to higher bit:
|
|
if encoding.signed and new_value > max_value:
|
|
new_value -= 1 << encoding.size
|
|
|
|
for i in range(0, encoding.size):
|
|
bit = (new_value >> (encoding.size - i - 1)) & 1
|
|
self.setbit(key, offset + i, bit)
|
|
return new_value if value is None else ans
|
|
|
|
@command(fixed=(Key(bytes),), repeat=(bytes,))
|
|
def bitfield(self, key, *args):
|
|
overflow = b"WRAP"
|
|
results = []
|
|
i = 0
|
|
while i < len(args):
|
|
if casematch(args[i], b"overflow") and i + 1 < len(args):
|
|
overflow = args[i+1].upper()
|
|
if overflow not in (b"WRAP", b"SAT", b"FAIL"):
|
|
raise SimpleError(msgs.INVALID_OVERFLOW_TYPE)
|
|
i += 2
|
|
elif casematch(args[i], b"get") and i + 2 < len(args):
|
|
encoding = BitfieldEncoding(args[i+1])
|
|
offset = BitOffset.decode(args[i+2])
|
|
results.append(self._bitfield_get(key, encoding, offset))
|
|
i += 3
|
|
elif casematch(args[i], b"set") and i + 3 < len(args):
|
|
old_value = self._bitfield_set(
|
|
key=key,
|
|
encoding=BitfieldEncoding(args[i + 1]),
|
|
offset=BitOffset.decode(args[i + 2]),
|
|
value=Int.decode(args[i + 3]),
|
|
overflow=overflow
|
|
)
|
|
results.append(old_value)
|
|
i += 4
|
|
elif casematch(args[i], b"incrby") and i + 3 < len(args):
|
|
old_value = self._bitfield_set(
|
|
key=key,
|
|
encoding=BitfieldEncoding(args[i + 1]),
|
|
offset=BitOffset.decode(args[i + 2]),
|
|
incr=Int.decode(args[i + 3]),
|
|
overflow=overflow
|
|
)
|
|
results.append(old_value)
|
|
i += 4
|
|
else:
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
|
|
return results
|