youtube-summarizer/venv311/lib/python3.11/site-packages/fakeredis/commands_mixins/streams_mixin.py

457 lines
16 KiB
Python

import functools
from typing import List, Union, Tuple, Callable
import fakeredis._msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import Key, command, CommandItem, Int
from fakeredis._helpers import SimpleError, casematch, OK, current_time, Database
from fakeredis._stream import XStream, StreamRangeTest, StreamGroup
class StreamsCommandsMixin:
_db: Database
version: Tuple[int]
_blocking: Callable
@command(
name="XADD",
fixed=(Key(),),
repeat=(bytes,),
)
def xadd(self, key, *args):
(nomkstream, limit, maxlen, minid), left_args = extract_args(
args,
("nomkstream", "+limit", "~+maxlen", "~minid"),
error_on_unexpected=False,
)
if nomkstream and key.value is None:
return None
entry_key = left_args[0]
elements = left_args[1:]
if not elements or len(elements) % 2 != 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("XADD"))
stream = key.value if key.value is not None else XStream()
if (
self.version < (7,)
and entry_key != b"*"
and not StreamRangeTest.valid_key(entry_key)
):
raise SimpleError(msgs.XADD_INVALID_ID)
entry_key = stream.add(elements, entry_key=entry_key)
if entry_key is None:
if not StreamRangeTest.valid_key(left_args[0]):
raise SimpleError(msgs.XADD_INVALID_ID)
raise SimpleError(msgs.XADD_ID_LOWER_THAN_LAST)
if maxlen is not None or minid is not None:
stream.trim(max_length=maxlen, start_entry_key=minid, limit=limit)
key.update(stream)
return entry_key
@command(
name="XTRIM",
fixed=(Key(XStream),),
repeat=(bytes,),
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
)
def xtrim(self, key, *args):
(limit, maxlen, minid), _ = extract_args(args, ("+limit", "~+maxlen", "~minid"))
if maxlen is not None and minid is not None:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if maxlen is None and minid is None:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
stream = key.value or XStream()
res = stream.trim(max_length=maxlen, start_entry_key=minid, limit=limit)
key.update(stream)
return res
@command(name="XLEN", fixed=(Key(XStream),))
def xlen(self, key):
return len(key.value)
@staticmethod
def _xrange(
stream: XStream,
_min: StreamRangeTest,
_max: StreamRangeTest,
reverse: bool,
count: Union[int, None],
) -> List:
if stream is None:
return []
if count is None:
count = len(stream)
res = stream.irange(_min, _max, reverse=reverse)
return res[:count]
@command(
name="XRANGE",
fixed=(Key(XStream), StreamRangeTest, StreamRangeTest),
repeat=(bytes,),
)
def xrange(self, key, _min, _max, *args):
(count,), _ = extract_args(args, ("+count",))
return self._xrange(key.value, _min, _max, False, count)
@command(
name="XREVRANGE",
fixed=(Key(XStream), StreamRangeTest, StreamRangeTest),
repeat=(bytes,),
)
def xrevrange(self, key, _min, _max, *args):
(count,), _ = extract_args(args, ("+count",))
return self._xrange(key.value, _max, _min, True, count)
def _xread(self, stream_start_id_list: List, count: int, first_pass: bool):
max_inf = StreamRangeTest.decode(b"+")
res = list()
for item, start_id in stream_start_id_list:
stream_results = self._xrange(item.value, start_id, max_inf, False, count)
if first_pass and (count is None):
return None
if len(stream_results) > 0:
res.append([item.key, stream_results])
return res
def _xreadgroup(
self,
consumer_name: bytes,
group_params: List[Tuple[StreamGroup, bytes, bytes]],
count: int,
noack: bool,
first_pass: bool,
):
res = list()
for group, stream_name, start_id in group_params:
stream_results = group.group_read(consumer_name, start_id, count, noack)
if first_pass and (count is None or len(stream_results) < count):
return None
if len(stream_results) > 0 or start_id != b">":
res.append([stream_name, stream_results])
return res
@staticmethod
def _parse_start_id(key: CommandItem, s: bytes) -> StreamRangeTest:
if s == b"$":
return StreamRangeTest.decode(key.value.last_item_key(), exclusive=True)
return StreamRangeTest.decode(s, exclusive=True)
@command(name="XREAD", fixed=(bytes,), repeat=(bytes,))
def xread(self, *args):
(count, timeout,), left_args = extract_args(args, ("+count", "+block",), error_on_unexpected=False, )
if (len(left_args) < 3 or not casematch(left_args[0], b"STREAMS") or len(left_args) % 2 != 1):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
left_args = left_args[1:]
num_streams = int(len(left_args) / 2)
stream_start_id_list = list()
for i in range(num_streams):
item = CommandItem(
left_args[i], self._db, item=self._db.get(left_args[i]), default=None
)
start_id = self._parse_start_id(item, left_args[i + num_streams])
stream_start_id_list.append(
(
item,
start_id,
)
)
if timeout is None:
return self._xread(stream_start_id_list, count, False)
else:
return self._blocking(
timeout / 1000.0, functools.partial(self._xread, stream_start_id_list, count)
)
@command(name="XREADGROUP", fixed=(bytes, bytes, bytes), repeat=(bytes,))
def xreadgroup(self, group_const, group_name, consumer_name, *args):
if not casematch(b"GROUP", group_const):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
(count, timeout, noack), left_args = extract_args(
args, ("+count", "+block", "noack"), error_on_unexpected=False
)
if (len(left_args) < 3
or not casematch(left_args[0], b"STREAMS")
or len(left_args) % 2 != 1):
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
left_args = left_args[1:]
num_streams = int(len(left_args) / 2)
# List of (group, stream_name, stream start-id)
group_params: List[Tuple[StreamGroup, bytes, bytes]] = list()
for i in range(num_streams):
item = CommandItem(
left_args[i], self._db, item=self._db.get(left_args[i]), default=None
)
if item.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = item.value.group_get(group_name)
if not group:
raise SimpleError(
msgs.XREADGROUP_KEY_OR_GROUP_NOT_FOUND_MSG.format(
left_args[i].decode(), group_name.decode()
)
)
group_params.append(
(
group,
left_args[i],
left_args[i + num_streams],
)
)
if timeout is None:
return self._xreadgroup(consumer_name, group_params, count, noack, False)
else:
return self._blocking(
timeout / 1000.0,
functools.partial(
self._xreadgroup, consumer_name, group_params, count, noack
),
)
@command(
name="XDEL",
fixed=(Key(XStream),),
repeat=(bytes,),
)
def xdel(self, key, *args):
if len(args) == 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("xdel"))
res = key.value.delete(args)
return res
@command(
name="XACK",
fixed=(Key(XStream), bytes),
repeat=(bytes,),
)
def xack(self, key, group_name, *args):
if len(args) == 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("xack"))
if key.value is None:
return 0
group: StreamGroup = key.value.group_get(group_name)
if not group:
return 0
return group.ack(args) # type: ignore
@command(
name="XPENDING",
fixed=(Key(XStream), bytes),
repeat=(bytes,),
)
def xpending(self, key, group_name, *args):
if key.value is None:
return 0
idle, start, end, count, consumer = None, None, None, None, None
if len(args) > 4 and casematch(b"idle", args[0]): # Idle
idle = Int.decode(args[1])
args = args[2:]
if 0 < len(args) < 3:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
elif len(args) >= 3:
start, end, count = (
StreamRangeTest.decode(args[0]),
StreamRangeTest.decode(args[1]),
Int.decode(args[2]),
)
if len(args) > 3:
consumer = args[3]
group: StreamGroup = key.value.group_get(group_name)
if not group:
return 0 if start is not None else []
if start is not None:
return group.pending(idle, start, end, count, consumer)
else:
return group.pending_summary()
@command(
name="XGROUP CREATE",
fixed=(Key(XStream), bytes, bytes),
repeat=(bytes,),
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
)
def xgroup_create(self, key, group_name, start_key, *args):
(
mkstream,
entries_read,
), _ = extract_args(args, ("mkstream", "+entriesread"))
if key.value is None and not mkstream:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
if key.value.group_get(group_name) is not None:
raise SimpleError(msgs.XGROUP_BUSYGROUP)
key.value.group_add(group_name, start_key, entries_read)
key.updated()
return OK
@command(
name="XGROUP SETID",
fixed=(Key(XStream), bytes, bytes),
repeat=(bytes,),
)
def xgroup_setid(self, key, group_name, start_key, *args):
(entries_read,), _ = extract_args(args, ("+entriesread",))
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group = key.value.group_get(group_name)
if not group:
raise SimpleError(
msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key)
)
group.set_id(start_key, entries_read)
return OK
@command(
name="XGROUP DESTROY",
fixed=(
Key(XStream),
bytes,
),
repeat=(),
)
def xgroup_destroy(
self,
key,
group_name,
):
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
res = key.value.group_delete(group_name)
return res
@command(
name="XGROUP CREATECONSUMER",
fixed=(Key(XStream), bytes, bytes),
repeat=(),
)
def xgroup_createconsumer(self, key, group_name, consumer_name):
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = key.value.group_get(group_name)
if not group:
raise SimpleError(
msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key)
)
return group.add_consumer(consumer_name)
@command(
name="XGROUP DELCONSUMER",
fixed=(Key(XStream), bytes, bytes),
repeat=(),
)
def xgroup_delconsumer(self, key, group_name, consumer_name):
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = key.value.group_get(group_name)
if not group:
raise SimpleError(
msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key)
)
return group.del_consumer(consumer_name)
@command(
name="XINFO GROUPS",
fixed=(Key(XStream),),
repeat=(),
)
def xinfo_groups(
self,
key,
):
if key.value is None:
raise SimpleError(msgs.NO_KEY_MSG)
return key.value.groups_info()
@command(
name="XINFO STREAM",
fixed=(Key(XStream),),
repeat=(bytes,),
)
def xinfo_stream(self, key, *args):
(full,), _ = extract_args(args, ("full",))
if key.value is None:
raise SimpleError(msgs.NO_KEY_MSG)
return key.value.stream_info(full)
@command(
name="XINFO CONSUMERS",
fixed=(Key(XStream), bytes),
repeat=(),
)
def xinfo_consumers(
self,
key,
group_name,
):
if key.value is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = key.value.group_get(group_name)
if not group:
raise SimpleError(
msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key)
)
return group.consumers_info()
@command(
name="XCLAIM",
fixed=(Key(XStream), bytes, bytes, Int, bytes),
repeat=(bytes,),
)
def xclaim(self, key, group_name, consumer_name, min_idle_ms, *args):
stream = key.value
if stream is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = stream.group_get(group_name)
if not group:
raise SimpleError(
msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key)
)
(idle, _time, retry, force, justid), msg_ids = extract_args(
args,
("+idle", "+time", "+retrycount", "force", "justid"),
error_on_unexpected=False,
left_from_first_unexpected=False,
)
if idle is not None and idle > 0 and _time is None:
_time = current_time() - idle
msgs_claimed, _ = group.claim(min_idle_ms, msg_ids, consumer_name, _time, force)
if justid:
return [msg.encode() for msg in msgs_claimed]
return [stream.format_record(msg) for msg in msgs_claimed]
@command(
name="XAUTOCLAIM",
fixed=(Key(XStream), bytes, bytes, Int, bytes),
repeat=(bytes,),
)
def xautoclaim(self, key, group_name, consumer_name, min_idle_ms, start, *args):
(count, justid), _ = extract_args(args, ("+count", "justid"))
count = count or 100
stream = key.value
if stream is None:
raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG)
group: StreamGroup = stream.group_get(group_name)
if not group:
raise SimpleError(
msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key)
)
keys = group.read_pel_msgs(min_idle_ms, start, count)
msgs_claimed, msgs_removed = group.claim(
min_idle_ms, keys, consumer_name, None, False
)
res = [
max(msgs_claimed).encode() if len(msgs_claimed) > 0 else start,
[msg.encode() for msg in msgs_claimed]
if justid
else [stream.format_record(msg) for msg in msgs_claimed],
]
if self.version >= (7,):
res.append([msg.encode() for msg in msgs_removed])
return res