Source code for ezmsg.zmq.repreq

import asyncio
import typing

import zmq
import zmq.asyncio
import ezmsg.core as ez
from zmq.utils.monitor import parse_monitor_message

from .util import ZMQMessage


[docs] class ZMQRepSettings(ez.Settings): addr: str
[docs] class ZMQRepState(ez.State): context: zmq.asyncio.Context socket: zmq.asyncio.Socket queue: asyncio.Queue
[docs] class ZMQRep(ez.Unit): OUTPUT = ez.OutputStream(ZMQMessage) SETTINGS = ZMQRepSettings STATE = ZMQRepState
[docs] def initialize(self) -> None: self.STATE.context = zmq.asyncio.Context() self.STATE.socket = self.STATE.context.socket(zmq.REP) ez.logger.debug(f"{self}:binding to {self.SETTINGS.addr}") self.STATE.socket.bind(self.SETTINGS.addr) self.STATE.queue = asyncio.Queue()
[docs] def shutdown(self) -> None: self.STATE.socket.close() self.STATE.context.term()
def _handle_req(self, data: bytes) -> bytes: return data
[docs] @ez.task async def zmq_rep(self) -> None: while True: data = await self.STATE.socket.recv() response = self._handle_req(data) await self.STATE.socket.send(response) self.STATE.queue.put_nowait(data)
[docs] @ez.publisher(OUTPUT) async def send_reqs(self) -> typing.AsyncGenerator: while True: data = await self.STATE.queue.get() yield self.OUTPUT, ZMQMessage(data)
[docs] class ZMQReqSettings(ez.Settings): addr: str
[docs] class ZMQReqState(ez.State): context: zmq.asyncio.Context socket: zmq.asyncio.Socket monitor: zmq.asyncio.Socket
[docs] class ZMQReq(ez.Unit): INPUT = ez.InputStream(ZMQMessage) OUTPUT = ez.OutputStream(ZMQMessage) SETTINGS = ZMQReqSettings STATE = ZMQReqState
[docs] def initialize(self) -> None: self.STATE.context = zmq.asyncio.Context() self.STATE.socket = self.STATE.context.socket(zmq.REQ) self.STATE.monitor = self.STATE.socket.get_monitor_socket() ez.logger.debug(f"{self}:connecting to {self.SETTINGS.addr}") self.STATE.socket.connect(self.SETTINGS.addr) self._has_server = False
[docs] def shutdown(self) -> None: self.STATE.monitor.close() self.STATE.socket.close() self.STATE.context.term()
@ez.task async def _socket_monitor(self) -> None: while True: monitor_result = await self.STATE.monitor.poll(100, zmq.POLLIN) if monitor_result: data = await self.STATE.monitor.recv_multipart() evt = parse_monitor_message(data) event = evt["event"] if event == zmq.EVENT_CONNECTED: self._has_server = True elif event == zmq.EVENT_DISCONNECTED: self._has_server = False
[docs] @ez.subscriber(INPUT, zero_copy=True) @ez.publisher(OUTPUT) async def send_req(self, msg: ZMQMessage) -> None: if self._has_server: await self.STATE.socket.send(msg.data) response = await self.STATE.socket.recv() yield self.OUTPUT, ZMQMessage(response)