youtube-summarizer/venv311/lib/python3.11/site-packages/fakeredis/commands_mixins/geo_mixin.py

347 lines
11 KiB
Python

import sys
from collections import namedtuple
from typing import List, Any
from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import command, Key, Float, CommandItem
from fakeredis._helpers import SimpleError, Database
from fakeredis._zset import ZSet
from fakeredis.geo import geohash
from fakeredis.geo.haversine import distance
UNIT_TO_M = {"km": 0.001, "mi": 0.000621371, "ft": 3.28084, "m": 1}
def translate_meters_to_unit(unit_arg: bytes) -> float:
"""number of meters in a unit.
:param unit_arg: unit name (km, mi, ft, m)
:returns: number of meters in unit
"""
unit = UNIT_TO_M.get(unit_arg.decode().lower())
if unit is None:
raise SimpleError(msgs.GEO_UNSUPPORTED_UNIT)
return unit
GeoResult = namedtuple("GeoResult", "name long lat hash distance")
def _parse_results(
items: List[GeoResult], withcoord: bool, withdist: bool
) -> List[Any]:
"""Parse list of GeoResults to redis response
:param withcoord: include coordinates in response
:param withdist: include distance in response
:returns: Parsed list
"""
res = list()
for item in items:
new_item = [
item.name,
]
if withdist:
new_item.append(Float.encode(item.distance, False))
if withcoord:
new_item.append(
[Float.encode(item.long, False), Float.encode(item.lat, False)]
)
if len(new_item) == 1:
new_item = new_item[0]
res.append(new_item)
return res
def _find_near(
zset: ZSet,
lat: float,
long: float,
radius: float,
conv: float,
count: int,
count_any: bool,
desc: bool,
) -> List[GeoResult]:
"""Find items within area (lat,long)+radius
:param zset: list of items to check
:param lat: latitude
:param long: longitude
:param radius: radius in whatever units
:param conv: conversion of radius to meters
:param count: number of results to give
:param count_any: should we return any results that match? (vs. sorted)
:param desc: should results be sorted descending order?
:returns: List of GeoResults
"""
results = list()
for name, _hash in zset.items():
p_lat, p_long, _, _ = geohash.decode(_hash)
dist = distance((p_lat, p_long), (lat, long)) * conv
if dist < radius:
results.append(GeoResult(name, p_long, p_lat, _hash, dist))
if count_any and len(results) >= count:
break
results = sorted(results, key=lambda x: x.distance, reverse=desc)
if count:
results = results[:count]
return results
class GeoCommandsMixin:
_db: Database
def _store_geo_results(
self, item_name: bytes, geo_results: List[GeoResult], scoredist: bool
) -> int:
db_item = CommandItem(
item_name, self._db, item=self._db.get(item_name), default=ZSet()
)
db_item.value = ZSet()
for item in geo_results:
val = item.distance if scoredist else item.hash
db_item.value.add(item.name, val)
db_item.writeback()
return len(geo_results)
@command(name="GEOADD", fixed=(Key(ZSet),), repeat=(bytes,))
def geoadd(self, key, *args):
(xx, nx, ch), data = extract_args(
args,
("nx", "xx", "ch"),
error_on_unexpected=False,
left_from_first_unexpected=True,
)
if xx and nx:
raise SimpleError(msgs.NX_XX_GT_LT_ERROR_MSG)
if len(data) == 0 or len(data) % 3 != 0:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
zset = key.value
old_len, changed_items = len(zset), 0
for i in range(0, len(data), 3):
long, lat, name = (
Float.decode(data[i + 0]),
Float.decode(data[i + 1]),
data[i + 2],
)
if (name in zset and not xx) or (name not in zset and not nx):
if zset.add(name, geohash.encode(lat, long, 10)):
changed_items += 1
if changed_items:
key.updated()
if ch:
return changed_items
return len(zset) - old_len
@command(name="GEOHASH", fixed=(Key(ZSet), bytes), repeat=(bytes,))
def geohash(self, key, *members):
hashes = map(key.value.get, members)
geohash_list = [((x + "0").encode() if x is not None else x) for x in hashes]
return geohash_list
@command(name="GEOPOS", fixed=(Key(ZSet), bytes), repeat=(bytes,))
def geopos(self, key, *members):
gospositions = map(
lambda x: geohash.decode(x) if x is not None else x,
map(key.value.get, members),
)
res = [
(
[
self._encodefloat(x[1], humanfriendly=False),
self._encodefloat(x[0], humanfriendly=False),
]
if x is not None
else None
)
for x in gospositions
]
return res
@command(name="GEODIST", fixed=(Key(ZSet), bytes, bytes), repeat=(bytes,))
def geodist(self, key, m1, m2, *args):
geohashes = [key.value.get(m1), key.value.get(m2)]
if any(elem is None for elem in geohashes):
return None
geo_locs = [geohash.decode(x) for x in geohashes]
res = distance(
(geo_locs[0][0], geo_locs[0][1]), (geo_locs[1][0], geo_locs[1][1])
)
unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1
return res * unit
def _search(
self,
key,
long,
lat,
radius,
conv,
withcoord,
withdist,
_,
count,
count_any,
desc,
store,
storedist,
):
zset = key.value
geo_results = _find_near(zset, lat, long, radius, conv, count, count_any, desc)
if store:
self._store_geo_results(store, geo_results, scoredist=False)
return len(geo_results)
if storedist:
self._store_geo_results(storedist, geo_results, scoredist=True)
return len(geo_results)
ret = _parse_results(geo_results, withcoord, withdist)
return ret
@command(
name="GEORADIUS_RO", fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,)
)
def georadius_ro(self, key, long, lat, radius, *args):
(
withcoord,
withdist,
withhash,
count,
count_any,
desc,
), left_args = extract_args(
args,
(
"withcoord",
"withdist",
"withhash",
"+count",
"any",
"desc",
),
error_on_unexpected=False,
left_from_first_unexpected=False,
)
count = count or sys.maxsize
conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
return self._search(
key,
long,
lat,
radius,
conv,
withcoord,
withdist,
withhash,
count,
count_any,
desc,
False,
False,
)
@command(name="GEORADIUS", fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,))
def georadius(self, key, long, lat, radius, *args):
(
withcoord,
withdist,
withhash,
count,
count_any,
desc,
store,
storedist,
), left_args = extract_args(
args,
(
"withcoord",
"withdist",
"withhash",
"+count",
"any",
"desc",
"*store",
"*storedist",
),
error_on_unexpected=False,
left_from_first_unexpected=False,
)
count = count or sys.maxsize
conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
return self._search(
key,
long,
lat,
radius,
conv,
withcoord,
withdist,
withhash,
count,
count_any,
desc,
store,
storedist,
)
@command(name="GEORADIUSBYMEMBER", fixed=(Key(ZSet), bytes, Float), repeat=(bytes,))
def georadiusbymember(self, key, member_name, radius, *args):
member_score = key.value.get(member_name)
lat, long, _, _ = geohash.decode(member_score)
return self.georadius(key, long, lat, radius, *args)
@command(
name="GEORADIUSBYMEMBER_RO", fixed=(Key(ZSet), bytes, Float), repeat=(bytes,)
)
def georadiusbymember_ro(self, key, member_name, radius, *args):
member_score = key.value.get(member_name)
lat, long, _, _ = geohash.decode(member_score)
return self.georadius_ro(key, long, lat, radius, *args)
@command(name="GEOSEARCH", fixed=(Key(ZSet),), repeat=(bytes,))
def geosearch(self, key, *args):
(frommember, (long, lat), radius), left_args = extract_args(
args,
("*frommember", "..fromlonlat", ".byradius"),
error_on_unexpected=False,
left_from_first_unexpected=False,
)
if frommember is None and long is None:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if frommember is not None and long is not None:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if frommember:
return self.georadiusbymember_ro(key, frommember, radius, *left_args)
else:
return self.georadius_ro(key, long, lat, radius, *left_args)
@command(
name="GEOSEARCHSTORE",
fixed=(
bytes,
Key(ZSet),
),
repeat=(bytes,),
)
def geosearchstore(self, dst, src, *args):
(frommember, (long, lat), radius, storedist), left_args = extract_args(
args,
("*frommember", "..fromlonlat", ".byradius", "storedist"),
error_on_unexpected=False,
left_from_first_unexpected=False,
)
if frommember is None and long is None:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if frommember is not None and long is not None:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
additional = [b"storedist", dst] if storedist else [b"store", dst]
if frommember:
return self.georadiusbymember(
src, frommember, radius, *left_args, *additional
)
else:
return self.georadius(src, long, lat, radius, *left_args, *additional)
def _encodefloat(self, value, humanfriendly):
raise NotImplementedError # Implemented in BaseFakeSocket