95 lines
3.2 KiB
Python
95 lines
3.2 KiB
Python
from typing import Callable, Set, Any, List, Optional
|
|
|
|
from fakeredis import _msgs as msgs
|
|
from fakeredis._commands import command, Key, CommandItem
|
|
from fakeredis._helpers import OK, SimpleError, Database, SimpleString
|
|
|
|
|
|
class TransactionsCommandsMixin:
|
|
_db: Database
|
|
_run_command: Callable # type: ignore
|
|
|
|
def __init__(self, *args, **kwargs) -> None: # type: ignore
|
|
super(TransactionsCommandsMixin, self).__init__(*args, **kwargs)
|
|
self._watches: Set[Any] = set()
|
|
# When in a MULTI, set to a list of function calls
|
|
self._transaction: Optional[List[Any]] = None
|
|
self._transaction_failed = False
|
|
# Set when executing the commands from EXEC
|
|
self._in_transaction = False
|
|
self._watch_notified = False
|
|
|
|
def _clear_watches(self) -> None:
|
|
self._watch_notified = False
|
|
while self._watches:
|
|
(key, db) = self._watches.pop()
|
|
db.remove_watch(key, self)
|
|
|
|
# Transaction commands
|
|
@command((), flags=[msgs.FLAG_NO_SCRIPT, msgs.FLAG_TRANSACTION])
|
|
def discard(self) -> SimpleString:
|
|
if self._transaction is None:
|
|
raise SimpleError(msgs.WITHOUT_MULTI_MSG.format("DISCARD"))
|
|
self._transaction = None
|
|
self._transaction_failed = False
|
|
self._clear_watches()
|
|
return OK
|
|
|
|
@command(
|
|
name="exec",
|
|
fixed=(),
|
|
repeat=(),
|
|
flags=[msgs.FLAG_NO_SCRIPT, msgs.FLAG_TRANSACTION],
|
|
)
|
|
def exec_(self) -> Any:
|
|
if self._transaction is None:
|
|
raise SimpleError(msgs.WITHOUT_MULTI_MSG.format("EXEC"))
|
|
if self._transaction_failed:
|
|
self._transaction = None
|
|
self._clear_watches()
|
|
raise SimpleError(msgs.EXECABORT_MSG)
|
|
transaction = self._transaction
|
|
self._transaction = None
|
|
self._transaction_failed = False
|
|
watch_notified = self._watch_notified
|
|
self._clear_watches()
|
|
if watch_notified:
|
|
return None
|
|
result = []
|
|
for func, sig, args in transaction:
|
|
try:
|
|
self._in_transaction = True
|
|
ans = self._run_command(func, sig, args, False)
|
|
except SimpleError as exc:
|
|
ans = exc
|
|
finally:
|
|
self._in_transaction = False
|
|
result.append(ans)
|
|
return result
|
|
|
|
@command((), flags=[msgs.FLAG_NO_SCRIPT, msgs.FLAG_TRANSACTION])
|
|
def multi(self) -> SimpleString:
|
|
if self._transaction is not None:
|
|
raise SimpleError(msgs.MULTI_NESTED_MSG)
|
|
self._transaction = []
|
|
self._transaction_failed = False
|
|
return OK
|
|
|
|
@command((), flags=msgs.FLAG_NO_SCRIPT)
|
|
def unwatch(self) -> SimpleString:
|
|
self._clear_watches()
|
|
return OK
|
|
|
|
@command((Key(),), (Key(),), flags=[msgs.FLAG_NO_SCRIPT, msgs.FLAG_TRANSACTION])
|
|
def watch(self, *keys: CommandItem) -> SimpleString:
|
|
if self._transaction is not None:
|
|
raise SimpleError(msgs.WATCH_INSIDE_MULTI_MSG)
|
|
for key in keys:
|
|
if key not in self._watches:
|
|
self._watches.add((key.key, self._db))
|
|
self._db.add_watch(key.key, self)
|
|
return OK
|
|
|
|
def notify_watch(self) -> None:
|
|
self._watch_notified = True
|