"""Command mixin for emulating `redis-py`'s JSON functionality.""" import copy import json from json import JSONDecodeError from typing import Any, Union, Dict, List, Optional from jsonpath_ng import Root, JSONPath from jsonpath_ng.exceptions import JsonPathParserError from jsonpath_ng.ext import parse from fakeredis import _helpers as helpers from fakeredis import _msgs as msgs from fakeredis._command_args_parsing import extract_args from fakeredis._commands import Key, command, delete_keys, CommandItem, Int, Float from fakeredis._zset import ZSet JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] def _format_path(path) -> str: if isinstance(path, bytes): path = path.decode() if path == ".": return "$" elif path.startswith("."): return "$" + path elif path.startswith("$"): return path else: return "$." + path def _parse_jsonpath(path: Union[str, bytes]): path = _format_path(path) try: return parse(path) except JsonPathParserError: raise helpers.SimpleError(msgs.JSON_PATH_DOES_NOT_EXIST.format(path)) def _path_is_root(path: JSONPath) -> bool: return path == Root() def _dict_deep_merge(source: JsonType, destination: Dict) -> Dict: """Deep merge of two dictionaries""" if not isinstance(source, dict): return destination for key, value in source.items(): if value is None and key in destination: del destination[key] elif isinstance(value, dict): node = destination.setdefault(key, {}) _dict_deep_merge(value, node) else: destination[key] = value return destination class JSONObject: """Argument converter for JSON objects.""" DECODE_ERROR = msgs.JSON_WRONG_REDIS_TYPE ENCODE_ERROR = msgs.JSON_WRONG_REDIS_TYPE @classmethod def decode(cls, value: bytes) -> Any: """Deserialize the supplied bytes into a valid Python object.""" try: return json.loads(value) except JSONDecodeError: raise helpers.SimpleError(cls.DECODE_ERROR) @classmethod def encode(cls, value: Any) -> Optional[bytes]: """Serialize the supplied Python object into a valid, JSON-formatted byte-encoded string.""" return json.dumps(value, default=str).encode() if value is not None else None def _json_write_iterate(method, key, path_str, **kwargs): """Implement json.* write commands. Iterate over values with path_str in key and running method to get new value for path item. """ if key.value is None: raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND) path = _parse_jsonpath(path_str) found_matches = path.find(key.value) if len(found_matches) == 0: raise helpers.SimpleError( msgs.JSON_PATH_NOT_FOUND_OR_NOT_STRING.format(path_str) ) curr_value = copy.deepcopy(key.value) res = list() for item in found_matches: new_value, res_val, update = method(item.value) if update: curr_value = item.full_path.update(curr_value, new_value) res.append(res_val) key.update(curr_value) if len(path_str) > 1 and path_str[0] == ord(b"."): if kwargs.get("allow_result_none", False): return res[-1] else: return next(x for x in reversed(res) if x is not None) if len(res) == 1 and path_str[0] != ord(b"$"): return res[0] return res def _json_read_iterate(method, key, *args, error_on_zero_matches=False): path_str = args[0] if len(args) > 0 else "$" if key.value is None: if path_str[0] == 36: raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND) else: return None path = _parse_jsonpath(path_str) found_matches = path.find(key.value) if error_on_zero_matches and len(found_matches) == 0 and path_str[0] != 36: raise helpers.SimpleError( msgs.JSON_PATH_NOT_FOUND_OR_NOT_STRING.format(path_str) ) res = list() for item in found_matches: res.append(method(item.value)) if path_str[0] == 46: return res[0] if len(res) > 0 else None if len(res) == 1 and (len(args) == 0 or (len(args) == 1 and args[0][0] == 46)): return res[0] return res class JSONCommandsMixin: """`CommandsMixin` for enabling RedisJSON compatibility in `fakeredis`.""" TYPES_EMPTY_VAL_DICT = { dict: {}, int: 0, float: 0.0, list: [], } TYPE_NAMES = { dict: b"object", int: b"integer", float: b"number", bytes: b"string", list: b"array", set: b"set", str: b"string", bool: b"boolean", type(None): b"null", ZSet: "zset", } _db: helpers.Database def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @staticmethod def _get_single( key, path_str: str, always_return_list: bool = False, empty_list_as_none: bool = False, ) -> Any: path = _parse_jsonpath(path_str) path_value = path.find(key.value) val = [i.value for i in path_value] if empty_list_as_none and len(val) == 0: return None elif len(val) == 1 and not always_return_list: return val[0] return val @command( name=["JSON.DEL", "JSON.FORGET"], fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_del(self, key, path_str) -> int: if key.value is None: return 0 path = _parse_jsonpath(path_str) if _path_is_root(path): delete_keys(key) return 1 curr_value = copy.deepcopy(key.value) found_matches = path.find(curr_value) res = 0 while len(found_matches) > 0: item = found_matches[0] curr_value = item.full_path.filter(lambda _: True, curr_value) res += 1 found_matches = path.find(curr_value) key.update(curr_value) return res @staticmethod def _json_set(key: CommandItem, path_str: bytes, value: JsonType, *args): path = _parse_jsonpath(path_str) if ( key.value is not None and (type(key.value) is not dict) and not _path_is_root(path) ): raise helpers.SimpleError(msgs.JSON_WRONG_REDIS_TYPE) old_value = path.find(key.value) (nx, xx), _ = extract_args(args, ("nx", "xx")) if xx and nx: raise helpers.SimpleError(msgs.SYNTAX_ERROR_MSG) if (nx and old_value) or (xx and not old_value): return None new_value = path.update_or_create(key.value, value) key.update(new_value) return helpers.OK @command( name="JSON.SET", fixed=(Key(), bytes, JSONObject), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_set(self, key, path_str: bytes, value: JsonType, *args): """Set the JSON value at key `name` under the `path` to `obj`. For more information see `JSON.SET `_. """ return JSONCommandsMixin._json_set(key, path_str, value, *args) @command( name="JSON.GET", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_get(self, key, *args) -> Optional[bytes]: if key.value is None: return None paths = [arg for arg in args if not helpers.casematch(b"noescape", arg)] no_wrapping_array = len(paths) == 1 and paths[0][0] == ord(b".") formatted_paths = [ _format_path(arg) for arg in args if not helpers.casematch(b"noescape", arg) ] path_values = [ self._get_single(key, path, len(formatted_paths) > 1) for path in formatted_paths ] # Emulate the behavior of `redis-py`: # - if only one path was supplied => return a single value # - if more than one path was specified => return one value for each specified path if no_wrapping_array or ( len(path_values) == 1 and isinstance(path_values[0], list) ): return JSONObject.encode(path_values[0]) if len(path_values) == 1: return JSONObject.encode(path_values) return JSONObject.encode(dict(zip(formatted_paths, path_values))) @command( name="JSON.MGET", fixed=(bytes,), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_mget(self, *args): if len(args) < 2: raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.mget")) path_str = args[-1] keys = [ CommandItem(key, self._db, item=self._db.get(key), default=[]) for key in args[:-1] ] result = [ JSONObject.encode(self._get_single(key, path_str, empty_list_as_none=True)) for key in keys ] return result @command( name="JSON.TOGGLE", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_toggle(self, key, *args): if key.value is None: raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND) path_str = args[0] if len(args) > 0 else "$" path = _parse_jsonpath(path_str) found_matches = path.find(key.value) curr_value = copy.deepcopy(key.value) res: List[Optional[bool]] = list() for item in found_matches: if type(item.value) is bool: curr_value = item.full_path.update(curr_value, not item.value) res.append(not item.value) else: res.append(None) if all([x is None for x in res]): raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND) key.update(curr_value) if len(res) == 1 and (len(args) == 0 or (len(args) == 1 and args[0] == b".")): return res[0] return res @command( name="JSON.CLEAR", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_clear( self, key, *args, ): if key.value is None: raise helpers.SimpleError(msgs.JSON_KEY_NOT_FOUND) path_str = args[0] if len(args) > 0 else "$" path = _parse_jsonpath(path_str) found_matches = path.find(key.value) curr_value = copy.deepcopy(key.value) res = 0 for item in found_matches: new_val = self.TYPES_EMPTY_VAL_DICT.get(type(item.value), None) if new_val is not None: curr_value = item.full_path.update(curr_value, new_val) res += 1 key.update(curr_value) return res @command( name="JSON.STRAPPEND", fixed=(Key(), bytes), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_strappend(self, key, path_str, *args): if len(args) == 0: raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.strappend")) addition = JSONObject.decode(args[0]) def strappend(val): if type(val) is str: new_value = val + addition return new_value, len(new_value), True else: return None, None, False return _json_write_iterate(strappend, key, path_str) @command( name="JSON.ARRAPPEND", fixed=( Key(), bytes, ), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_arrappend(self, key, path_str, *args): if len(args) == 0: raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.arrappend")) addition = [JSONObject.decode(item) for item in args] def arrappend(val): if type(val) is list: new_value = val + addition return new_value, len(new_value), True else: return None, None, False return _json_write_iterate(arrappend, key, path_str) @command( name="JSON.ARRINSERT", fixed=(Key(), bytes, Int), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_arrinsert(self, key, path_str, index, *args): if len(args) == 0: raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.arrinsert")) addition = [JSONObject.decode(item) for item in args] def arrinsert(val): if type(val) is list: new_value = val[:index] + addition + val[index:] return new_value, len(new_value), True else: return None, None, False return _json_write_iterate(arrinsert, key, path_str) @command( name="JSON.ARRPOP", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_arrpop(self, key, *args): path_str = args[0] if len(args) > 0 else "$" index = Int.decode(args[1]) if len(args) > 1 else -1 def arrpop(val): if type(val) is list and len(val) > 0: ind = index if index < len(val) else -1 res = val.pop(ind) return val, JSONObject.encode(res), True else: return None, None, False return _json_write_iterate(arrpop, key, path_str, allow_result_none=True) @command( name="JSON.ARRTRIM", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_arrtrim(self, key, *args): path_str = args[0] if len(args) > 0 else "$" start = Int.decode(args[1]) if len(args) > 1 else 0 stop = Int.decode(args[2]) if len(args) > 2 else None def arrtrim(val): if type(val) is list: start_ind = min(start, len(val)) stop_ind = len(val) if stop is None or stop == -1 else stop + 1 if stop_ind < 0: stop_ind = len(val) + stop_ind + 1 new_val = val[start_ind:stop_ind] return new_val, len(new_val), True else: return None, None, False return _json_write_iterate(arrtrim, key, path_str) @command( name="JSON.NUMINCRBY", fixed=(Key(), bytes, Float), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_numincrby(self, key, path_str, inc_by, *_): def numincrby(val): if type(val) in {int, float}: new_value = val + inc_by return new_value, new_value, True else: return None, None, False return _json_write_iterate(numincrby, key, path_str) @command( name="JSON.NUMMULTBY", fixed=(Key(), bytes, Float), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_nummultby(self, key, path_str, mult_by, *_): def nummultby(val): if type(val) in {int, float}: new_value = val * mult_by return new_value, new_value, True else: return None, None, False return _json_write_iterate(nummultby, key, path_str) # Read operations @command( name="JSON.ARRINDEX", fixed=(Key(), bytes, bytes), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_arrindex(self, key, path_str, encoded_value, *args): start = max(0, Int.decode(args[0]) if len(args) > 0 else 0) end = Int.decode(args[1]) if len(args) > 1 else -1 end = end if end > 0 else -1 expected_value = JSONObject.decode(encoded_value) def check_index(value): if type(value) is not list: return None try: ind = next( filter(lambda x: x[1] == expected_value and type(x[1]) is type(expected_value), enumerate(value[start:end])) ) return ind[0] + start except StopIteration: return -1 return _json_read_iterate( check_index, key, path_str, *args, error_on_zero_matches=True ) @command(name="JSON.STRLEN", fixed=(Key(),), repeat=(bytes,)) def json_strlen(self, key, *args): return _json_read_iterate( lambda val: len(val) if type(val) is str else None, key, *args ) @command(name="JSON.ARRLEN", fixed=(Key(),), repeat=(bytes,)) def json_arrlen(self, key, *args): return _json_read_iterate( lambda val: len(val) if type(val) is list else None, key, *args ) @command(name="JSON.OBJLEN", fixed=(Key(),), repeat=(bytes,)) def json_objlen(self, key, *args): return _json_read_iterate( lambda val: len(val) if type(val) is dict else None, key, *args ) @command( name="JSON.TYPE", fixed=(Key(),), repeat=(bytes,), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_type( self, key, *args, ): return _json_read_iterate( lambda val: self.TYPE_NAMES.get(type(val), None), key, *args ) @command(name="JSON.OBJKEYS", fixed=(Key(),), repeat=(bytes,)) def json_objkeys(self, key, *args): return _json_read_iterate( lambda val: [i.encode() for i in val.keys()] if type(val) is dict else None, key, *args, ) @command( name="JSON.MSET", fixed=(), repeat=(Key(), bytes, JSONObject), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_mset(self, *args): if len(args) < 3 or len(args) % 3 != 0: raise helpers.SimpleError(msgs.WRONG_ARGS_MSG6.format("json.mset")) for i in range(0, len(args), 3): key, path_str, value = args[i], args[i + 1], args[i + 2] JSONCommandsMixin._json_set(key, path_str, value) return helpers.OK @command( name="JSON.MERGE", fixed=(Key(), bytes, JSONObject), repeat=(), flags=msgs.FLAG_LEAVE_EMPTY_VAL, ) def json_merge(self, key, path_str: bytes, value: JsonType): path: JSONPath = _parse_jsonpath(path_str) if ( key.value is not None and (type(key.value) is not dict) and not _path_is_root(path) ): raise helpers.SimpleError(msgs.JSON_WRONG_REDIS_TYPE) matching = path.find(key.value) for item in matching: prev_value = item.value if item is not None else dict() _dict_deep_merge(value, prev_value) if len(matching) > 0: key.updated() return helpers.OK