Source code for ezmsg.websocket.units

import asyncio
import ssl
from dataclasses import field
from typing import AsyncGenerator, Optional, Union

import ezmsg.core as ez
import websockets.exceptions
import websockets.server
from websockets.legacy.client import WebSocketClientProtocol, connect


[docs] class WebsocketSettings(ez.Settings): host: str port: int cert_path: Optional[str] = None
[docs] class WebsocketState(ez.State): incoming_queue: "asyncio.Queue[Union[str,bytes]]" = field(default_factory=asyncio.Queue) outgoing_queue: "asyncio.Queue[Union[str,bytes]]" = field(default_factory=asyncio.Queue)
[docs] class WebsocketServer(ez.Unit): """ Receives arbitrary content from outside world and injects it into system in a DataArray """ SETTINGS = WebsocketSettings STATE = WebsocketState INPUT = ez.InputStream(bytes) OUTPUT = ez.OutputStream(bytes)
[docs] @ez.task async def start_server(self): ez.logger.info(f"Starting WS Input Server @ ws://{self.SETTINGS.host}:{self.SETTINGS.port}") async def connection(websocket: websockets.server.WebSocketServerProtocol, path): async def loop(mode): try: if mode == "rx": while True: data = await websocket.recv() self.STATE.incoming_queue.put_nowait(data) elif mode == "tx": while True: data = await self.STATE.outgoing_queue.get() await websocket.send(data) except websockets.exceptions.ConnectionClosedOK: pass except asyncio.CancelledError: pass except Exception as e: print("Error in websocket server:", e) pass finally: ... await asyncio.wait([loop(mode="tx"), loop(mode="rx")]) try: if self.SETTINGS.cert_path: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_context.load_cert_chain(self.SETTINGS.cert_path) else: ssl_context = None server = await websockets.server.serve(connection, self.SETTINGS.host, self.SETTINGS.port, ssl=ssl_context) await server.wait_closed() finally: ...
[docs] @ez.publisher(OUTPUT) async def publish_incoming(self): while True: data = await self.STATE.incoming_queue.get() yield self.OUTPUT, data
[docs] @ez.subscriber(INPUT) async def transmit_outgoing(self, message: bytes): self.STATE.outgoing_queue.put_nowait(message)
[docs] class WebsocketClient(ez.Unit): SETTINGS = WebsocketSettings STATE = WebsocketState INPUT = ez.InputStream(bytes) OUTPUT = ez.OutputStream(bytes)
[docs] async def rx_from(self, websocket: WebSocketClientProtocol): # await incoming data from websocket and post them # to incoming queue for publication within ezmsg async for message in websocket: self.STATE.incoming_queue.put_nowait(message)
[docs] async def tx_to(self, websocket: WebSocketClientProtocol): # await messages from subscription within ezmsg # and post them to outgoing websocket while True: message = await self.STATE.outgoing_queue.get() await websocket.send(message)
[docs] @ez.task async def connection(self): if self.SETTINGS.cert_path: prefix = "wss" else: prefix = "ws" uri = f"{prefix}://{self.SETTINGS.host}:{self.SETTINGS.port}" websocket = None for attempt in range(10): try: websocket = await connect(uri) break except Exception: await asyncio.sleep(0.5) if websocket is None: raise Exception(f"Could not connect to {uri}") receive_task = asyncio.ensure_future(self.rx_from(websocket)) transmit_task = asyncio.ensure_future(self.tx_to(websocket)) done, pending = await asyncio.wait([receive_task, transmit_task], return_when=asyncio.FIRST_COMPLETED) for task in pending: task.cancel() await websocket.close()
[docs] @ez.publisher(OUTPUT) async def receive(self) -> AsyncGenerator: while True: message = await self.STATE.incoming_queue.get() yield self.OUTPUT, message
[docs] @ez.subscriber(INPUT) async def transmit(self, message: bytes) -> None: self.STATE.outgoing_queue.put_nowait(message)