from typing import Tuple, Any, Dict, Callable from fakeredis import _msgs as msgs from fakeredis._commands import command from fakeredis._helpers import NoResponse, compile_pattern, SimpleError class PubSubCommandsMixin: _server: Any version: Tuple[int] put_response: Callable def __init__(self, *args, **kwargs): super(PubSubCommandsMixin, self).__init__(*args, **kwargs) self._pubsub = 0 # Count of subscriptions def _subscribe(self, channels, subscribers, mtype): for channel in channels: subs = subscribers[channel] if self not in subs: subs.add(self) self._pubsub += 1 msg = [mtype, channel, self._pubsub] self.put_response(msg) return NoResponse() def _unsubscribe(self, channels, subscribers, mtype): if not channels: channels = [] for channel, subs in subscribers.items(): if self in subs: channels.append(channel) for channel in channels: subs = subscribers.get(channel, set()) if self in subs: subs.remove(self) if not subs: del subscribers[channel] self._pubsub -= 1 msg = [mtype, channel, self._pubsub] self.put_response(msg) return NoResponse() def _numsub(self, subscribers: Dict[bytes, Any], *channels): tuples_list = [(ch, len(subscribers.get(ch, []))) for ch in channels] return [item for sublist in tuples_list for item in sublist] @command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT) def psubscribe(self, *patterns): return self._subscribe(patterns, self._server.psubscribers, b"psubscribe") @command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT) def subscribe(self, *channels): return self._subscribe(channels, self._server.subscribers, b"subscribe") @command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT) def ssubscribe(self, *channels): return self._subscribe(channels, self._server.ssubscribers, b"ssubscribe") @command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT) def punsubscribe(self, *patterns): return self._unsubscribe(patterns, self._server.psubscribers, b"punsubscribe") @command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT) def unsubscribe(self, *channels): return self._unsubscribe(channels, self._server.subscribers, b"unsubscribe") @command(fixed=(), repeat=(bytes,), flags=msgs.FLAG_NO_SCRIPT) def sunsubscribe(self, *channels): return self._unsubscribe(channels, self._server.ssubscribers, b"sunsubscribe") @command((bytes, bytes)) def publish(self, channel, message): receivers = 0 msg = [b"message", channel, message] subs = self._server.subscribers.get(channel, set()) for sock in subs: sock.put_response(msg) receivers += 1 for pattern, socks in self._server.psubscribers.items(): regex = compile_pattern(pattern) if regex.match(channel): msg = [b"pmessage", pattern, channel, message] for sock in socks: sock.put_response(msg) receivers += 1 return receivers @command((bytes, bytes)) def spublish(self, channel, message): receivers = 0 msg = [b"smessage", channel, message] subs = self._server.ssubscribers.get(channel, set()) for sock in subs: sock.put_response(msg) receivers += 1 for pattern, socks in self._server.psubscribers.items(): regex = compile_pattern(pattern) if regex.match(channel): msg = [b"pmessage", pattern, channel, message] for sock in socks: sock.put_response(msg) receivers += 1 return receivers @command(name="PUBSUB NUMPAT", fixed=(), repeat=()) def pubsub_numpat(self, *_): return len(self._server.psubscribers) def _channels(self, subscribers_dict: Dict[bytes, Any], *patterns): channels = list(subscribers_dict.keys()) if len(patterns) > 0: regex = compile_pattern(patterns[0]) channels = [ch for ch in channels if regex.match(ch)] return channels @command(name="PUBSUB CHANNELS", fixed=(), repeat=(bytes,)) def pubsub_channels(self, *args): return self._channels(self._server.subscribers, *args) @command(name="PUBSUB SHARDCHANNELS", fixed=(), repeat=(bytes,)) def pubsub_shardchannels(self, *args): return self._channels(self._server.ssubscribers, *args) @command(name="PUBSUB NUMSUB", fixed=(), repeat=(bytes,)) def pubsub_numsub(self, *args): return self._numsub(self._server.subscribers, *args) @command(name="PUBSUB SHARDNUMSUB", fixed=(), repeat=(bytes,)) def pubsub_shardnumsub(self, *args): return self._numsub(self._server.ssubscribers, *args) @command(name="PUBSUB", fixed=()) def pubsub(self, *args): raise SimpleError(msgs.WRONG_ARGS_MSG6.format("pubsub")) @command(name="PUBSUB HELP", fixed=()) def pubsub_help(self, *args): if self.version >= (7,): help_strings = [ "PUBSUB [ [value] [opt] ...]. Subcommands are:", "CHANNELS []", " Return the currently active channels matching a (default: '*')" ".", "NUMPAT", " Return number of subscriptions to patterns.", "NUMSUB [ ...]", " Return the number of subscribers for the specified channels, excluding", " pattern subscriptions(default: no channels).", "SHARDCHANNELS []", " Return the currently active shard level channels matching a (d" "efault: '*').", "SHARDNUMSUB [ ...]", " Return the number of subscribers for the specified shard level channel(s" ")", "HELP", ( " Prints this help." if self.version < (7, 1) else " Print this help." ), ] else: help_strings = [ "PUBSUB [ [value] [opt] ...]. Subcommands are:", "CHANNELS []", " Return the currently active channels matching a (default: '*')" ".", "NUMPAT", " Return number of subscriptions to patterns.", "NUMSUB [ ...]", " Return the number of subscribers for the specified channels, excluding", " pattern subscriptions(default: no channels).", "HELP", " Prints this help.", ] return [s.encode() for s in help_strings]