599 lines
19 KiB
Python
599 lines
19 KiB
Python
"""Command mixin for emulating `redis-py`'s JSON functionality."""
|
|
|
|
import copy
|
|
import json
|
|
from json import JSONDecodeError
|
|
from typing import Any, Union, Dict, List, Optional
|
|
|
|
from jsonpath_ng import Root, JSONPath
|
|
from jsonpath_ng.exceptions import JsonPathParserError
|
|
from jsonpath_ng.ext import parse
|
|
|
|
from fakeredis import _helpers as helpers
|
|
from fakeredis import _msgs as msgs
|
|
from fakeredis._command_args_parsing import extract_args
|
|
from fakeredis._commands import Key, command, delete_keys, CommandItem, Int, Float
|
|
from fakeredis._zset import ZSet
|
|
|
|
JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
|
|
|
|
|
|
def _format_path(path) -> str:
|
|
if isinstance(path, bytes):
|
|
path = path.decode()
|
|
if path == ".":
|
|
return "$"
|
|
elif path.startswith("."):
|
|
return "$" + path
|
|
elif path.startswith("$"):
|
|
return path
|
|
else:
|
|
return "$." + path
|
|
|
|
|
|
def _parse_jsonpath(path: Union[str, bytes]):
|
|
path = _format_path(path)
|
|
try:
|
|
return parse(path)
|
|
except JsonPathParserError:
|
|
raise helpers.SimpleError(msgs.JSON_PATH_DOES_NOT_EXIST.format(path))
|
|
|
|
|
|
def _path_is_root(path: JSONPath) -> bool:
|
|
return path == Root()
|
|
|
|
|
|
def _dict_deep_merge(source: JsonType, destination: Dict) -> Dict:
|
|
"""Deep merge of two dictionaries"""
|
|
if not isinstance(source, dict):
|
|
return destination
|
|
for key, value in source.items():
|
|
if value is None and key in destination:
|
|
del destination[key]
|
|
elif isinstance(value, dict):
|
|
node = destination.setdefault(key, {})
|
|
_dict_deep_merge(value, node)
|
|
else:
|
|
destination[key] = value
|
|
|
|
return destination
|
|
|
|
|
|
class JSONObject:
|
|
"""Argument converter for JSON objects."""
|
|
|
|
DECODE_ERROR = msgs.JSON_WRONG_REDIS_TYPE
|
|
ENCODE_ERROR = msgs.JSON_WRONG_REDIS_TYPE
|
|
|
|
@classmethod
|
|
def decode(cls, value: bytes) -> Any:
|
|
"""Deserialize the supplied bytes into a valid Python object."""
|
|
try:
|
|
return json.loads(value)
|
|
except JSONDecodeError:
|
|
raise helpers.SimpleError(cls.DECODE_ERROR)
|
|
|
|
@classmethod
|
|
def encode(cls, value: Any) -> Optional[bytes]:
|
|
"""Serialize the supplied Python object into a valid, JSON-formatted
|
|
byte-encoded string."""
|
|
return json.dumps(value, default=str).encode() if value is not None else None
|
|
|
|
|
|
def _json_write_iterate(method, key, path_str, **kwargs):
|
|
"""Implement json.* write commands.
|
|
Iterate over values with path_str in key and running method to get new value for path item.
|
|
"""
|
|
if key.value is None:
|
|
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
|
|
path = _parse_jsonpath(path_str)
|
|
found_matches = path.find(key.value)
|
|
if len(found_matches) == 0:
|
|
raise helpers.SimpleError(
|
|
msgs.JSON_PATH_NOT_FOUND_OR_NOT_STRING.format(path_str)
|
|
)
|
|
|
|
curr_value = copy.deepcopy(key.value)
|
|
res = list()
|
|
for item in found_matches:
|
|
new_value, res_val, update = method(item.value)
|
|
if update:
|
|
curr_value = item.full_path.update(curr_value, new_value)
|
|
res.append(res_val)
|
|
|
|
key.update(curr_value)
|
|
|
|
if len(path_str) > 1 and path_str[0] == ord(b"."):
|
|
if kwargs.get("allow_result_none", False):
|
|
return res[-1]
|
|
else:
|
|
return next(x for x in reversed(res) if x is not None)
|
|
if len(res) == 1 and path_str[0] != ord(b"$"):
|
|
return res[0]
|
|
return res
|
|
|
|
|
|
def _json_read_iterate(method, key, *args, error_on_zero_matches=False):
|
|
path_str = args[0] if len(args) > 0 else "$"
|
|
if key.value is None:
|
|
if path_str[0] == 36:
|
|
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
|
|
else:
|
|
return None
|
|
|
|
path = _parse_jsonpath(path_str)
|
|
found_matches = path.find(key.value)
|
|
if error_on_zero_matches and len(found_matches) == 0 and path_str[0] != 36:
|
|
raise helpers.SimpleError(
|
|
msgs.JSON_PATH_NOT_FOUND_OR_NOT_STRING.format(path_str)
|
|
)
|
|
res = list()
|
|
for item in found_matches:
|
|
res.append(method(item.value))
|
|
|
|
if path_str[0] == 46:
|
|
return res[0] if len(res) > 0 else None
|
|
if len(res) == 1 and (len(args) == 0 or (len(args) == 1 and args[0][0] == 46)):
|
|
return res[0]
|
|
|
|
return res
|
|
|
|
|
|
class JSONCommandsMixin:
|
|
"""`CommandsMixin` for enabling RedisJSON compatibility in `fakeredis`."""
|
|
|
|
TYPES_EMPTY_VAL_DICT = {
|
|
dict: {},
|
|
int: 0,
|
|
float: 0.0,
|
|
list: [],
|
|
}
|
|
TYPE_NAMES = {
|
|
dict: b"object",
|
|
int: b"integer",
|
|
float: b"number",
|
|
bytes: b"string",
|
|
list: b"array",
|
|
set: b"set",
|
|
str: b"string",
|
|
bool: b"boolean",
|
|
type(None): b"null",
|
|
ZSet: "zset",
|
|
}
|
|
|
|
_db: helpers.Database
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def _get_single(
|
|
key,
|
|
path_str: str,
|
|
always_return_list: bool = False,
|
|
empty_list_as_none: bool = False,
|
|
) -> Any:
|
|
path = _parse_jsonpath(path_str)
|
|
path_value = path.find(key.value)
|
|
val = [i.value for i in path_value]
|
|
if empty_list_as_none and len(val) == 0:
|
|
return None
|
|
elif len(val) == 1 and not always_return_list:
|
|
return val[0]
|
|
return val
|
|
|
|
@command(
|
|
name=["JSON.DEL", "JSON.FORGET"],
|
|
fixed=(Key(),),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_del(self, key, path_str) -> int:
|
|
if key.value is None:
|
|
return 0
|
|
|
|
path = _parse_jsonpath(path_str)
|
|
if _path_is_root(path):
|
|
delete_keys(key)
|
|
return 1
|
|
curr_value = copy.deepcopy(key.value)
|
|
|
|
found_matches = path.find(curr_value)
|
|
res = 0
|
|
while len(found_matches) > 0:
|
|
item = found_matches[0]
|
|
curr_value = item.full_path.filter(lambda _: True, curr_value)
|
|
res += 1
|
|
found_matches = path.find(curr_value)
|
|
|
|
key.update(curr_value)
|
|
return res
|
|
|
|
@staticmethod
|
|
def _json_set(key: CommandItem, path_str: bytes, value: JsonType, *args):
|
|
path = _parse_jsonpath(path_str)
|
|
if (
|
|
key.value is not None
|
|
and (type(key.value) is not dict)
|
|
and not _path_is_root(path)
|
|
):
|
|
raise helpers.SimpleError(msgs.JSON_WRONG_REDIS_TYPE)
|
|
old_value = path.find(key.value)
|
|
(nx, xx), _ = extract_args(args, ("nx", "xx"))
|
|
if xx and nx:
|
|
raise helpers.SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
if (nx and old_value) or (xx and not old_value):
|
|
return None
|
|
new_value = path.update_or_create(key.value, value)
|
|
key.update(new_value)
|
|
return helpers.OK
|
|
|
|
@command(
|
|
name="JSON.SET",
|
|
fixed=(Key(), bytes, JSONObject),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_set(self, key, path_str: bytes, value: JsonType, *args):
|
|
"""Set the JSON value at key `name` under the `path` to `obj`.
|
|
|
|
For more information see `JSON.SET <https://redis.io/commands/json.set>`_.
|
|
"""
|
|
return JSONCommandsMixin._json_set(key, path_str, value, *args)
|
|
|
|
@command(
|
|
name="JSON.GET",
|
|
fixed=(Key(),),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_get(self, key, *args) -> Optional[bytes]:
|
|
if key.value is None:
|
|
return None
|
|
paths = [arg for arg in args if not helpers.casematch(b"noescape", arg)]
|
|
no_wrapping_array = len(paths) == 1 and paths[0][0] == ord(b".")
|
|
|
|
formatted_paths = [
|
|
_format_path(arg) for arg in args if not helpers.casematch(b"noescape", arg)
|
|
]
|
|
path_values = [
|
|
self._get_single(key, path, len(formatted_paths) > 1)
|
|
for path in formatted_paths
|
|
]
|
|
|
|
# Emulate the behavior of `redis-py`:
|
|
# - if only one path was supplied => return a single value
|
|
# - if more than one path was specified => return one value for each specified path
|
|
if no_wrapping_array or (
|
|
len(path_values) == 1 and isinstance(path_values[0], list)
|
|
):
|
|
return JSONObject.encode(path_values[0])
|
|
if len(path_values) == 1:
|
|
return JSONObject.encode(path_values)
|
|
return JSONObject.encode(dict(zip(formatted_paths, path_values)))
|
|
|
|
@command(
|
|
name="JSON.MGET",
|
|
fixed=(bytes,),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_mget(self, *args):
|
|
if len(args) < 2:
|
|
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.mget"))
|
|
path_str = args[-1]
|
|
keys = [
|
|
CommandItem(key, self._db, item=self._db.get(key), default=[])
|
|
for key in args[:-1]
|
|
]
|
|
|
|
result = [
|
|
JSONObject.encode(self._get_single(key, path_str, empty_list_as_none=True))
|
|
for key in keys
|
|
]
|
|
return result
|
|
|
|
@command(
|
|
name="JSON.TOGGLE",
|
|
fixed=(Key(),),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_toggle(self, key, *args):
|
|
if key.value is None:
|
|
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
|
|
path_str = args[0] if len(args) > 0 else "$"
|
|
path = _parse_jsonpath(path_str)
|
|
found_matches = path.find(key.value)
|
|
|
|
curr_value = copy.deepcopy(key.value)
|
|
res: List[Optional[bool]] = list()
|
|
for item in found_matches:
|
|
if type(item.value) is bool:
|
|
curr_value = item.full_path.update(curr_value, not item.value)
|
|
res.append(not item.value)
|
|
else:
|
|
res.append(None)
|
|
if all([x is None for x in res]):
|
|
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
|
|
key.update(curr_value)
|
|
|
|
if len(res) == 1 and (len(args) == 0 or (len(args) == 1 and args[0] == b".")):
|
|
return res[0]
|
|
|
|
return res
|
|
|
|
@command(
|
|
name="JSON.CLEAR",
|
|
fixed=(Key(),),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_clear(
|
|
self,
|
|
key,
|
|
*args,
|
|
):
|
|
if key.value is None:
|
|
raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND)
|
|
path_str = args[0] if len(args) > 0 else "$"
|
|
path = _parse_jsonpath(path_str)
|
|
found_matches = path.find(key.value)
|
|
curr_value = copy.deepcopy(key.value)
|
|
res = 0
|
|
for item in found_matches:
|
|
new_val = self.TYPES_EMPTY_VAL_DICT.get(type(item.value), None)
|
|
if new_val is not None:
|
|
curr_value = item.full_path.update(curr_value, new_val)
|
|
res += 1
|
|
|
|
key.update(curr_value)
|
|
return res
|
|
|
|
@command(
|
|
name="JSON.STRAPPEND",
|
|
fixed=(Key(), bytes),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_strappend(self, key, path_str, *args):
|
|
if len(args) == 0:
|
|
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.strappend"))
|
|
addition = JSONObject.decode(args[0])
|
|
|
|
def strappend(val):
|
|
if type(val) is str:
|
|
new_value = val + addition
|
|
return new_value, len(new_value), True
|
|
else:
|
|
return None, None, False
|
|
|
|
return _json_write_iterate(strappend, key, path_str)
|
|
|
|
@command(
|
|
name="JSON.ARRAPPEND",
|
|
fixed=(
|
|
Key(),
|
|
bytes,
|
|
),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_arrappend(self, key, path_str, *args):
|
|
if len(args) == 0:
|
|
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.arrappend"))
|
|
|
|
addition = [JSONObject.decode(item) for item in args]
|
|
|
|
def arrappend(val):
|
|
if type(val) is list:
|
|
new_value = val + addition
|
|
return new_value, len(new_value), True
|
|
else:
|
|
return None, None, False
|
|
|
|
return _json_write_iterate(arrappend, key, path_str)
|
|
|
|
@command(
|
|
name="JSON.ARRINSERT",
|
|
fixed=(Key(), bytes, Int),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_arrinsert(self, key, path_str, index, *args):
|
|
if len(args) == 0:
|
|
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.arrinsert"))
|
|
|
|
addition = [JSONObject.decode(item) for item in args]
|
|
|
|
def arrinsert(val):
|
|
if type(val) is list:
|
|
new_value = val[:index] + addition + val[index:]
|
|
return new_value, len(new_value), True
|
|
else:
|
|
return None, None, False
|
|
|
|
return _json_write_iterate(arrinsert, key, path_str)
|
|
|
|
@command(
|
|
name="JSON.ARRPOP",
|
|
fixed=(Key(),),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_arrpop(self, key, *args):
|
|
path_str = args[0] if len(args) > 0 else "$"
|
|
index = Int.decode(args[1]) if len(args) > 1 else -1
|
|
|
|
def arrpop(val):
|
|
if type(val) is list and len(val) > 0:
|
|
ind = index if index < len(val) else -1
|
|
res = val.pop(ind)
|
|
return val, JSONObject.encode(res), True
|
|
else:
|
|
return None, None, False
|
|
|
|
return _json_write_iterate(arrpop, key, path_str, allow_result_none=True)
|
|
|
|
@command(
|
|
name="JSON.ARRTRIM",
|
|
fixed=(Key(),),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_arrtrim(self, key, *args):
|
|
path_str = args[0] if len(args) > 0 else "$"
|
|
start = Int.decode(args[1]) if len(args) > 1 else 0
|
|
stop = Int.decode(args[2]) if len(args) > 2 else None
|
|
|
|
def arrtrim(val):
|
|
if type(val) is list:
|
|
start_ind = min(start, len(val))
|
|
stop_ind = len(val) if stop is None or stop == -1 else stop + 1
|
|
if stop_ind < 0:
|
|
stop_ind = len(val) + stop_ind + 1
|
|
new_val = val[start_ind:stop_ind]
|
|
return new_val, len(new_val), True
|
|
else:
|
|
return None, None, False
|
|
|
|
return _json_write_iterate(arrtrim, key, path_str)
|
|
|
|
@command(
|
|
name="JSON.NUMINCRBY",
|
|
fixed=(Key(), bytes, Float),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_numincrby(self, key, path_str, inc_by, *_):
|
|
def numincrby(val):
|
|
if type(val) in {int, float}:
|
|
new_value = val + inc_by
|
|
return new_value, new_value, True
|
|
else:
|
|
return None, None, False
|
|
|
|
return _json_write_iterate(numincrby, key, path_str)
|
|
|
|
@command(
|
|
name="JSON.NUMMULTBY",
|
|
fixed=(Key(), bytes, Float),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_nummultby(self, key, path_str, mult_by, *_):
|
|
def nummultby(val):
|
|
if type(val) in {int, float}:
|
|
new_value = val * mult_by
|
|
return new_value, new_value, True
|
|
else:
|
|
return None, None, False
|
|
|
|
return _json_write_iterate(nummultby, key, path_str)
|
|
|
|
# Read operations
|
|
@command(
|
|
name="JSON.ARRINDEX",
|
|
fixed=(Key(), bytes, bytes),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_arrindex(self, key, path_str, encoded_value, *args):
|
|
start = max(0, Int.decode(args[0]) if len(args) > 0 else 0)
|
|
end = Int.decode(args[1]) if len(args) > 1 else -1
|
|
end = end if end > 0 else -1
|
|
expected_value = JSONObject.decode(encoded_value)
|
|
|
|
def check_index(value):
|
|
if type(value) is not list:
|
|
return None
|
|
try:
|
|
ind = next(
|
|
filter(lambda x: x[1] == expected_value and type(x[1]) is type(expected_value),
|
|
enumerate(value[start:end]))
|
|
)
|
|
return ind[0] + start
|
|
except StopIteration:
|
|
return -1
|
|
|
|
return _json_read_iterate(
|
|
check_index, key, path_str, *args, error_on_zero_matches=True
|
|
)
|
|
|
|
@command(name="JSON.STRLEN", fixed=(Key(),), repeat=(bytes,))
|
|
def json_strlen(self, key, *args):
|
|
return _json_read_iterate(
|
|
lambda val: len(val) if type(val) is str else None, key, *args
|
|
)
|
|
|
|
@command(name="JSON.ARRLEN", fixed=(Key(),), repeat=(bytes,))
|
|
def json_arrlen(self, key, *args):
|
|
return _json_read_iterate(
|
|
lambda val: len(val) if type(val) is list else None, key, *args
|
|
)
|
|
|
|
@command(name="JSON.OBJLEN", fixed=(Key(),), repeat=(bytes,))
|
|
def json_objlen(self, key, *args):
|
|
return _json_read_iterate(
|
|
lambda val: len(val) if type(val) is dict else None, key, *args
|
|
)
|
|
|
|
@command(
|
|
name="JSON.TYPE",
|
|
fixed=(Key(),),
|
|
repeat=(bytes,),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_type(
|
|
self,
|
|
key,
|
|
*args,
|
|
):
|
|
return _json_read_iterate(
|
|
lambda val: self.TYPE_NAMES.get(type(val), None), key, *args
|
|
)
|
|
|
|
@command(name="JSON.OBJKEYS", fixed=(Key(),), repeat=(bytes,))
|
|
def json_objkeys(self, key, *args):
|
|
return _json_read_iterate(
|
|
lambda val: [i.encode() for i in val.keys()] if type(val) is dict else None,
|
|
key,
|
|
*args,
|
|
)
|
|
|
|
@command(
|
|
name="JSON.MSET",
|
|
fixed=(),
|
|
repeat=(Key(), bytes, JSONObject),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_mset(self, *args):
|
|
if len(args) < 3 or len(args) % 3 != 0:
|
|
raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.mset"))
|
|
for i in range(0, len(args), 3):
|
|
key, path_str, value = args[i], args[i + 1], args[i + 2]
|
|
JSONCommandsMixin._json_set(key, path_str, value)
|
|
return helpers.OK
|
|
|
|
@command(
|
|
name="JSON.MERGE",
|
|
fixed=(Key(), bytes, JSONObject),
|
|
repeat=(),
|
|
flags=msgs.FLAG_LEAVE_EMPTY_VAL,
|
|
)
|
|
def json_merge(self, key, path_str: bytes, value: JsonType):
|
|
path: JSONPath = _parse_jsonpath(path_str)
|
|
if (
|
|
key.value is not None
|
|
and (type(key.value) is not dict)
|
|
and not _path_is_root(path)
|
|
):
|
|
raise helpers.SimpleError(msgs.JSON_WRONG_REDIS_TYPE)
|
|
matching = path.find(key.value)
|
|
for item in matching:
|
|
prev_value = item.value if item is not None else dict()
|
|
_dict_deep_merge(value, prev_value)
|
|
if len(matching) > 0:
|
|
key.updated()
|
|
return helpers.OK
|