Source code for lvmopstools.pubsub

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2024-08-21
# @Filename: pubsub.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)

from __future__ import annotations

import json
import time
import uuid
from enum import auto

from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Awaitable,
    Callable,
    ClassVar,
    Literal,
    overload,
)

import aio_pika
from aio_pika.abc import AbstractIncomingMessage
from pydantic import BaseModel, Field
from strenum import UppercaseStrEnum
from typing_extensions import Self

from lvmopstools import config
from lvmopstools.retrier import Retrier


if TYPE_CHECKING:
    from aio_pika.abc import (
        AbstractChannel,
        AbstractConnection,
        AbstractExchange,
        AbstractQueue,
        ConsumerTag,
    )


SubCallbackType = Callable[["Message"], Awaitable[None]]
MessageType = Literal["event", "notification", "custom"]


[docs] class Event(UppercaseStrEnum): """Enumeration with the event types.""" ERROR = auto() RECIPE_START = auto() RECIPE_END = auto() RECIPE_FAILED = auto() OBSERVER_NEW_TILE = auto() OBSERVER_STAGE_RUNNING = auto() OBSERVER_STAGE_DONE = auto() OBSERVER_STAGE_FAILED = auto() OBSERVER_ACQUISITION_START = auto() OBSERVER_ACQUISITION_DONE = auto() OBSERVER_STANDARD_ACQUISITION_FAILED = auto() DOME_OPENING = auto() DOME_OPEN = auto() DOME_CLOSING = auto() DOME_CLOSED = auto() EMERGENCY_SHUTDOWN = auto() UNEXPECTED_FIBSEL_REHOME = auto() UNCATEGORISED = auto()
class PublishedMessageModel(BaseModel): """A model for messages published to the exchange.""" message_type: MessageType event_name: int | str | None payload: dict[str, Any] = {} timestamp: float = Field(default_factory=time.time) class EventModel(PublishedMessageModel): """A model for event messages.""" event_name: str message_type: Literal["event"] = "event"
[docs] class Message: """A model for messages to be published to the exchange.""" def __init__(self, message: AbstractIncomingMessage): self.message = message self.body: dict[str, Any] = json.loads(message.body) self.payload: dict[str, Any] = self.body.get("payload", {}) self.message_type: MessageType = self.body.get("message_type", "custom") self.event: Event | None = None self.event_name: str | None = None if self.message_type == "event": self.event_name = self.body["event_name"].upper() try: self.event = Event(self.event_name) except ValueError: self.event = Event.UNCATEGORISED
def callback_wrapper(func: SubCallbackType): """Wraps a callback to receive a ``Message`` instance.""" async def wrapper(message: AbstractIncomingMessage): async with message.process(): await func(Message(message)) return wrapper class BasePubSub: """A base class to connect to a RabbitMQ exchange. Parameters ---------- connection_string The connection string to the RabbitMQ server. exchange_name The name of the exchange where the messages will be sent. """ def __init__( self, connection_string: str | None = None, exchange_name: str | None = None, ): psc = config["pubsub"] self.connection_string = connection_string or psc["connection_string"] self.exchange_name = exchange_name or psc["exchange_name"] self.connection: AbstractConnection | None = None self.channel: AbstractChannel | None = None self.exchange: AbstractExchange | None = None async def connect(self) -> Self: """Connects to the RabbitMQ server and declares the exchange.""" self.connection = await aio_pika.connect_robust(self.connection_string) self.channel = await self.connection.channel() await self.channel.set_qos(prefetch_count=1) self.exchange = await self.channel.declare_exchange( self.exchange_name, auto_delete=True, type=aio_pika.ExchangeType.FANOUT, ) return self async def disconnect(self): """Disconnects from the RabbitMQ server.""" if self.channel and not self.channel.is_closed: await self.channel.close() if self.connection and not self.connection.is_closed: try: await self.connection.close() except Exception: pass async def __aenter__(self): if ( not self.connection or self.connection.is_closed or not self.channel or self.channel.is_closed ): await self.connect() return self async def __aexit__(self, exc_type, exc_value, traceback): await self.disconnect()
[docs] class Publisher(BasePubSub): """A class to publish messages to a RabbitMQ exchange. A singleton.""" _instance: ClassVar[Publisher] def __new__(cls, *args, **kwargs): if not hasattr(cls, "_instance"): cls._instance = super(Publisher, cls).__new__(cls) return cls._instance def __init__( self, connection_string: str | None = None, exchange_name: str | None = None, ): if not hasattr(self, "connection"): super().__init__(connection_string, exchange_name)
[docs] @Retrier(max_attempts=3, delay=0.5) async def publish(self, message: dict, routing_key: str | None = None): """Publishes a message to the exchange. Parameters ---------- message The message to publish. Must be a dictionary that will be encoded as a JSON string. routing_key The routing key to use. If not provided, uses the default routing key defined in the configuration. """ async with self: assert self.exchange, "exchange not defined." await self.exchange.publish( aio_pika.Message(body=json.dumps(message).encode()), routing_key=routing_key or config["pubsub.routing_key"], )
[docs] class Subscriber(BasePubSub): """A class to subscribe to messages from a RabbitMQ exchange.""" def __init__( self, connection_string: str | None = None, exchange_name: str | None = None, callback: SubCallbackType | None = None, queue_name: str | None = None, ): super().__init__( connection_string=connection_string, exchange_name=exchange_name, ) self.queue_name: str | None = queue_name self.queue: AbstractQueue | None = None self.callback = callback self.consumer_tag: ConsumerTag | None = None
[docs] async def connect(self, queue_name: str | None = None) -> Self: """Connects to the exchange, declares a queue, and binds the callback. Parameters ---------- queue_name The name of the queue to declare. If not provided, a random name will be generated (recommended). """ await super().connect() assert self.channel, "channel not defined." assert self.exchange, "exchange not defined." self.queue_name = ( queue_name or self.queue_name or f"{self.exchange_name}-{str(uuid.uuid4()).split('-')[-1]}" ) self.queue = await self.channel.declare_queue( self.queue_name, auto_delete=True, exclusive=True, ) await self.queue.bind( self.exchange, routing_key=config["pubsub.routing_key"], ) if self.callback: self.consumer_tag = await self.queue.consume( callback_wrapper(self.callback) ) return self
[docs] async def disconnect(self): """Disconnects from the RabbitMQ server.""" if self.queue: if self.consumer_tag: await self.queue.cancel(self.consumer_tag) if self.exchange: await self.queue.unbind(self.exchange) await super().disconnect()
@overload async def get( self, decode: Literal[True] = True, ) -> Message: ... @overload async def get( self, decode: Literal[False], ) -> AbstractIncomingMessage: ...
[docs] async def get( self, decode: bool = True, ) -> AbstractIncomingMessage | Message: """Gets the next message from the queue.""" if not self.queue: raise RuntimeError("queue not defined.") if decode: return Message(await self.queue.get()) else: return await self.queue.get()
@overload def iterator( self, decode: Literal[True] = True, ) -> AsyncGenerator[Message, None]: ... @overload def iterator( self, decode: Literal[False] = False, ) -> AsyncGenerator[AbstractIncomingMessage, None]: ...
[docs] async def iterator( self, decode: bool = True, ) -> AsyncGenerator[AbstractIncomingMessage | Message, None]: """Iterates over a queue and yields messages.""" async with self as instance: assert instance.queue, "queue not defined." async with instance.queue.iterator() as queue_iter: async for message in queue_iter: async with message.process(): if decode: yield Message(message) else: yield message
[docs] async def send_event(event: Event | str, payload: dict[str, Any] = {}): """Convenience function to publish an event to the exchange.""" message = EventModel(event_name=event, payload=payload).model_dump() await Publisher().publish(message)