Skip to content

Commit

Permalink
Refactor redis lock to provide base class (PP-1472) (#1990)
Browse files Browse the repository at this point in the history
This refactors our redis lock class, to provide an abstract base class BaseRedisLock. It also refactors RedisLock, so that the acquire method is non-blocking by default and creates a new acquire_blocking method that provides a blocking acquire.
  • Loading branch information
jonathangreen authored Aug 20, 2024
1 parent 5b665cc commit 38f2fef
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 104 deletions.
190 changes: 124 additions & 66 deletions src/palace/manager/service/redis/models/lock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import random
import time
from abc import ABC, abstractmethod
from collections.abc import Generator, Sequence
from contextlib import contextmanager
from datetime import timedelta
from enum import Enum, auto
from functools import cached_property
from typing import cast
from uuid import uuid4

Expand All @@ -16,17 +17,90 @@ class LockError(BasePalaceException):
pass


class LockReturn(Enum):
failed = auto()
timeout = auto()
acquired = auto()
extended = auto()
class BaseRedisLock(ABC):
def __init__(
self,
redis_client: Redis,
random_value: str | None = None,
):
self._redis_client = redis_client
self._random_value = random_value if random_value else str(uuid4())

@abstractmethod
def acquire(self) -> bool:
"""
Acquire the lock. Always non-blocking.
:return: True if the lock was acquired, False otherwise.
"""

@abstractmethod
def release(self) -> bool:
"""
Release the lock.
:return: True if the lock was released, False if there was some error releasing the lock.
"""

@abstractmethod
def locked(self, by_us: bool = False) -> bool:
"""
Check if the lock is currently held, by us or anyone else.
:param by_us: If True, check if the lock is held by us. If False, check if the lock is held by anyone.
:return: True if the lock is held, False otherwise.
"""

@abstractmethod
def extend_timeout(self) -> bool:
"""
Extend the timeout of the lock.
def __bool__(self) -> bool:
return self in (LockReturn.acquired, LockReturn.extended)
:return: True if the timeout was extended, False otherwise.
"""

@property
@abstractmethod
def key(self) -> str:
"""
Return the key used to store the lock in Redis.
:return: The key used to store the lock in Redis.
"""

@contextmanager
def lock(
self,
release_on_error: bool = True,
release_on_exit: bool = True,
ignored_exceptions: tuple[type[BaseException], ...] = (),
) -> Generator[bool, None, None]:
"""
Context manager for acquiring and releasing the lock.
:param release_on_error: If True, release the lock if an exception occurs.
:param release_on_exit: If True, release the lock when the context manager exits.
:param ignored_exceptions: Exceptions that should not cause the lock to be released.
:return: The result of the lock acquisition. You must check the return value to see if the lock was acquired.
"""
locked = self.acquire()
exception_occurred = False
try:
yield locked
except Exception as exc:
if not issubclass(exc.__class__, ignored_exceptions):
exception_occurred = True
raise
finally:
if (release_on_error and exception_occurred) or (
release_on_exit and not exception_occurred
):
self.release()


class RedisLock:
class RedisLock(BaseRedisLock):
"""
A simple distributed lock implementation using Redis.
Expand Down Expand Up @@ -54,114 +128,98 @@ def __init__(
redis_client: Redis,
lock_name: str | Sequence[str],
random_value: str | None = None,
timeout: timedelta | None = timedelta(minutes=5),
lock_timeout: timedelta | None = timedelta(minutes=5),
retry_delay: float = 0.2,
):
self._redis_client = redis_client
super().__init__(redis_client, random_value)
if isinstance(lock_name, str):
lock_name = [lock_name]
self.lock_key = self._redis_client.get_key(self._lock_type, *lock_name)
self.random_value = random_value if random_value else str(uuid4())
self.timeout = timeout
self._lock_timeout = lock_timeout
self._retry_delay = retry_delay
self._lock_name = lock_name
self.unlock_script = self._redis_client.register_script(self._UNLOCK_SCRIPT)
self.extend_script = self._redis_client.register_script(self._EXTEND_SCRIPT)
self._retry_delay = retry_delay

@cached_property
def key(self) -> str:
return self._redis_client.get_key(self._lock_type, *self._lock_name)

@property
def _lock_type(self) -> str:
return self.__class__.__name__

def _acquire(self) -> LockReturn:
def acquire(self) -> bool:
previous_value = cast(
str | None,
self._redis_client.set(
self.lock_key, self.random_value, nx=True, px=self.timeout, get=True
self.key,
self._random_value,
nx=True,
px=self._lock_timeout,
get=True,
),
)

if (
previous_value is not None
and previous_value == self.random_value
and self.timeout is not None
and previous_value == self._random_value
and self._lock_timeout is not None
):
return LockReturn.extended if self.extend_timeout() else LockReturn.failed
return self.extend_timeout()

return (
LockReturn.acquired
if previous_value is None or previous_value == self.random_value
else LockReturn.failed
)
return previous_value is None or previous_value == self._random_value

def acquire_blocking(self, timeout: float | int = -1) -> bool:
"""
Acquire the lock. Blocks until the lock is acquired or the timeout is reached.
def acquire(self, blocking: bool = False, timeout: float | int = -1) -> LockReturn:
if not blocking and timeout != -1:
raise LockError("Cannot specify a timeout without blocking")
This is a light wrapper around acquire that adds blocking and timeout functionality.
if not blocking:
return self._acquire()
:param timeout: The maximum time to wait for the lock to be acquired. If 0, wait indefinitely.
:return: The result of the lock acquisition. You must check the return value to see if the lock was acquired.
"""
if timeout < 0:
raise LockError("Cannot specify a negative timeout")

start_time = time.time()
while timeout == -1 or (time.time() - start_time) < timeout:
acquired = self._acquire()
while timeout == 0 or (time.time() - start_time) < timeout:
acquired = self.acquire()
if acquired:
return acquired
delay = random.uniform(0, self._retry_delay)
time.sleep(delay)
return LockReturn.timeout
return False

def release(self) -> bool:
ret_val: int = self.unlock_script(
keys=(self.lock_key,), args=(self.random_value,)
)
ret_val: int = self.unlock_script(keys=(self.key,), args=(self._random_value,))
return ret_val == 1

def extend_timeout(self) -> bool:
if self.timeout is None:
if self._lock_timeout is None:
# If the lock has no timeout, we can't extend it
return False

timout_ms = int(self.timeout.total_seconds() * 1000)
timout_ms = int(self._lock_timeout.total_seconds() * 1000)
ret_val: int = self.extend_script(
keys=(self.lock_key,), args=(self.random_value, timout_ms)
keys=(self.key,), args=(self._random_value, timout_ms)
)
return ret_val == 1

def locked(self, by_us: bool = False) -> bool:
key_value = self._redis_client.get(self.lock_key)
key_value = self._redis_client.get(self.key)
if by_us:
return key_value == self.random_value
return key_value == self._random_value
return key_value is not None

@contextmanager
def lock(
self,
blocking: bool = False,
timeout: float | int = -1,
release_on_error: bool = True,
release_on_exit: bool = True,
ignored_exceptions: tuple[type[BaseException], ...] = (),
) -> Generator[LockReturn, None, None]:
locked = self.acquire(blocking=blocking, timeout=timeout)
exception_occurred = False
try:
yield locked
except Exception as exc:
if not issubclass(exc.__class__, ignored_exceptions):
exception_occurred = True
raise
finally:
if (release_on_error and exception_occurred) or (
release_on_exit and not exception_occurred
):
self.release()


class TaskLock(RedisLock):
def __init__(
self,
redis_client: Redis,
task: Task,
lock_name: str | None = None,
timeout: timedelta | None = timedelta(minutes=5),
lock_timeout: timedelta | None = timedelta(minutes=5),
retry_delay: float = 0.2,
):
random_value = task.request.root_id or task.request.id
Expand All @@ -173,4 +231,4 @@ def __init__(
name = ["Task", task.name]
else:
name = [lock_name]
super().__init__(redis_client, name, random_value, timeout, retry_delay)
super().__init__(redis_client, name, random_value, lock_timeout, retry_delay)
Loading

0 comments on commit 38f2fef

Please sign in to comment.