141 lines
4.1 KiB
Python
141 lines
4.1 KiB
Python
import itertools
|
|
import math
|
|
import random
|
|
from typing import Callable
|
|
|
|
from fakeredis import _msgs as msgs
|
|
from fakeredis._commands import command, Key, Hash, Int, Float
|
|
from fakeredis._helpers import SimpleError, OK, casematch
|
|
|
|
|
|
class HashCommandsMixin:
|
|
_encodeint: Callable[[int, ], bytes]
|
|
_encodefloat: Callable[[float, bool], bytes]
|
|
|
|
@command((Key(Hash), bytes), (bytes,))
|
|
def hdel(self, key, *fields):
|
|
h = key.value
|
|
rem = 0
|
|
for field in fields:
|
|
if field in h:
|
|
del h[field]
|
|
key.updated()
|
|
rem += 1
|
|
return rem
|
|
|
|
@command((Key(Hash), bytes))
|
|
def hexists(self, key, field):
|
|
return int(field in key.value)
|
|
|
|
@command((Key(Hash), bytes))
|
|
def hget(self, key, field):
|
|
return key.value.get(field)
|
|
|
|
@command((Key(Hash),))
|
|
def hgetall(self, key):
|
|
return list(itertools.chain(*key.value.items()))
|
|
|
|
@command(fixed=(Key(Hash), bytes, bytes))
|
|
def hincrby(self, key, field, amount):
|
|
amount = Int.decode(amount)
|
|
field_value = Int.decode(
|
|
key.value.get(field, b"0"), decode_error=msgs.INVALID_HASH_MSG
|
|
)
|
|
c = field_value + amount
|
|
key.value[field] = self._encodeint(c)
|
|
key.updated()
|
|
return c
|
|
|
|
@command((Key(Hash), bytes, bytes))
|
|
def hincrbyfloat(self, key, field, amount):
|
|
c = Float.decode(key.value.get(field, b"0")) + Float.decode(amount)
|
|
if not math.isfinite(c):
|
|
raise SimpleError(msgs.NONFINITE_MSG)
|
|
encoded = self._encodefloat(c, True)
|
|
key.value[field] = encoded
|
|
key.updated()
|
|
return encoded
|
|
|
|
@command((Key(Hash),))
|
|
def hkeys(self, key):
|
|
return list(key.value.keys())
|
|
|
|
@command((Key(Hash),))
|
|
def hlen(self, key):
|
|
return len(key.value)
|
|
|
|
@command((Key(Hash), bytes), (bytes,))
|
|
def hmget(self, key, *fields):
|
|
return [key.value.get(field) for field in fields]
|
|
|
|
@command((Key(Hash), bytes, bytes), (bytes, bytes))
|
|
def hmset(self, key, *args):
|
|
self.hset(key, *args)
|
|
return OK
|
|
|
|
@command(
|
|
(
|
|
Key(Hash),
|
|
Int,
|
|
),
|
|
(bytes, bytes),
|
|
)
|
|
def hscan(self, key, cursor, *args):
|
|
cursor, keys = self._scan(key.value, cursor, *args)
|
|
items = []
|
|
for k in keys:
|
|
items.append(k)
|
|
items.append(key.value[k])
|
|
return [cursor, items]
|
|
|
|
@command((Key(Hash), bytes, bytes), (bytes, bytes))
|
|
def hset(self, key, *args):
|
|
h = key.value
|
|
keys_count = len(h.keys())
|
|
h.update(
|
|
dict(zip(*[iter(args)] * 2))
|
|
) # https://stackoverflow.com/a/12739974/1056460
|
|
created = len(h.keys()) - keys_count
|
|
|
|
key.updated()
|
|
return created
|
|
|
|
@command((Key(Hash), bytes, bytes))
|
|
def hsetnx(self, key, field, value):
|
|
if field in key.value:
|
|
return 0
|
|
return self.hset(key, field, value)
|
|
|
|
@command((Key(Hash), bytes))
|
|
def hstrlen(self, key, field):
|
|
return len(key.value.get(field, b""))
|
|
|
|
@command((Key(Hash),))
|
|
def hvals(self, key):
|
|
return list(key.value.values())
|
|
|
|
@command(name="HRANDFIELD", fixed=(Key(Hash),), repeat=(bytes,))
|
|
def hrandfield(self, key, *args):
|
|
if len(args) > 2:
|
|
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
|
|
if key.value is None or len(key.value) == 0:
|
|
return None
|
|
count = min(Int.decode(args[0]) if len(args) >= 1 else 1, len(key.value))
|
|
withvalues = casematch(args[1], b"withvalues") if len(args) >= 2 else False
|
|
if count == 0:
|
|
return list()
|
|
|
|
if count < 0: # Allow repetitions
|
|
res = random.choices(sorted(key.value.items()), k=-count)
|
|
else: # Unique values from hash
|
|
res = random.sample(sorted(key.value.items()), count)
|
|
|
|
if withvalues:
|
|
res = [item for t in res for item in t]
|
|
else:
|
|
res = [t[0] for t in res]
|
|
return res
|
|
|
|
def _scan(self, keys, cursor, *args):
|
|
raise NotImplementedError # Implemented in BaseFakeSocket
|