youtube-summarizer/venv311/lib/python3.11/site-packages/fakeredis/commands_mixins/scripting_mixin.py

289 lines
10 KiB
Python

import functools
import hashlib
import itertools
import logging
from typing import Tuple, Callable, AnyStr, Set, Any
from fakeredis import _msgs as msgs
from fakeredis._commands import command, Int
from fakeredis._helpers import (
SimpleError,
SimpleString,
null_terminate,
OK,
encode_command,
)
LOGGER = logging.getLogger("fakeredis")
REDIS_LOG_LEVELS = {
b"LOG_DEBUG": 0,
b"LOG_VERBOSE": 1,
b"LOG_NOTICE": 2,
b"LOG_WARNING": 3,
}
REDIS_LOG_LEVELS_TO_LOGGING = {
0: logging.DEBUG,
1: logging.INFO,
2: logging.INFO,
3: logging.WARNING,
}
def _ensure_str(s: AnyStr, encoding: str, replaceerr: str):
if isinstance(s, bytes):
res = s.decode(encoding=encoding, errors=replaceerr)
else:
res = str(s).encode(encoding=encoding, errors=replaceerr)
return res
def _check_for_lua_globals(lua_runtime, expected_globals):
unexpected_globals = set(lua_runtime.globals().keys()) - expected_globals
if len(unexpected_globals) > 0:
unexpected = [
_ensure_str(var, "utf-8", "replace") for var in unexpected_globals
]
raise SimpleError(msgs.GLOBAL_VARIABLE_MSG.format(", ".join(unexpected)))
def _lua_redis_log(lua_runtime, expected_globals, lvl, *args):
_check_for_lua_globals(lua_runtime, expected_globals)
if len(args) < 1:
raise SimpleError(msgs.REQUIRES_MORE_ARGS_MSG.format("redis.log()", "two"))
if lvl not in REDIS_LOG_LEVELS_TO_LOGGING.keys():
raise SimpleError(msgs.LOG_INVALID_DEBUG_LEVEL_MSG)
msg = " ".join(
[
x.decode("utf-8") if isinstance(x, bytes) else str(x)
for x in args
if not isinstance(x, bool)
]
)
LOGGER.log(REDIS_LOG_LEVELS_TO_LOGGING[lvl], msg)
class ScriptingCommandsMixin:
version: Tuple[int]
_name_to_func: Callable
_run_command: Callable
def __init__(self, *args, **kwargs):
super(ScriptingCommandsMixin, self).__init__(*args, **kwargs)
# Maps SHA1 to the script source
self.script_cache = {}
def _convert_redis_arg(self, lua_runtime, value):
# Type checks are exact to avoid issues like bool being a subclass of int.
if type(value) is bytes:
return value
elif type(value) in {int, float}:
return "{:.17g}".format(value).encode()
else:
# TODO: add the context
msg = (
msgs.LUA_COMMAND_ARG_MSG6
if self.version < (7,)
else msgs.LUA_COMMAND_ARG_MSG
)
raise SimpleError(msg)
def _convert_redis_result(self, lua_runtime, result):
if isinstance(result, (bytes, int)):
return result
elif isinstance(result, SimpleString):
return lua_runtime.table_from({b"ok": result.value})
elif result is None:
return False
elif isinstance(result, list):
converted = [
self._convert_redis_result(lua_runtime, item) for item in result
]
return lua_runtime.table_from(converted)
elif isinstance(result, SimpleError):
if result.value.startswith("ERR wrong number of arguments"):
raise SimpleError(msgs.WRONG_ARGS_MSG7)
raise result
else:
raise RuntimeError(
"Unexpected return type from redis: {}".format(type(result))
)
def _convert_lua_result(self, result, nested=True):
from lupa import lua_type
if lua_type(result) == "table":
for key in (b"ok", b"err"):
if key in result:
msg = self._convert_lua_result(result[key])
if not isinstance(msg, bytes):
raise SimpleError(msgs.LUA_WRONG_NUMBER_ARGS_MSG)
if key == b"ok":
return SimpleString(msg)
elif nested:
return SimpleError(msg.decode("utf-8", "replace"))
else:
raise SimpleError(msg.decode("utf-8", "replace"))
# Convert Lua tables into lists, starting from index 1, mimicking the behavior of StrictRedis.
result_list = []
for index in itertools.count(1):
if index not in result:
break
item = result[index]
result_list.append(self._convert_lua_result(item))
return result_list
elif isinstance(result, str):
return result.encode()
elif isinstance(result, float):
return int(result)
elif isinstance(result, bool):
return 1 if result else None
return result
def _lua_redis_call(self, lua_runtime, expected_globals, op, *args):
# Check if we've set any global variables before making any change.
_check_for_lua_globals(lua_runtime, expected_globals)
func, sig = self._name_to_func(encode_command(op))
new_args = [self._convert_redis_arg(lua_runtime, arg) for arg in args]
result = self._run_command(func, sig, new_args, True)
return self._convert_redis_result(lua_runtime, result)
def _lua_redis_pcall(self, lua_runtime, expected_globals, op, *args):
try:
return self._lua_redis_call(lua_runtime, expected_globals, op, *args)
except Exception as ex:
return lua_runtime.table_from({b"err": str(ex)})
@command((bytes, Int), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def eval(self, script, numkeys, *keys_and_args):
from lupa import LuaError, LuaRuntime, as_attrgetter
if numkeys > len(keys_and_args):
raise SimpleError(msgs.TOO_MANY_KEYS_MSG)
if numkeys < 0:
raise SimpleError(msgs.NEGATIVE_KEYS_MSG)
sha1 = hashlib.sha1(script).hexdigest().encode()
self.script_cache[sha1] = script
lua_runtime = LuaRuntime(encoding=None, unpack_returned_tuples=True)
set_globals = lua_runtime.eval(
"""
function(keys, argv, redis_call, redis_pcall, redis_log, redis_log_levels)
redis = {}
redis.call = redis_call
redis.pcall = redis_pcall
redis.log = redis_log
for level, pylevel in python.iterex(redis_log_levels.items()) do
redis[level] = pylevel
end
redis.error_reply = function(msg) return {err=msg} end
redis.status_reply = function(msg) return {ok=msg} end
KEYS = keys
ARGV = argv
end
"""
)
expected_globals: Set[Any] = set()
set_globals(
lua_runtime.table_from(keys_and_args[:numkeys]),
lua_runtime.table_from(keys_and_args[numkeys:]),
functools.partial(self._lua_redis_call, lua_runtime, expected_globals),
functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals),
functools.partial(_lua_redis_log, lua_runtime, expected_globals),
as_attrgetter(REDIS_LOG_LEVELS),
)
expected_globals.update(lua_runtime.globals().keys())
try:
result = lua_runtime.execute(script)
except SimpleError as ex:
if self.version < (7,):
raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex))
raise SimpleError(ex.value)
except LuaError as ex:
raise SimpleError(msgs.SCRIPT_ERROR_MSG.format(sha1.decode(), ex))
_check_for_lua_globals(lua_runtime, expected_globals)
return self._convert_lua_result(result, nested=False)
@command((bytes, Int), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def evalsha(self, sha1, numkeys, *keys_and_args):
try:
script = self.script_cache[sha1]
except KeyError:
raise SimpleError(msgs.NO_MATCHING_SCRIPT_MSG)
return self.eval(script, numkeys, *keys_and_args)
@command(
name="script load",
fixed=(bytes,),
repeat=(bytes,),
flags=msgs.FLAG_NO_SCRIPT,
)
def script_load(self, *args):
if len(args) != 1:
raise SimpleError(msgs.BAD_SUBCOMMAND_MSG.format("SCRIPT"))
script = args[0]
sha1 = hashlib.sha1(script).hexdigest().encode()
self.script_cache[sha1] = script
return sha1
@command(
name="script exists",
fixed=(),
repeat=(bytes,),
flags=msgs.FLAG_NO_SCRIPT,
)
def script_exists(self, *args):
if self.version >= (7,) and len(args) == 0:
raise SimpleError(msgs.WRONG_ARGS_MSG7)
return [int(sha1 in self.script_cache) for sha1 in args]
@command(
name="script flush",
fixed=(),
repeat=(bytes,),
flags=msgs.FLAG_NO_SCRIPT,
)
def script_flush(self, *args):
if len(args) > 1 or (
len(args) == 1 and null_terminate(args[0]) not in {b"sync", b"async"}
):
raise SimpleError(msgs.BAD_SUBCOMMAND_MSG.format("SCRIPT"))
self.script_cache = {}
return OK
@command((), flags=msgs.FLAG_NO_SCRIPT)
def script(self, *args):
raise SimpleError(msgs.BAD_SUBCOMMAND_MSG.format("SCRIPT"))
@command(name="SCRIPT HELP", fixed=())
def script_help(self, *args):
help_strings = [
"SCRIPT <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
"DEBUG (YES|SYNC|NO)",
" Set the debug mode for subsequent scripts executed.",
"EXISTS <sha1> [<sha1> ...]",
" Return information about the existence of the scripts in the script cach"
"e.",
"FLUSH [ASYNC|SYNC]",
" Flush the Lua scripts cache. Very dangerous on replicas.",
" When called without the optional mode argument, the behavior is determin"
"ed by the",
" lazyfree-lazy-user-flush configuration directive. Valid modes are:",
" * ASYNC: Asynchronously flush the scripts cache.",
" * SYNC: Synchronously flush the scripts cache.",
"KILL",
" Kill the currently executing Lua script.",
"LOAD <script>",
" Load a script into the scripts cache without executing it.",
"HELP",
(
" Prints this help."
if self.version < (7, 1)
else " Print this help."
),
]
return [s.encode() for s in help_strings]