import functools from typing import List, Union, Tuple, Callable import fakeredis._msgs as msgs from fakeredis._command_args_parsing import extract_args from fakeredis._commands import Key, command, CommandItem, Int from fakeredis._helpers import SimpleError, casematch, OK, current_time, Database from fakeredis._stream import XStream, StreamRangeTest, StreamGroup class StreamsCommandsMixin: _db: Database version: Tuple[int] _blocking: Callable @command( name="XADD", fixed=(Key(),), repeat=(bytes,), ) def xadd(self, key, *args): (nomkstream, limit, maxlen, minid), left_args = extract_args( args, ("nomkstream", "+limit", "~+maxlen", "~minid"), error_on_unexpected=False, ) if nomkstream and key.value is None: return None entry_key = left_args[0] elements = left_args[1:] if not elements or len(elements) % 2 != 0: raise SimpleError(msgs.WRONG_ARGS_MSG6.format("XADD")) stream = key.value if key.value is not None else XStream() if ( self.version < (7,) and entry_key != b"*" and not StreamRangeTest.valid_key(entry_key) ): raise SimpleError(msgs.XADD_INVALID_ID) entry_key = stream.add(elements, entry_key=entry_key) if entry_key is None: if not StreamRangeTest.valid_key(left_args[0]): raise SimpleError(msgs.XADD_INVALID_ID) raise SimpleError(msgs.XADD_ID_LOWER_THAN_LAST) if maxlen is not None or minid is not None: stream.trim(max_length=maxlen, start_entry_key=minid, limit=limit) key.update(stream) return entry_key @command( name="XTRIM", fixed=(Key(XStream),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def xtrim(self, key, *args): (limit, maxlen, minid), _ = extract_args(args, ("+limit", "~+maxlen", "~minid")) if maxlen is not None and minid is not None: raise SimpleError(msgs.SYNTAX_ERROR_MSG) if maxlen is None and minid is None: raise SimpleError(msgs.SYNTAX_ERROR_MSG) stream = key.value or XStream() res = stream.trim(max_length=maxlen, start_entry_key=minid, limit=limit) key.update(stream) return res @command(name="XLEN", fixed=(Key(XStream),)) def xlen(self, key): return len(key.value) @staticmethod def _xrange( stream: XStream, _min: StreamRangeTest, _max: StreamRangeTest, reverse: bool, count: Union[int, None], ) -> List: if stream is None: return [] if count is None: count = len(stream) res = stream.irange(_min, _max, reverse=reverse) return res[:count] @command( name="XRANGE", fixed=(Key(XStream), StreamRangeTest, StreamRangeTest), repeat=(bytes,), ) def xrange(self, key, _min, _max, *args): (count,), _ = extract_args(args, ("+count",)) return self._xrange(key.value, _min, _max, False, count) @command( name="XREVRANGE", fixed=(Key(XStream), StreamRangeTest, StreamRangeTest), repeat=(bytes,), ) def xrevrange(self, key, _min, _max, *args): (count,), _ = extract_args(args, ("+count",)) return self._xrange(key.value, _max, _min, True, count) def _xread(self, stream_start_id_list: List, count: int, first_pass: bool): max_inf = StreamRangeTest.decode(b"+") res = list() for item, start_id in stream_start_id_list: stream_results = self._xrange(item.value, start_id, max_inf, False, count) if first_pass and (count is None): return None if len(stream_results) > 0: res.append([item.key, stream_results]) return res def _xreadgroup( self, consumer_name: bytes, group_params: List[Tuple[StreamGroup, bytes, bytes]], count: int, noack: bool, first_pass: bool, ): res = list() for group, stream_name, start_id in group_params: stream_results = group.group_read(consumer_name, start_id, count, noack) if first_pass and (count is None or len(stream_results) < count): return None if len(stream_results) > 0 or start_id != b">": res.append([stream_name, stream_results]) return res @staticmethod def _parse_start_id(key: CommandItem, s: bytes) -> StreamRangeTest: if s == b"$": return StreamRangeTest.decode(key.value.last_item_key(), exclusive=True) return StreamRangeTest.decode(s, exclusive=True) @command(name="XREAD", fixed=(bytes,), repeat=(bytes,)) def xread(self, *args): (count, timeout,), left_args = extract_args(args, ("+count", "+block",), error_on_unexpected=False, ) if (len(left_args) < 3 or not casematch(left_args[0], b"STREAMS") or len(left_args) % 2 != 1): raise SimpleError(msgs.SYNTAX_ERROR_MSG) left_args = left_args[1:] num_streams = int(len(left_args) / 2) stream_start_id_list = list() for i in range(num_streams): item = CommandItem( left_args[i], self._db, item=self._db.get(left_args[i]), default=None ) start_id = self._parse_start_id(item, left_args[i + num_streams]) stream_start_id_list.append( ( item, start_id, ) ) if timeout is None: return self._xread(stream_start_id_list, count, False) else: return self._blocking( timeout / 1000.0, functools.partial(self._xread, stream_start_id_list, count) ) @command(name="XREADGROUP", fixed=(bytes, bytes, bytes), repeat=(bytes,)) def xreadgroup(self, group_const, group_name, consumer_name, *args): if not casematch(b"GROUP", group_const): raise SimpleError(msgs.SYNTAX_ERROR_MSG) (count, timeout, noack), left_args = extract_args( args, ("+count", "+block", "noack"), error_on_unexpected=False ) if (len(left_args) < 3 or not casematch(left_args[0], b"STREAMS") or len(left_args) % 2 != 1): raise SimpleError(msgs.SYNTAX_ERROR_MSG) left_args = left_args[1:] num_streams = int(len(left_args) / 2) # List of (group, stream_name, stream start-id) group_params: List[Tuple[StreamGroup, bytes, bytes]] = list() for i in range(num_streams): item = CommandItem( left_args[i], self._db, item=self._db.get(left_args[i]), default=None ) if item.value is None: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) group: StreamGroup = item.value.group_get(group_name) if not group: raise SimpleError( msgs.XREADGROUP_KEY_OR_GROUP_NOT_FOUND_MSG.format( left_args[i].decode(), group_name.decode() ) ) group_params.append( ( group, left_args[i], left_args[i + num_streams], ) ) if timeout is None: return self._xreadgroup(consumer_name, group_params, count, noack, False) else: return self._blocking( timeout / 1000.0, functools.partial( self._xreadgroup, consumer_name, group_params, count, noack ), ) @command( name="XDEL", fixed=(Key(XStream),), repeat=(bytes,), ) def xdel(self, key, *args): if len(args) == 0: raise SimpleError(msgs.WRONG_ARGS_MSG6.format("xdel")) res = key.value.delete(args) return res @command( name="XACK", fixed=(Key(XStream), bytes), repeat=(bytes,), ) def xack(self, key, group_name, *args): if len(args) == 0: raise SimpleError(msgs.WRONG_ARGS_MSG6.format("xack")) if key.value is None: return 0 group: StreamGroup = key.value.group_get(group_name) if not group: return 0 return group.ack(args) # type: ignore @command( name="XPENDING", fixed=(Key(XStream), bytes), repeat=(bytes,), ) def xpending(self, key, group_name, *args): if key.value is None: return 0 idle, start, end, count, consumer = None, None, None, None, None if len(args) > 4 and casematch(b"idle", args[0]): # Idle idle = Int.decode(args[1]) args = args[2:] if 0 < len(args) < 3: raise SimpleError(msgs.SYNTAX_ERROR_MSG) elif len(args) >= 3: start, end, count = ( StreamRangeTest.decode(args[0]), StreamRangeTest.decode(args[1]), Int.decode(args[2]), ) if len(args) > 3: consumer = args[3] group: StreamGroup = key.value.group_get(group_name) if not group: return 0 if start is not None else [] if start is not None: return group.pending(idle, start, end, count, consumer) else: return group.pending_summary() @command( name="XGROUP CREATE", fixed=(Key(XStream), bytes, bytes), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def xgroup_create(self, key, group_name, start_key, *args): ( mkstream, entries_read, ), _ = extract_args(args, ("mkstream", "+entriesread")) if key.value is None and not mkstream: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) if key.value.group_get(group_name) is not None: raise SimpleError(msgs.XGROUP_BUSYGROUP) key.value.group_add(group_name, start_key, entries_read) key.updated() return OK @command( name="XGROUP SETID", fixed=(Key(XStream), bytes, bytes), repeat=(bytes,), ) def xgroup_setid(self, key, group_name, start_key, *args): (entries_read,), _ = extract_args(args, ("+entriesread",)) if key.value is None: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) group = key.value.group_get(group_name) if not group: raise SimpleError( msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key) ) group.set_id(start_key, entries_read) return OK @command( name="XGROUP DESTROY", fixed=( Key(XStream), bytes, ), repeat=(), ) def xgroup_destroy( self, key, group_name, ): if key.value is None: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) res = key.value.group_delete(group_name) return res @command( name="XGROUP CREATECONSUMER", fixed=(Key(XStream), bytes, bytes), repeat=(), ) def xgroup_createconsumer(self, key, group_name, consumer_name): if key.value is None: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) group: StreamGroup = key.value.group_get(group_name) if not group: raise SimpleError( msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key) ) return group.add_consumer(consumer_name) @command( name="XGROUP DELCONSUMER", fixed=(Key(XStream), bytes, bytes), repeat=(), ) def xgroup_delconsumer(self, key, group_name, consumer_name): if key.value is None: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) group: StreamGroup = key.value.group_get(group_name) if not group: raise SimpleError( msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key) ) return group.del_consumer(consumer_name) @command( name="XINFO GROUPS", fixed=(Key(XStream),), repeat=(), ) def xinfo_groups( self, key, ): if key.value is None: raise SimpleError(msgs.NO_KEY_MSG) return key.value.groups_info() @command( name="XINFO STREAM", fixed=(Key(XStream),), repeat=(bytes,), ) def xinfo_stream(self, key, *args): (full,), _ = extract_args(args, ("full",)) if key.value is None: raise SimpleError(msgs.NO_KEY_MSG) return key.value.stream_info(full) @command( name="XINFO CONSUMERS", fixed=(Key(XStream), bytes), repeat=(), ) def xinfo_consumers( self, key, group_name, ): if key.value is None: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) group: StreamGroup = key.value.group_get(group_name) if not group: raise SimpleError( msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key) ) return group.consumers_info() @command( name="XCLAIM", fixed=(Key(XStream), bytes, bytes, Int, bytes), repeat=(bytes,), ) def xclaim(self, key, group_name, consumer_name, min_idle_ms, *args): stream = key.value if stream is None: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) group: StreamGroup = stream.group_get(group_name) if not group: raise SimpleError( msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key) ) (idle, _time, retry, force, justid), msg_ids = extract_args( args, ("+idle", "+time", "+retrycount", "force", "justid"), error_on_unexpected=False, left_from_first_unexpected=False, ) if idle is not None and idle > 0 and _time is None: _time = current_time() - idle msgs_claimed, _ = group.claim(min_idle_ms, msg_ids, consumer_name, _time, force) if justid: return [msg.encode() for msg in msgs_claimed] return [stream.format_record(msg) for msg in msgs_claimed] @command( name="XAUTOCLAIM", fixed=(Key(XStream), bytes, bytes, Int, bytes), repeat=(bytes,), ) def xautoclaim(self, key, group_name, consumer_name, min_idle_ms, start, *args): (count, justid), _ = extract_args(args, ("+count", "justid")) count = count or 100 stream = key.value if stream is None: raise SimpleError(msgs.XGROUP_KEY_NOT_FOUND_MSG) group: StreamGroup = stream.group_get(group_name) if not group: raise SimpleError( msgs.XGROUP_GROUP_NOT_FOUND_MSG.format(group_name.decode(), key) ) keys = group.read_pel_msgs(min_idle_ms, start, count) msgs_claimed, msgs_removed = group.claim( min_idle_ms, keys, consumer_name, None, False ) res = [ max(msgs_claimed).encode() if len(msgs_claimed) > 0 else start, [msg.encode() for msg in msgs_claimed] if justid else [stream.format_record(msg) for msg in msgs_claimed], ] if self.version >= (7,): res.append([msg.encode() for msg in msgs_removed]) return res