Source code for lvmopstools.socket
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# @Author: José Sánchez-Gallego (gallegoj@uw.edu)
# @Date: 2024-01-17
# @Filename: socket.py
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause)
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
from lvmopstools.retrier import Retrier
__all__ = ["AsyncSocketHandler"]
RequestFuncType = Callable[[asyncio.StreamReader, asyncio.StreamWriter], Awaitable[Any]]
[docs]
@dataclass
class AsyncSocketHandler:
"""Handles a socket connection and disconnection.
Handles secure connection and disconnection to a TCP server and executed a
callback. By default :obj:`.Retrier` is used to retry the connection if it
fails either during the connection phase or during callback execution.
There are two ways to use this class. The first one is to create an instance
and call it with a callback function which receives :obj:`~asyncio.StreamReader`
and :obj:`~asyncio.StreamWriter` arguments ::
async def callback(reader, writer):
...
handler = AsyncSocketHandler(host, port)
await handler(callback)
Alternatively, you can subclass ``AsyncSocketHandler`` and override the
:obj:`request` method ::
class MyHandler(AsyncSocketHandler):
async def request(self, reader, writer):
...
Parameters
----------
host
The host that is running the server.
port
The port on which the server is listening.
timeout
The timeout for connection and callback execution.
retry
Whether to retry the connection/callback if they fails.
retrier_params
Parameters to pass to the :class:`.Retrier` instance.
"""
host: str
port: int
timeout: float = 5
retry: bool = True
retrier_params: dict[str, Any] = field(default_factory=dict)
async def _connect(self):
"""Connects to the socket."""
reader, writer = await asyncio.wait_for(
asyncio.open_connection(self.host, self.port),
timeout=self.timeout,
)
return reader, writer
async def _run(self, func: RequestFuncType | None = None):
"""Internal helper to connect to the socket and run the request."""
if func is None:
func = self.request
reader, writer = await self._connect()
try:
return await func(reader, writer)
finally:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
async def __call__(self, func: RequestFuncType | None = None):
"""Connects to the socket and runs the request function."""
if self.retry:
retrier = Retrier(**self.retrier_params)
return await retrier(self._run)(func)
else:
return await self._run(func)
[docs]
async def request(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
): # pragma: no cover
"""Sends a request to the socket.
If the handler is not called with a callback function, this method must be
overridden in a subclass. It receives the :obj:`~asyncio.StreamReader` and
:obj:`~asyncio.StreamWriter` client instances after a connection has been
established.
"""
return