Source code for lvmopstools.retrier

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2024-01-02
# @Filename: retrier.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

from __future__ import annotations

import asyncio
import inspect
import time
import warnings
from dataclasses import dataclass, field
from functools import wraps

from typing import (
    Any,
    Awaitable,
    Callable,
    ParamSpec,
    TypeVar,
    overload,
)

from typing_extensions import Self


__all__ = ["Retrier"]

T = TypeVar("T", bound=Any)
P = ParamSpec("P")


[docs] @dataclass class Retrier: """A class that implements a retry mechanism. The object returned by this class can be used to wrap a function that will be retried ``max_attempts`` times if it fails:: def test_function(): ... retrier = Retrier(max_attempts=5) retrier(test_function)() where the wrapped function can be a coroutine, in which case the wrapped function will also be a coroutine. Most frequently this class will be used as a decorator:: @Retrier(max_attempts=4, delay=0.1) async def test_function(x, y): ... await test_function(1, 2) Parameters ---------- max_attempts The maximum number of attempts before giving up. delay The delay between attempts, in seconds. use_exponential_backoff Whether to use exponential backoff for the delay between attempts. If :obj:`True`, the delay will be ``delay * exponential_backoff_base ** (attempt - 1) + random_ms`` where ``random_ms`` is a random number between 0 and 100 ms used to avoid synchronisation issues. exponential_backoff_base The base for the exponential backoff. max_delay The maximum delay between attempts when using exponential backoff. on_retry A function that will be called when a retry is attempted. The function should accept an exception as its only argument. raise_on_exception_class A list of exception classes that will cause an exception to be raised without retrying. timeout If defined, each attempt can take at most this amount of time. If the attempt times out, an :obj:`asyncio.TimeoutError` will be raised. This only works if the wrapped function is a coroutine. """ max_attempts: int = 3 delay: float = 1 use_exponential_backoff: bool = True exponential_backoff_base: float = 2 max_delay: float = 32.0 on_retry: Callable[[Exception], None] | None = None raise_on_exception_class: list[type[Exception]] = field(default_factory=list) timeout: float | None = None
[docs] def calculate_delay(self, attempt: int) -> float: """Calculates the delay for a given attempt.""" # Random number between 0 and 100 ms to avoid synchronisation issues. random_ms = 0.1 * (time.time() % 1) if self.use_exponential_backoff: return min( self.delay * self.exponential_backoff_base ** (attempt - 1) + random_ms, self.max_delay, ) else: return self.delay
@overload def __call__( self: Self, func: Callable[P, T], ) -> Callable[P, T]: ... @overload def __call__( self: Self, func: Callable[P, Awaitable[T]], ) -> Callable[P, Awaitable[T]]: ... def __call__( self, func: Callable[P, T] | Callable[P, Awaitable[T]], ) -> Callable[P, T] | Callable[P, Awaitable[T]]: """Wraps a function to retry it if it fails.""" if inspect.iscoroutinefunction(func): @wraps(func) async def async_wrapper(*args: P.args, **kwargs: P.kwargs): attempt = 0 while True: try: return await asyncio.wait_for( func(*args, **kwargs), timeout=self.timeout, ) except Exception as ee: attempt += 1 if attempt >= self.max_attempts: raise ee elif isinstance(ee, tuple(self.raise_on_exception_class)): raise ee else: if self.on_retry: self.on_retry(ee) await asyncio.sleep(self.calculate_delay(attempt)) return async_wrapper else: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs): attempt = 0 while True: try: if self.timeout is not None: warnings.warn( "The wrapped function is not a coroutine. " "The timeout parameter will be ignored.", RuntimeWarning, ) return func(*args, **kwargs) except Exception as ee: attempt += 1 if attempt >= self.max_attempts: raise ee elif isinstance(ee, tuple(self.raise_on_exception_class)): raise ee else: if self.on_retry: self.on_retry(ee) time.sleep(self.calculate_delay(attempt)) return wrapper