from __future__ import annotations import asyncio import base64 import concurrent.futures import json import multiprocessing.connection import os import pickle import threading from asyncio.events import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from typing import List import rpc class BufferSerializer(rpc.RpcSerializer): def serialize(self, value, serializationContext): return base64.b64encode(value).decode("utf8") def deserialize(self, value, serializationContext): return base64.b64decode(value) class SidebandBufferSerializer(rpc.RpcSerializer): def serialize(self, value, serializationContext): buffers = serializationContext.get("buffers", None) if not buffers: buffers = [] serializationContext["buffers"] = buffers buffers.append(value) return len(buffers) - 1 def deserialize(self, value, serializationContext): buffers: List = serializationContext.get("buffers", None) buffer = buffers.pop() return buffer class RpcTransport: async def prepare(self): pass async def read(self): pass def writeBuffer(self, buffer, reject): pass def writeSerialized(self, json, reject): pass class RpcFileTransport(RpcTransport): def __init__(self, readFd: int, writeFd: int) -> None: super().__init__() self.readFd = readFd self.writeFd = writeFd self.executor = ThreadPoolExecutor(1, "rpc-read") def osReadExact(self, size: int): b = bytes(0) while size: got = os.read(self.readFd, size) if not len(got): self.executor.shutdown(False) raise Exception("rpc end of stream reached") size -= len(got) b += got return b def readMessageInternal(self): lengthBytes = self.osReadExact(4) typeBytes = self.osReadExact(1) type = typeBytes[0] length = int.from_bytes(lengthBytes, "big") data = self.osReadExact(length - 1) if type == 1: return data message = json.loads(data) return message async def read(self): return await asyncio.get_event_loop().run_in_executor( self.executor, lambda: self.readMessageInternal() ) def writeMessage(self, type: int, buffer, reject): length = len(buffer) + 1 lb = length.to_bytes(4, "big") try: for b in [lb, bytes([type]), buffer]: os.write(self.writeFd, b) except Exception as e: if reject: reject(e) def writeSerialized(self, j, reject): return self.writeMessage( 0, bytes(json.dumps(j, allow_nan=False), "utf8"), reject ) def writeBuffer(self, buffer, reject): return self.writeMessage(1, buffer, reject) class RpcStreamTransport(RpcTransport): def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: super().__init__() self.reader = reader self.writer = writer async def read(self): lengthBytes = await self.reader.readexactly(4) typeBytes = await self.reader.readexactly(1) type = typeBytes[0] length = int.from_bytes(lengthBytes, "big") data = await self.reader.readexactly(length - 1) if type == 1: return data message = json.loads(data) return message def writeMessage(self, type: int, buffer, reject): length = len(buffer) + 1 lb = length.to_bytes(4, "big") try: for b in [lb, bytes([type]), buffer]: self.writer.write(b) except Exception as e: if reject: reject(e) def writeSerialized(self, j, reject): return self.writeMessage( 0, bytes(json.dumps(j, allow_nan=False), "utf8"), reject ) def writeBuffer(self, buffer, reject): return self.writeMessage(1, buffer, reject) class RpcPickleStreamTransport(RpcTransport): def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, pickler = pickle ) -> None: super().__init__() self.reader = reader self.writer = writer self.pickler = pickler async def read(self): lengthBytes = await self.reader.readexactly(4) typeBytes = await self.reader.readexactly(1) type = typeBytes[0] length = int.from_bytes(lengthBytes, "big") data = await self.reader.readexactly(length - 1) if type == 1: return data message = self.pickler.loads(data) return message def writeMessage(self, type: int, buffer, reject): length = len(buffer) + 1 lb = length.to_bytes(4, "big") try: for b in [lb, bytes([type]), buffer]: self.writer.write(b) except Exception as e: if reject: reject(e) def writeSerialized(self, j, reject): pickled = self.pickler.dumps(j) return self.writeMessage(0, pickled, reject) def writeBuffer(self, buffer, reject): return self.writeMessage(1, buffer, reject) class RpcConnectionTransport(RpcTransport): def __init__(self, connection: multiprocessing.connection.Connection) -> None: super().__init__() self.connection = connection self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) async def read(self): return await asyncio.get_event_loop().run_in_executor( self.executor, lambda: self.connection.recv() ) def writeMessage(self, json, reject): try: self.connection.send(json) except Exception as e: if reject: reject(e) def writeSerialized(self, json, reject): return self.writeMessage(json, reject) def writeBuffer(self, buffer, reject): return self.writeMessage(bytes(buffer), reject) async def readLoop(peer: rpc.RpcPeer, rpcTransport: RpcTransport): deserializationContext = {"buffers": []} while True: message = await rpcTransport.read() if type(message) != dict: deserializationContext["buffers"].append(message) continue asyncio.create_task(peer.handleMessage(message, deserializationContext)) deserializationContext = {"buffers": []} async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport): await rpcTransport.prepare() mutex = threading.Lock() def send(message, reject=None, serializationContext=None): with mutex: if serializationContext: buffers = serializationContext.get("buffers", None) if buffers: for buffer in buffers: rpcTransport.writeBuffer(buffer, reject) rpcTransport.writeSerialized(message, reject) peer = rpc.RpcPeer(send) peer.nameDeserializerMap["Buffer"] = SidebandBufferSerializer() peer.constructorSerializerMap[bytes] = "Buffer" peer.constructorSerializerMap[bytearray] = "Buffer" peer.constructorSerializerMap[memoryview] = "Buffer" async def peerReadLoop(): try: await readLoop(peer, rpcTransport) except: peer.kill() raise return peer, peerReadLoop