178 lines
7.0 KiB
Python
178 lines
7.0 KiB
Python
from typing import Tuple, Any, Dict, Callable
|
|
|
|
from fakeredis import _msgs as msgs
|
|
from fakeredis._commands import command
|
|
from fakeredis._helpers import NoResponse, compile_pattern, SimpleError
|
|
|
|
|
|
class PubSubCommandsMixin:
|
|
_server: Any
|
|
version: Tuple[int]
|
|
put_response: Callable
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(PubSubCommandsMixin, self).__init__(*args, **kwargs)
|
|
self._pubsub = 0 # Count of subscriptions
|
|
|
|
def _subscribe(self, channels, subscribers, mtype):
|
|
for channel in channels:
|
|
subs = subscribers[channel]
|
|
if self not in subs:
|
|
subs.add(self)
|
|
self._pubsub += 1
|
|
msg = [mtype, channel, self._pubsub]
|
|
self.put_response(msg)
|
|
return NoResponse()
|
|
|
|
def _unsubscribe(self, channels, subscribers, mtype):
|
|
if not channels:
|
|
channels = []
|
|
for channel, subs in subscribers.items():
|
|
if self in subs:
|
|
channels.append(channel)
|
|
for channel in channels:
|
|
subs = subscribers.get(channel, set())
|
|
if self in subs:
|
|
subs.remove(self)
|
|
if not subs:
|
|
del subscribers[channel]
|
|
self._pubsub -= 1
|
|
msg = [mtype, channel, self._pubsub]
|
|
self.put_response(msg)
|
|
return NoResponse()
|
|
|
|
def _numsub(self, subscribers: Dict[bytes, Any], *channels):
|
|
tuples_list = [(ch, len(subscribers.get(ch, []))) for ch in channels]
|
|
return [item for sublist in tuples_list for item in sublist]
|
|
|
|
@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
|
|
def psubscribe(self, *patterns):
|
|
return self._subscribe(patterns, self._server.psubscribers, b"psubscribe")
|
|
|
|
@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
|
|
def subscribe(self, *channels):
|
|
return self._subscribe(channels, self._server.subscribers, b"subscribe")
|
|
|
|
@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
|
|
def ssubscribe(self, *channels):
|
|
return self._subscribe(channels, self._server.ssubscribers, b"ssubscribe")
|
|
|
|
@command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
|
|
def punsubscribe(self, *patterns):
|
|
return self._unsubscribe(patterns, self._server.psubscribers, b"punsubscribe")
|
|
|
|
@command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
|
|
def unsubscribe(self, *channels):
|
|
return self._unsubscribe(channels, self._server.subscribers, b"unsubscribe")
|
|
|
|
@command(fixed=(), repeat=(bytes,), flags=msgs.FLAG_NO_SCRIPT)
|
|
def sunsubscribe(self, *channels):
|
|
return self._unsubscribe(channels, self._server.ssubscribers, b"sunsubscribe")
|
|
|
|
@command((bytes, bytes))
|
|
def publish(self, channel, message):
|
|
receivers = 0
|
|
msg = [b"message", channel, message]
|
|
subs = self._server.subscribers.get(channel, set())
|
|
for sock in subs:
|
|
sock.put_response(msg)
|
|
receivers += 1
|
|
for pattern, socks in self._server.psubscribers.items():
|
|
regex = compile_pattern(pattern)
|
|
if regex.match(channel):
|
|
msg = [b"pmessage", pattern, channel, message]
|
|
for sock in socks:
|
|
sock.put_response(msg)
|
|
receivers += 1
|
|
return receivers
|
|
|
|
@command((bytes, bytes))
|
|
def spublish(self, channel, message):
|
|
receivers = 0
|
|
msg = [b"smessage", channel, message]
|
|
subs = self._server.ssubscribers.get(channel, set())
|
|
for sock in subs:
|
|
sock.put_response(msg)
|
|
receivers += 1
|
|
for pattern, socks in self._server.psubscribers.items():
|
|
regex = compile_pattern(pattern)
|
|
if regex.match(channel):
|
|
msg = [b"pmessage", pattern, channel, message]
|
|
for sock in socks:
|
|
sock.put_response(msg)
|
|
receivers += 1
|
|
return receivers
|
|
|
|
@command(name="PUBSUB NUMPAT", fixed=(), repeat=())
|
|
def pubsub_numpat(self, *_):
|
|
return len(self._server.psubscribers)
|
|
|
|
def _channels(self, subscribers_dict: Dict[bytes, Any], *patterns):
|
|
channels = list(subscribers_dict.keys())
|
|
if len(patterns) > 0:
|
|
regex = compile_pattern(patterns[0])
|
|
channels = [ch for ch in channels if regex.match(ch)]
|
|
return channels
|
|
|
|
@command(name="PUBSUB CHANNELS", fixed=(), repeat=(bytes,))
|
|
def pubsub_channels(self, *args):
|
|
return self._channels(self._server.subscribers, *args)
|
|
|
|
@command(name="PUBSUB SHARDCHANNELS", fixed=(), repeat=(bytes,))
|
|
def pubsub_shardchannels(self, *args):
|
|
return self._channels(self._server.ssubscribers, *args)
|
|
|
|
@command(name="PUBSUB NUMSUB", fixed=(), repeat=(bytes,))
|
|
def pubsub_numsub(self, *args):
|
|
return self._numsub(self._server.subscribers, *args)
|
|
|
|
@command(name="PUBSUB SHARDNUMSUB", fixed=(), repeat=(bytes,))
|
|
def pubsub_shardnumsub(self, *args):
|
|
return self._numsub(self._server.ssubscribers, *args)
|
|
|
|
@command(name="PUBSUB", fixed=())
|
|
def pubsub(self, *args):
|
|
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("pubsub"))
|
|
|
|
@command(name="PUBSUB HELP", fixed=())
|
|
def pubsub_help(self, *args):
|
|
if self.version >= (7,):
|
|
help_strings = [
|
|
"PUBSUB <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
|
|
"CHANNELS [<pattern>]",
|
|
" Return the currently active channels matching a <pattern> (default: '*')"
|
|
".",
|
|
"NUMPAT",
|
|
" Return number of subscriptions to patterns.",
|
|
"NUMSUB [<channel> ...]",
|
|
" Return the number of subscribers for the specified channels, excluding",
|
|
" pattern subscriptions(default: no channels).",
|
|
"SHARDCHANNELS [<pattern>]",
|
|
" Return the currently active shard level channels matching a <pattern> (d"
|
|
"efault: '*').",
|
|
"SHARDNUMSUB [<shardchannel> ...]",
|
|
" Return the number of subscribers for the specified shard level channel(s"
|
|
")",
|
|
"HELP",
|
|
(
|
|
" Prints this help."
|
|
if self.version < (7, 1)
|
|
else " Print this help."
|
|
),
|
|
]
|
|
else:
|
|
help_strings = [
|
|
"PUBSUB <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
|
|
"CHANNELS [<pattern>]",
|
|
" Return the currently active channels matching a <pattern> (default: '*')"
|
|
".",
|
|
"NUMPAT",
|
|
" Return number of subscriptions to patterns.",
|
|
"NUMSUB [<channel> ...]",
|
|
" Return the number of subscribers for the specified channels, excluding",
|
|
" pattern subscriptions(default: no channels).",
|
|
"HELP",
|
|
" Prints this help.",
|
|
]
|
|
return [s.encode() for s in help_strings]
|