From 77399038e9c81fbdc56b3b8eb5bd3bb85ba2779f Mon Sep 17 00:00:00 2001 From: Koushik Dutta Date: Fri, 17 Mar 2023 22:21:07 -0700 Subject: [PATCH] server: clean up python rpc transports --- server/package-lock.json | 4 +- server/python/plugin_remote.py | 55 ++++------- server/python/rpc_reader.py | 166 ++++++++++++++++++++++----------- 3 files changed, 133 insertions(+), 92 deletions(-) diff --git a/server/package-lock.json b/server/package-lock.json index 8578bf462..ecea402b3 100644 --- a/server/package-lock.json +++ b/server/package-lock.json @@ -1,12 +1,12 @@ { "name": "@scrypted/server", - "version": "0.7.12", + "version": "0.7.13", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@scrypted/server", - "version": "0.7.12", + "version": "0.7.13", "license": "ISC", "dependencies": { "@mapbox/node-pre-gyp": "^1.0.10", diff --git a/server/python/plugin_remote.py b/server/python/plugin_remote.py index cd08b6333..6a09b8e01 100644 --- a/server/python/plugin_remote.py +++ b/server/python/plugin_remote.py @@ -35,25 +35,6 @@ class SystemDeviceState(TypedDict): stateTime: int value: any - -class StreamPipeReader: - def __init__(self, conn: multiprocessing.connection.Connection) -> None: - self.conn = conn - self.executor = concurrent.futures.ThreadPoolExecutor() - - def readBlocking(self, n): - b = bytes(0) - while len(b) < n: - self.conn.poll(None) - add = os.read(self.conn.fileno(), n - len(b)) - if not len(add): - raise Exception('unable to read requested bytes') - b += add - return b - - async def read(self, n): - return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.readBlocking(n)) - class SystemManager(scrypted_python.scrypted_sdk.types.SystemManager): def __init__(self, api: Any, systemState: Mapping[str, Mapping[str, SystemDeviceState]]) -> None: super().__init__() @@ -288,8 +269,9 @@ class PluginRemote: clusterSecret = options['clusterSecret'] async def handleClusterClient(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + rpcTransport = rpc_reader.RpcStreamTransport(reader, writer) peer: rpc.RpcPeer - peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, reader = reader, writer = writer) + peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, rpcTransport) async def connectRPCObject(id: str, secret: str): m = hashlib.sha256() m.update(bytes('%s%s' % (clusterPort, clusterSecret), 'utf8')) @@ -324,7 +306,8 @@ class PluginRemote: async def connectClusterPeer(): reader, writer = await asyncio.open_connection( '127.0.0.1', port) - peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, reader = reader, writer = writer) + rpcTransport = rpc_reader.RpcStreamTransport(reader, writer) + peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, rpcTransport) async def run_loop(): try: await peerReadLoop() @@ -485,8 +468,8 @@ class PluginRemote: schedule_exit_check() async def getFork(): - reader = StreamPipeReader(parent_conn) - forkPeer, readLoop = await rpc_reader.prepare_peer_readloop(self.loop, reader = reader, writeFd = parent_conn.fileno()) + rpcTransport = rpc_reader.RpcConnectionTransport(parent_conn) + forkPeer, readLoop = await rpc_reader.prepare_peer_readloop(self.loop, rpcTransport) forkPeer.peerName = 'thread' async def updateStats(stats): @@ -502,7 +485,7 @@ class PluginRemote: finally: allMemoryStats.pop(forkPeer) parent_conn.close() - reader.executor.shutdown() + rpcTransport.executor.shutdown() asyncio.run_coroutine_threadsafe(forkReadLoop(), loop=self.loop) getRemote = await forkPeer.getParam('getRemote') remote: PluginRemote = await getRemote(self.api, self.pluginId, self.hostInfo) @@ -594,8 +577,8 @@ class PluginRemote: allMemoryStats = {} -async def plugin_async_main(loop: AbstractEventLoop, readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None): - peer, readLoop = await rpc_reader.prepare_peer_readloop(loop, readFd=readFd, writeFd=writeFd, reader=reader, writer=writer) +async def plugin_async_main(loop: AbstractEventLoop, rpcTransport: rpc_reader.RpcTransport): + peer, readLoop = await rpc_reader.prepare_peer_readloop(loop, rpcTransport) peer.params['print'] = print peer.params['getRemote'] = lambda api, pluginId, hostInfo: PluginRemote(peer, api, pluginId, hostInfo, loop) @@ -642,11 +625,11 @@ async def plugin_async_main(loop: AbstractEventLoop, readFd: int = None, writeFd try: await readLoop() finally: - if reader and hasattr(reader, 'executor'): - r: StreamPipeReader = reader + if type(rpcTransport) == rpc_reader.RpcConnectionTransport: + r: rpc_reader.RpcConnectionTransport = rpcTransport r.executor.shutdown() -def main(readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None): +def main(rpcTransport: rpc_reader.RpcTransport): loop = asyncio.new_event_loop() def gc_runner(): @@ -654,10 +637,10 @@ def main(readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = loop.call_later(10, gc_runner) gc_runner() - loop.run_until_complete(plugin_async_main(loop, readFd=readFd, writeFd=writeFd, reader=reader, writer=writer)) + loop.run_until_complete(plugin_async_main(loop, rpcTransport)) loop.close() -def plugin_main(readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None): +def plugin_main(rpcTransport: rpc_reader.RpcTransport): try: import gi gi.require_version('Gst', '1.0') @@ -666,18 +649,16 @@ def plugin_main(readFd: int = None, writeFd: int = None, reader: asyncio.StreamR loop = GLib.MainLoop() - worker = threading.Thread(target=main, args=(readFd, writeFd, reader, writer), name="asyncio-main") + worker = threading.Thread(target=main, args=(rpcTransport,), name="asyncio-main") worker.start() loop.run() except: - main(readFd=readFd, writeFd=writeFd, reader=reader, writer=writer) + main(rpcTransport) def plugin_fork(conn: multiprocessing.connection.Connection): - fd = os.dup(conn.fileno()) - reader = StreamPipeReader(conn) - plugin_main(reader=reader, writeFd=fd) + plugin_main(rpc_reader.RpcConnectionTransport(conn)) if __name__ == "__main__": - plugin_main(3, 4) + plugin_main(rpc_reader.RpcFileTransport(3, 4)) diff --git a/server/python/rpc_reader.py b/server/python/rpc_reader.py index ab0e4debb..e14348630 100644 --- a/server/python/rpc_reader.py +++ b/server/python/rpc_reader.py @@ -4,15 +4,14 @@ import asyncio import base64 import json import os -import sys import threading from asyncio.events import AbstractEventLoop -from os import sys -from typing import List - +from typing import List, Any +import multiprocessing.connection import aiofiles import rpc - +import concurrent.futures +import json class BufferSerializer(rpc.RpcSerializer): def serialize(self, value, serializationContext): @@ -36,31 +35,118 @@ class SidebandBufferSerializer(rpc.RpcSerializer): buffer = buffers.pop() return buffer -async def readLoop(loop, peer: rpc.RpcPeer, reader: asyncio.StreamReader): +class RpcTransport: + async def prepare(self): + pass + + async def read(self): + pass + + def writeBuffer(self, buffer, reject): + pass + + def writeJSON(self, json, reject): + pass + +class RpcFileTransport(RpcTransport): + reader: asyncio.StreamReader + + def __init__(self, readFd: int, writeFd: int) -> None: + super().__init__() + self.readFd = readFd + self.writeFd = writeFd + self.reader = None + + async def prepare(self): + await super().prepare() + self.reader = await aiofiles.open(self.readFd, mode='rb') + + async def read(self): + lengthBytes = await self.reader.read(4) + typeBytes = await self.reader.read(1) + type = typeBytes[0] + length = int.from_bytes(lengthBytes, 'big') + data = await self.reader.read(length - 1) + if type == 1: + return data + message = json.loads(data) + return message + + def writeMessage(self, type: int, buffer, reject): + length = len(buffer) + 1 + lb = length.to_bytes(4, 'big') + try: + for b in [lb, bytes([type]), buffer]: + os.write(self.writeFd, b) + except Exception as e: + if reject: + reject(e) + + def writeJSON(self, j, reject): + return self.writeMessage(0, bytes(json.dumps(j), 'utf8'), reject) + + def writeBuffer(self, buffer, reject): + return self.writeMessage(1, buffer, reject) + +class RpcStreamTransport(RpcTransport): + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + super().__init__() + self.reader = reader + self.writer = writer + + async def read(self, n: int): + return await self.reader.readexactly(n) + + def writeMessage(self, type: int, buffer, reject): + length = len(buffer) + 1 + lb = length.to_bytes(4, 'big') + try: + for b in [lb, bytes([type]), buffer]: + self.writer.write(b) + except Exception as e: + if reject: + reject(e) + + def writeJSON(self, j, reject): + return self.writeMessage(0, bytes(json.dumps(j), 'utf8'), reject) + + def writeBuffer(self, buffer, reject): + return self.writeMessage(1, buffer, reject) + +class RpcConnectionTransport(RpcTransport): + def __init__(self, connection: multiprocessing.connection.Connection) -> None: + super().__init__() + self.connection = connection + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + async def read(self): + return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.connection.recv()) + + def writeMessage(self, json, reject): + try: + self.connection.send(json) + except Exception as e: + if reject: + reject(e) + + def writeJSON(self, json, reject): + return self.writeMessage(json, reject) + + def writeBuffer(self, buffer, reject): + return self.writeMessage(bytes(buffer), reject) + +async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport): deserializationContext = { 'buffers': [] } - if isinstance(reader, asyncio.StreamReader): - async def read(n): - return await reader.readexactly(n) - else: - async def read(n): - return await reader.read(n) - - while True: - lengthBytes = await read(4) - typeBytes = await read(1) - type = typeBytes[0] - length = int.from_bytes(lengthBytes, 'big') - data = await read(length - 1) + message = await rpcTransport.read() - if type == 1: - deserializationContext['buffers'].append(data) + if type(message) != dict: + deserializationContext['buffers'].append(message) continue - message = json.loads(data) asyncio.run_coroutine_threadsafe( peer.handleMessage(message, deserializationContext), loop) @@ -68,46 +154,20 @@ async def readLoop(loop, peer: rpc.RpcPeer, reader: asyncio.StreamReader): 'buffers': [] } -async def prepare_peer_readloop(loop: AbstractEventLoop, readFd: int = None, writeFd: int = None, reader: asyncio.StreamReader = None, writer: asyncio.StreamWriter = None): - reader = reader or await aiofiles.open(readFd, mode='rb') +async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport): + await rpcTransport.prepare() mutex = threading.Lock() - if writer: - def write(buffers, reject): - try: - for b in buffers: - writer.write(b) - except Exception as e: - if reject: - reject(e) - return None - else: - def write(buffers, reject): - try: - for b in buffers: - os.write(writeFd, b) - except Exception as e: - if reject: - reject(e) - def send(message, reject=None, serializationContext=None): with mutex: if serializationContext: buffers = serializationContext.get('buffers', None) if buffers: for buffer in buffers: - length = len(buffer) + 1 - lb = length.to_bytes(4, 'big') - type = 1 - write([lb, bytes([type]), buffer], reject) + rpcTransport.writeBuffer(buffer, reject) - jsonString = json.dumps(message) - b = bytes(jsonString, 'utf8') - length = len(b) + 1 - lb = length.to_bytes(4, 'big') - type = 0 - write([lb, bytes([type]), b], reject) + rpcTransport.writeJSON(message, reject) peer = rpc.RpcPeer(send) peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer() @@ -117,7 +177,7 @@ async def prepare_peer_readloop(loop: AbstractEventLoop, readFd: int = None, wri async def peerReadLoop(): try: - await readLoop(loop, peer, reader) + await readLoop(loop, peer, rpcTransport) except: peer.kill() raise