from __future__ import annotations import asyncio import base64 import json import os import threading from asyncio.events import AbstractEventLoop from typing import List, Any import multiprocessing.connection import aiofiles import rpc import concurrent.futures import json class BufferSerializer(rpc.RpcSerializer): def serialize(self, value, serializationContext): return base64.b64encode(value).decode('utf8') def deserialize(self, value, serializationContext): return base64.b64decode(value) class SidebandBufferSerializer(rpc.RpcSerializer): def serialize(self, value, serializationContext): buffers = serializationContext.get('buffers', None) if not buffers: buffers = [] serializationContext['buffers'] = buffers buffers.append(value) return len(buffers) - 1 def deserialize(self, value, serializationContext): buffers: List = serializationContext.get('buffers', None) buffer = buffers.pop() return buffer class RpcTransport: async def prepare(self): pass async def read(self): pass def writeBuffer(self, buffer, reject): pass def writeJSON(self, json, reject): pass class RpcFileTransport(RpcTransport): reader: asyncio.StreamReader def __init__(self, readFd: int, writeFd: int) -> None: super().__init__() self.readFd = readFd self.writeFd = writeFd self.reader = None async def prepare(self): await super().prepare() self.reader = await aiofiles.open(self.readFd, mode='rb') async def read(self): lengthBytes = await self.reader.read(4) typeBytes = await self.reader.read(1) type = typeBytes[0] length = int.from_bytes(lengthBytes, 'big') data = await self.reader.read(length - 1) if type == 1: return data message = json.loads(data) return message def writeMessage(self, type: int, buffer, reject): length = len(buffer) + 1 lb = length.to_bytes(4, 'big') try: for b in [lb, bytes([type]), buffer]: os.write(self.writeFd, b) except Exception as e: if reject: reject(e) def writeJSON(self, j, reject): return self.writeMessage(0, bytes(json.dumps(j), 'utf8'), reject) def writeBuffer(self, buffer, reject): return self.writeMessage(1, buffer, reject) class RpcStreamTransport(RpcTransport): def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: super().__init__() self.reader = reader self.writer = writer async def read(self): lengthBytes = await self.reader.readexactly(4) typeBytes = await self.reader.readexactly(1) type = typeBytes[0] length = int.from_bytes(lengthBytes, 'big') data = await self.reader.readexactly(length - 1) if type == 1: return data message = json.loads(data) return message def writeMessage(self, type: int, buffer, reject): length = len(buffer) + 1 lb = length.to_bytes(4, 'big') try: for b in [lb, bytes([type]), buffer]: self.writer.write(b) except Exception as e: if reject: reject(e) def writeJSON(self, j, reject): return self.writeMessage(0, bytes(json.dumps(j), 'utf8'), reject) def writeBuffer(self, buffer, reject): return self.writeMessage(1, buffer, reject) class RpcConnectionTransport(RpcTransport): def __init__(self, connection: multiprocessing.connection.Connection) -> None: super().__init__() self.connection = connection self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) async def read(self): return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.connection.recv()) def writeMessage(self, json, reject): try: self.connection.send(json) except Exception as e: if reject: reject(e) def writeJSON(self, json, reject): return self.writeMessage(json, reject) def writeBuffer(self, buffer, reject): return self.writeMessage(bytes(buffer), reject) async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport): deserializationContext = { 'buffers': [] } while True: message = await rpcTransport.read() if type(message) != dict: deserializationContext['buffers'].append(message) continue asyncio.run_coroutine_threadsafe( peer.handleMessage(message, deserializationContext), loop) deserializationContext = { 'buffers': [] } async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport): await rpcTransport.prepare() mutex = threading.Lock() def send(message, reject=None, serializationContext=None): with mutex: if serializationContext: buffers = serializationContext.get('buffers', None) if buffers: for buffer in buffers: rpcTransport.writeBuffer(buffer, reject) rpcTransport.writeJSON(message, reject) peer = rpc.RpcPeer(send) peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer() peer.constructorSerializerMap[bytes] = 'Buffer' peer.constructorSerializerMap[bytearray] = 'Buffer' peer.constructorSerializerMap[memoryview] = 'Buffer' async def peerReadLoop(): try: await readLoop(loop, peer, rpcTransport) except: peer.kill() raise return peer, peerReadLoop