Source code for ezmsg.zmq.util

from dataclasses import dataclass
import json
from pickle import PickleBuffer
import typing

import ezmsg.core as ez
from ezmsg.util.messagelogger import log_object
from ezmsg.util.messagecodec import MessageDecoder


[docs] class ZeroCopyBytes(bytes): def __reduce_ex__(self, protocol): if protocol >= 5: return type(self)._reconstruct, (PickleBuffer(self),), None else: # PickleBuffer is forbidden with pickle protocols <= 4. return type(self)._reconstruct, (bytes(self),) @classmethod def _reconstruct(cls, obj): with memoryview(obj) as m: # Get a handle over the original buffer object obj = m.obj if isinstance(obj, cls): # Original buffer object is a ZeroCopyBytes, return it # as-is. return obj else: return cls(obj)
[docs] @dataclass class ZMQMessage: data: bytes
[docs] def serialize_msg(msg: typing.Any) -> bytes: return log_object(msg).encode("utf-8")
""" The following alternative to serialize_msg might be faster because it doesn't convert numpy arrays to ascii. class NumpyArrayEncoder(json.JSONEncoder): def default(self, obj): if hasattr(obj, "tolist"): # Likely numpy array to list return obj.tolist() else: return json.JSONEncoder.default(self, obj) def serialize_msg(msg: typing.Any) -> bytes: return json.dumps(asdict(msg), cls=NumpyArrayEncoder).encode("utf-8") """
[docs] class SerializeMessageSettings(ez.Settings): fun: typing.Callable = serialize_msg """ Function to serialize the message. Must take a single argument and return a bytes object. """
[docs] class SerializeMessage(ez.Unit): SETTINGS = SerializeMessageSettings INPUT = ez.InputStream(typing.Any) OUTPUT = ez.OutputStream(ZMQMessage)
[docs] @ez.subscriber(INPUT) @ez.publisher(OUTPUT) async def on_message(self, message: typing.Any) -> typing.AsyncGenerator: encoded = self.SETTINGS.fun(message) yield self.OUTPUT, ZMQMessage(data=encoded)
[docs] class DeserializeBytes(ez.Unit): INPUT = ez.InputStream(bytes) OUTPUT_SIGNAL = ez.OutputStream(typing.Any)
[docs] @ez.subscriber(INPUT) @ez.publisher(OUTPUT_SIGNAL) async def deserialize(self, msg: ZMQMessage) -> typing.AsyncGenerator: decoded_msg = json.loads(msg.data, cls=MessageDecoder)["obj"] yield self.OUTPUT_SIGNAL, decoded_msg