Source code for ezmsg.zmq.pubsub

import asyncio
from typing import AsyncGenerator

import zmq
import zmq.asyncio
from zmq.utils.monitor import parse_monitor_message
import ezmsg.core as ez

from .util import ZMQMessage


POLL_TIME = 0.1
STARTUP_WAIT_TIME = 0.1


[docs] class ZMQSenderSettings(ez.Settings): write_addr: str zmq_topic: str multipart: bool = False wait_for_sub: bool = True
[docs] class ZMQSenderState(ez.State): context: zmq.asyncio.Context socket: zmq.asyncio.Socket monitor: zmq.asyncio.Socket
[docs] class ZMQSenderUnit(ez.Unit): """ Represents a node in an ezmsg graph that receives ZMQMessage messages on its INPUT stream, then publishes each message by writing to a zmq.PUB socket. Args: write_addr: The address to which ZMQ data should be written. zmq_topic: The ZMQ topic being sent. multipart: If True, use socket.send_multipart, else use socket.send. wait_for_sub: If True, the sender will wait for a subscriber before publishing This behaves strangely and should be set False. """ INPUT = ez.InputStream(ZMQMessage) SETTINGS = ZMQSenderSettings STATE = ZMQSenderState
[docs] def initialize(self) -> None: self.STATE.context = zmq.asyncio.Context() self.STATE.socket = self.STATE.context.socket(zmq.PUB) self.STATE.monitor = self.STATE.socket.get_monitor_socket() ez.logger.debug(f"{self}:binding to {self.SETTINGS.write_addr}") self.STATE.socket.bind(self.SETTINGS.write_addr) self.has_subscribers = False
[docs] def shutdown(self) -> None: self.STATE.monitor.close() self.STATE.socket.close() self.STATE.context.term()
@ez.task async def _socket_monitor(self) -> None: while True: monitor_result = await self.STATE.monitor.poll(100, zmq.POLLIN) if monitor_result: data = await self.STATE.monitor.recv_multipart() evt = parse_monitor_message(data) event = evt["event"] if event == zmq.EVENT_ACCEPTED: ez.logger.debug(f"{self}:subscriber joined") self.has_subscribers = True elif event in ( zmq.EVENT_DISCONNECTED, zmq.EVENT_MONITOR_STOPPED, zmq.EVENT_CLOSED, ): break
[docs] @ez.subscriber(INPUT, zero_copy=True) async def zmq_subscriber(self, message: ZMQMessage) -> None: while self.SETTINGS.wait_for_sub and not self.has_subscribers: await asyncio.sleep(STARTUP_WAIT_TIME) if self.SETTINGS.multipart is True: await self.STATE.socket.send_multipart( (bytes(self.SETTINGS.zmq_topic, "UTF-8"), message.data), flags=zmq.NOBLOCK, ) else: await self.STATE.socket.send( b"".join((bytes(self.SETTINGS.zmq_topic, "UTF-8"), message.data)), flags=zmq.NOBLOCK, )
[docs] class ZMQPollerSettings(ez.Settings): read_addr: str zmq_topic: str poll_time: float = POLL_TIME multipart: bool = False
[docs] class ZMQPollerState(ez.State): context: zmq.asyncio.Context socket: zmq.asyncio.Socket monitor: zmq.asyncio.Socket poller: zmq.Poller
[docs] class ZMQPollerUnit(ez.Unit): """ Represents a node in the graph which polls data from ZMQ. Data polled from ZMQ are subsequently pushed to the rest of the graph as a ZMQMessage. Args: read_addr: The address from which ZMQ data should be polled. zmq_topic: The ZMQ topic being polled. timeout: The maximum amount of time (in seconds) that should be spent polling a ZMQ socket each time. Defaults to FOREVER_POLL_TIME if not specified. exit_condition: An optional ZMQ event code specifying the event which, if encountered by the monitor, should signal the termination of this particular node's activity. """ OUTPUT = ez.OutputStream(ZMQMessage) SETTINGS = ZMQPollerSettings STATE = ZMQPollerState
[docs] def initialize(self) -> None: self.STATE.context = zmq.asyncio.Context() self.STATE.socket = self.STATE.context.socket(zmq.SUB) self.STATE.monitor = self.STATE.socket.get_monitor_socket() self.STATE.socket.connect(self.SETTINGS.read_addr) self.STATE.socket.subscribe(self.SETTINGS.zmq_topic) self.STATE.poller = zmq.Poller() self.STATE.poller.register(self.STATE.socket, zmq.POLLIN) self.socket_open = False
[docs] def shutdown(self) -> None: self.STATE.monitor.close() self.STATE.socket.close() self.STATE.context.term()
[docs] @ez.task async def socket_monitor(self) -> None: while True: monitor_result = await self.STATE.monitor.poll(100, zmq.POLLIN) if monitor_result: data = await self.STATE.monitor.recv_multipart() evt = parse_monitor_message(data) event = evt["event"] if event == zmq.EVENT_CONNECTED: self.socket_open = True elif event == zmq.EVENT_CLOSED: # was_open = self.socket_open self.socket_open = False # ZMQ seems to be sending spurious CLOSED event when we # try to connect before the source is running. Only give up # if we were previously connected. If we give up now, we # will never unblock zmq_publisher. # if was_open: # break elif event == zmq.EVENT_DISCONNECTED: self.socket_open = False # break elif event == zmq.EVENT_MONITOR_STOPPED: self.socket_open = False break
[docs] @ez.publisher(OUTPUT) async def zmq_publisher(self) -> AsyncGenerator: while True: # Wait for socket connection if not self.socket_open: await asyncio.sleep(POLL_TIME) if self.socket_open: poll_result = await self.STATE.socket.poll( self.SETTINGS.poll_time * 1000, zmq.POLLIN ) if poll_result: if self.SETTINGS.multipart is True: _, data = await self.STATE.socket.recv_multipart() else: data = await self.STATE.socket.recv() yield self.OUTPUT, ZMQMessage(data)