diff --git a/server/python/cluster_labels.py b/server/python/cluster_labels.py index 6250d1af6..8cde632fd 100644 --- a/server/python/cluster_labels.py +++ b/server/python/cluster_labels.py @@ -47,5 +47,8 @@ def needs_cluster_fork_worker(options: ClusterForkOptions) -> bool: return ( os.environ.get("SCRYPTED_CLUSTER_ADDRESS") and options - and (not matches_cluster_labels(options, get_cluster_labels()) or options.get("clusterWorkerId", None)) + and ( + not matches_cluster_labels(options, get_cluster_labels()) + or options.get("clusterWorkerId", None) + ) ) diff --git a/server/python/cluster_setup.py b/server/python/cluster_setup.py index 7433795fc..8edea62d9 100644 --- a/server/python/cluster_setup.py +++ b/server/python/cluster_setup.py @@ -12,6 +12,7 @@ import rpc import rpc_reader from typing import TypedDict + class ClusterObject(TypedDict): id: str address: str @@ -20,12 +21,16 @@ class ClusterObject(TypedDict): sourceKey: str sha256: str + def isClusterAddress(address: str): return not address or address == os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None) + def getClusterPeerKey(address: str, port: int): return f"{address}:{port}" -class ClusterSetup(): + + +class ClusterSetup: def __init__(self, loop: AbstractEventLoop, peer: rpc.RpcPeer): self.loop = loop self.peer = peer @@ -50,9 +55,13 @@ class ClusterSetup(): sha256 = self.computeClusterObjectHash(o) if sha256 != o["sha256"]: raise Exception("secret incorrect") - return await self.resolveObject(o.get('proxyId', None), o.get('sourceKey', None)) + return await self.resolveObject( + o.get("proxyId", None), o.get("sourceKey", None) + ) - def onProxySerialization(self, peer: rpc.RpcPeer, value: Any, sourceKey: str = None): + def onProxySerialization( + self, peer: rpc.RpcPeer, value: Any, sourceKey: str = None + ): properties: dict = rpc.RpcPeer.prepareProxyProperties(value) or {} clusterEntry = properties.get("__cluster", None) proxyId: str @@ -126,7 +135,9 @@ class ClusterSetup(): handleClusterClient, listenAddress, 0 ) self.clusterPort = clusterRpcServer.sockets[0].getsockname()[1] - self.peer.onProxySerialization = lambda value: self.onProxySerialization(self.peer, value, None) + self.peer.onProxySerialization = lambda value: self.onProxySerialization( + self.peer, value, None + ) del self.peer.params["initializeCluster"] def computeClusterObjectHash(self, o: ClusterObject) -> str: @@ -215,9 +226,7 @@ class ClusterSetup(): peerConnectRPCObject = clusterPeer.tags.get("connectRPCObject") if not peerConnectRPCObject: - peerConnectRPCObject = await clusterPeer.getParam( - "connectRPCObject" - ) + peerConnectRPCObject = await clusterPeer.getParam("connectRPCObject") clusterPeer.tags["connectRPCObject"] = peerConnectRPCObject newValue = await peerConnectRPCObject(clusterObject) if not newValue: diff --git a/server/python/plugin_console.py b/server/python/plugin_console.py index ae2d4d68e..a46751dab 100644 --- a/server/python/plugin_console.py +++ b/server/python/plugin_console.py @@ -1,5 +1,6 @@ import typing + async def writeWorkerGenerator(gen, out: typing.TextIO): try: async for item in gen: diff --git a/server/python/plugin_pip.py b/server/python/plugin_pip.py index babac30b0..381229841 100644 --- a/server/python/plugin_pip.py +++ b/server/python/plugin_pip.py @@ -4,22 +4,25 @@ import sys from typing import Any import shutil + def get_requirements_files(requirements: str): - want_requirements = requirements + '.txt' - installed_requirementstxt = requirements + '.installed.txt' + want_requirements = requirements + ".txt" + installed_requirementstxt = requirements + ".installed.txt" return want_requirements, installed_requirementstxt + def need_requirements(requirements_basename: str, requirements_str: str): _, installed_requirementstxt = get_requirements_files(requirements_basename) if not os.path.exists(installed_requirementstxt): return True try: f = open(installed_requirementstxt, "rb") - installed_requirements = f.read().decode('utf8') + installed_requirements = f.read().decode("utf8") return requirements_str != installed_requirements except: return True + def remove_pip_dirs(plugin_volume: str): try: for de in os.listdir(plugin_volume): @@ -48,7 +51,9 @@ def install_with_pip( ignore_error: bool = False, site_packages: str = None, ): - requirementstxt, installed_requirementstxt = get_requirements_files(requirements_basename) + requirementstxt, installed_requirementstxt = get_requirements_files( + requirements_basename + ) os.makedirs(python_prefix, exist_ok=True) @@ -81,15 +86,16 @@ def install_with_pip( # force reinstall even if it exists in system packages. pipArgs.append("--force-reinstall") - env = None if site_packages: env = dict(os.environ) - PYTHONPATH = env['PYTHONPATH'] or '' - PYTHONPATH += ':' + site_packages + PYTHONPATH = env["PYTHONPATH"] or "" + PYTHONPATH += ":" + site_packages env["PYTHONPATH"] = PYTHONPATH print("PYTHONPATH", env["PYTHONPATH"]) - p = subprocess.Popen(pipArgs, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env) + p = subprocess.Popen( + pipArgs, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env + ) while True: line = p.stdout.readline() diff --git a/server/python/plugin_remote.py b/server/python/plugin_remote.py index 40c733070..8482d2c2c 100644 --- a/server/python/plugin_remote.py +++ b/server/python/plugin_remote.py @@ -27,18 +27,22 @@ from cluster_setup import ClusterSetup import cluster_labels from plugin_pip import install_with_pip, need_requirements, remove_pip_dirs from scrypted_python.scrypted_sdk import PluginFork, ScryptedStatic -from scrypted_python.scrypted_sdk.types import (Device, DeviceManifest, - EventDetails, - ScryptedInterface, - ScryptedInterfaceMethods, - ScryptedInterfaceProperty, - Storage) +from scrypted_python.scrypted_sdk.types import ( + Device, + DeviceManifest, + EventDetails, + ScryptedInterface, + ScryptedInterfaceMethods, + ScryptedInterfaceProperty, + Storage, +) SCRYPTED_REQUIREMENTS = """ ptpython wheel """.strip() + class SystemDeviceState(TypedDict): lastEventTime: int stateTime: int @@ -47,8 +51,10 @@ class SystemDeviceState(TypedDict): def ensure_not_coroutine(fn: Callable | Coroutine) -> Callable: if inspect.iscoroutinefunction(fn): + def wrapper(*args, **kwargs): return asyncio.create_task(fn(*args, **kwargs)) + return wrapper return fn @@ -96,25 +102,29 @@ class DeviceProxy(object): class EventListenerRegisterImpl(scrypted_python.scrypted_sdk.EventListenerRegister): removeListener: Callable[[], None] - def __init__(self, removeListener: Callable[[], None] | Coroutine[Any, None, None]) -> None: + def __init__( + self, removeListener: Callable[[], None] | Coroutine[Any, None, None] + ) -> None: self.removeListener = ensure_not_coroutine(removeListener) class EventRegistry(object): systemListeners: Set[scrypted_python.scrypted_sdk.EventListener] - listeners: Mapping[str, Set[Callable[[scrypted_python.scrypted_sdk.EventDetails, Any], None]]] + listeners: Mapping[ + str, Set[Callable[[scrypted_python.scrypted_sdk.EventDetails, Any], None]] + ] - __allowedEventInterfaces = set([ - ScryptedInterface.ScryptedDevice.value, - 'Logger', - 'Storage' - ]) + __allowedEventInterfaces = set( + [ScryptedInterface.ScryptedDevice.value, "Logger", "Storage"] + ) def __init__(self) -> None: self.systemListeners = set() self.listeners = {} - def __getMixinEventName(self, options: str | scrypted_python.scrypted_sdk.EventListenerOptions) -> str: + def __getMixinEventName( + self, options: str | scrypted_python.scrypted_sdk.EventListenerOptions + ) -> str: mixinId = None if type(options) == str: event = options @@ -155,7 +165,15 @@ class EventRegistry(object): self.listeners[id].add(callback) return EventListenerRegisterImpl(lambda: self.listeners[id].remove(callback)) - def notify(self, id: str, eventTime: int, eventInterface: str, property: str, value: Any, options: dict = None): + def notify( + self, + id: str, + eventTime: int, + eventInterface: str, + property: str, + value: Any, + options: dict = None, + ): options = options or {} changed = options.get("changed") mixinId = options.get("mixinId") @@ -174,7 +192,13 @@ class EventRegistry(object): return self.notifyEventDetails(id, eventDetails, value) - def notifyEventDetails(self, id: str, eventDetails: scrypted_python.scrypted_sdk.EventDetails, value: Any, eventInterface: str = None): + def notifyEventDetails( + self, + id: str, + eventDetails: scrypted_python.scrypted_sdk.EventDetails, + value: Any, + eventInterface: str = None, + ): if not eventDetails.get("eventId"): eventDetails["eventId"] = self.__generateBase36Str() if not eventInterface: @@ -183,8 +207,9 @@ class EventRegistry(object): # system listeners only get state changes. # there are many potentially noisy stateless events, like # object detection and settings changes - if (eventDetails.get("property") and not eventDetails.get("mixinId")) or \ - (eventInterface in EventRegistry.__allowedEventInterfaces): + if (eventDetails.get("property") and not eventDetails.get("mixinId")) or ( + eventInterface in EventRegistry.__allowedEventInterfaces + ): for listener in self.systemListeners: listener(id, eventDetails, value) @@ -202,6 +227,7 @@ class EventRegistry(object): return True + class ClusterManager(scrypted_python.scrypted_sdk.types.ClusterManager): def __init__(self, api: Any): self.api = api @@ -213,11 +239,16 @@ class ClusterManager(scrypted_python.scrypted_sdk.types.ClusterManager): def getClusterWorkerId(self) -> str: return os.getenv("SCRYPTED_CLUSTER_WORKER_ID", None) - async def getClusterWorkers(self) -> Mapping[str, scrypted_python.scrypted_sdk.types.ClusterWorker]: - self.clusterService = self.clusterService or asyncio.ensure_future(self.api.getComponent("cluster-fork")) + async def getClusterWorkers( + self, + ) -> Mapping[str, scrypted_python.scrypted_sdk.types.ClusterWorker]: + self.clusterService = self.clusterService or asyncio.ensure_future( + self.api.getComponent("cluster-fork") + ) cs = await self.clusterService return await cs.getClusterWorkers() + class SystemManager(scrypted_python.scrypted_sdk.types.SystemManager): def __init__( self, api: Any, systemState: Mapping[str, Mapping[str, SystemDeviceState]] @@ -306,19 +337,27 @@ class SystemManager(scrypted_python.scrypted_sdk.types.SystemManager): callback = ensure_not_coroutine(callback) if type(options) != str and options.get("watch"): return self.events.listenDevice( - id, options, - lambda eventDetails, eventData: callback(self.getDeviceById(id), eventDetails, eventData) + id, + options, + lambda eventDetails, eventData: callback( + self.getDeviceById(id), eventDetails, eventData + ), ) register_fut = asyncio.ensure_future( self.api.listenDevice( - id, options, - lambda eventDetails, eventData: callback(self.getDeviceById(id), eventDetails, eventData) + id, + options, + lambda eventDetails, eventData: callback( + self.getDeviceById(id), eventDetails, eventData + ), ) ) + async def unregister(): register = await register_fut await register.removeListener() + return EventListenerRegisterImpl(lambda: asyncio.ensure_future(unregister())) async def removeDevice(self, id: str) -> None: @@ -555,6 +594,7 @@ class DeviceManager(scrypted_python.scrypted_sdk.types.DeviceManager): def getDeviceStorage(self, nativeId: str = None) -> Storage: return self.nativeIds.get(nativeId, None) + class PeerLiveness: def __init__(self, loop: AbstractEventLoop): self.killed = Future(loop=loop) @@ -562,15 +602,22 @@ class PeerLiveness: async def waitKilled(self): await self.killed + def safe_set_result(fut: Future, result: Any): try: fut.set_result(result) except: pass + class PluginRemote: def __init__( - self, clusterSetup: ClusterSetup, api, pluginId: str, hostInfo, loop: AbstractEventLoop + self, + clusterSetup: ClusterSetup, + api, + pluginId: str, + hostInfo, + loop: AbstractEventLoop, ): self.systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {} self.nativeIds: Mapping[str, DeviceStorage] = {} @@ -606,7 +653,9 @@ class PluginRemote: consoleFuture = Future() self.consoles[nativeId] = consoleFuture plugins = await self.api.getComponent("plugins") - port, hostname = await plugins.getRemoteServicePort(self.pluginId, "console-writer") + port, hostname = await plugins.getRemoteServicePort( + self.pluginId, "console-writer" + ) connection = await asyncio.open_connection(host=hostname, port=port) _, writer = connection if not nativeId: @@ -682,7 +731,7 @@ class PluginRemote: if not forkMain: multiprocessing.set_start_method("spawn") - + # forkMain may be set to true, but the environment may not be initialized # if the plugin is loaded in another cluster worker. # instead rely on a environemnt variable that will be passed to @@ -819,10 +868,13 @@ class PluginRemote: async def getZip(self): return await zipAPI.getZip() - return await remote.loadZip(packageJson, PluginZipAPI(), forkOptions) + return await remote.loadZip( + packageJson, PluginZipAPI(), forkOptions + ) if cluster_labels.needs_cluster_fork_worker(options): peerLiveness = PeerLiveness(self.loop) + async def getClusterFork(): runtimeWorkerOptions = { "packageJson": packageJson, @@ -835,14 +887,17 @@ class PluginRemote: forkComponent = await self.api.getComponent("cluster-fork") sanitizedOptions = options.copy() - sanitizedOptions["runtime"] = sanitizedOptions.get("runtime", "python") + sanitizedOptions["runtime"] = sanitizedOptions.get( + "runtime", "python" + ) sanitizedOptions["zipHash"] = zipHash clusterForkResult = await forkComponent.fork( runtimeWorkerOptions, sanitizedOptions, - peerLiveness, lambda: zipAPI.getZip() + peerLiveness, + lambda: zipAPI.getZip(), ) - + async def waitPeerLiveness(): try: await peerLiveness.waitKilled() @@ -851,6 +906,7 @@ class PluginRemote: await clusterForkResult.kill() except: pass + asyncio.ensure_future(waitPeerLiveness(), loop=self.loop) async def waitClusterForkKilled(): @@ -859,30 +915,48 @@ class PluginRemote: except: pass safe_set_result(peerLiveness.killed, None) + asyncio.ensure_future(waitClusterForkKilled(), loop=self.loop) - clusterGetRemote = await self.clusterSetup.connectRPCObject(await clusterForkResult.getResult()) + clusterGetRemote = await self.clusterSetup.connectRPCObject( + await clusterForkResult.getResult() + ) remoteDict = await clusterGetRemote() - asyncio.ensure_future(plugin_console.writeWorkerGenerator(remoteDict["stdout"], sys.stdout)) - asyncio.ensure_future(plugin_console.writeWorkerGenerator(remoteDict["stderr"], sys.stderr)) + asyncio.ensure_future( + plugin_console.writeWorkerGenerator( + remoteDict["stdout"], sys.stdout + ) + ) + asyncio.ensure_future( + plugin_console.writeWorkerGenerator( + remoteDict["stderr"], sys.stderr + ) + ) getRemote = remoteDict["getRemote"] - directGetRemote = await self.clusterSetup.connectRPCObject(getRemote) + directGetRemote = await self.clusterSetup.connectRPCObject( + getRemote + ) if directGetRemote is getRemote: raise Exception("cluster fork peer not direct connected") - forkPeer = getattr(directGetRemote, rpc.RpcPeer.PROPERTY_PROXY_PEER) + forkPeer = getattr( + directGetRemote, rpc.RpcPeer.PROPERTY_PROXY_PEER + ) return await finishFork(forkPeer) - pluginFork = PluginFork() pluginFork.result = asyncio.create_task(getClusterFork()) + async def waitKilled(): await peerLiveness.killed + pluginFork.exit = asyncio.create_task(waitKilled()) + def terminate(): safe_set_result(peerLiveness.killed, None) pluginFork.worker.terminate() + pluginFork.terminate = terminate pluginFork.worker = None @@ -902,12 +976,16 @@ class PluginRemote: pluginFork = PluginFork() killed = Future(loop=self.loop) + async def waitKilled(): await killed + pluginFork.exit = asyncio.create_task(waitKilled()) + def terminate(): safe_set_result(killed, None) pluginFork.worker.kill() + pluginFork.terminate = terminate pluginFork.worker = multiprocessing.Process( @@ -956,6 +1034,7 @@ class PluginRemote: # sdk. from scrypted_sdk import sdk_init2 # type: ignore + sdk_init2(sdk) except: from scrypted_sdk import sdk_init # type: ignore @@ -1048,7 +1127,9 @@ async def plugin_async_main( peer.params["print"] = print clusterSetup = ClusterSetup(loop, peer) - peer.params["initializeCluster"] = lambda options: clusterSetup.initializeCluster(options) + peer.params["initializeCluster"] = lambda options: clusterSetup.initializeCluster( + options + ) async def ping(time: int): return time @@ -1077,6 +1158,7 @@ def main(rpcTransport: rpc_reader.RpcTransport): loop.run_until_complete(plugin_async_main(loop, rpcTransport)) loop.close() + def plugin_fork(conn: multiprocessing.connection.Connection): main(rpc_reader.RpcConnectionTransport(conn)) diff --git a/server/python/plugin_repl.py b/server/python/plugin_repl.py index 60ee2e18c..04ef2c2ce 100644 --- a/server/python/plugin_repl.py +++ b/server/python/plugin_repl.py @@ -39,6 +39,8 @@ ColorDepth.default = lambda *args, **kwargs: ColorDepth.DEPTH_4_BIT # the library. The patches here allow us to scope a particular call stack # to a particular REPL, and to get the current Application from the stack. default_get_app = prompt_toolkit.application.current.get_app + + def get_app_patched() -> Application[Any]: stack = inspect.stack() for frame in stack: @@ -46,6 +48,8 @@ def get_app_patched() -> Application[Any]: if self_var is not None and isinstance(self_var, Application): return self_var return default_get_app() + + prompt_toolkit.application.current.get_app = get_app_patched prompt_toolkit.key_binding.key_processor.get_app = get_app_patched prompt_toolkit.contrib.telnet.server.get_app = get_app_patched @@ -141,7 +145,9 @@ async def eval_async_patched(self: PythonRepl, line: str) -> object: def eval_across_loops(code, *args, **kwargs): future = concurrent.futures.Future() - scrypted_loop.call_soon_threadsafe(partial(eval_in_scrypted, future), code, *args, **kwargs) + scrypted_loop.call_soon_threadsafe( + partial(eval_in_scrypted, future), code, *args, **kwargs + ) return future.result() # WORKAROUND: Due to a bug in Jedi, the current directory is removed @@ -192,7 +198,7 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.settimeout(None) - sock.bind(('localhost', 0)) + sock.bind(("localhost", 0)) sock.listen() scrypted_loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() @@ -222,7 +228,7 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int: # Select a free port for the telnet server s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(('localhost', 0)) + s.bind(("localhost", 0)) telnet_port = s.getsockname()[1] s.close() @@ -230,14 +236,19 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int: # repl_loop owns the print capabilities, but the prints will # be executed in scrypted_loop. We need to bridge the two here repl_print = partial(print_formatted_text, output=connection.vt100_output) + def print_across_loops(*args, **kwargs): repl_loop.call_soon_threadsafe(repl_print, *args, **kwargs) global_dict = { **globals(), "print": print_across_loops, - "help": lambda *args, **kwargs: print_across_loops("Help is not available in this environment"), - "input": lambda *args, **kwargs: print_across_loops("Input is not available in this environment"), + "help": lambda *args, **kwargs: print_across_loops( + "Help is not available in this environment" + ), + "input": lambda *args, **kwargs: print_across_loops( + "Input is not available in this environment" + ), } locals_dict = { "device": device, @@ -245,19 +256,32 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int: "deviceManager": deviceManager, "mediaManager": mediaManager, "sdk": sdk, - "realDevice": realDevice + "realDevice": realDevice, } - vars_prompt = '\n'.join([f" {k}" for k in locals_dict.keys()]) + vars_prompt = "\n".join([f" {k}" for k in locals_dict.keys()]) banner = f"Python REPL variables:\n{vars_prompt}" print_formatted_text(banner) - await embed(return_asyncio_coroutine=True, globals=global_dict, locals=locals_dict, configure=partial(configure, scrypted_loop)) + await embed( + return_asyncio_coroutine=True, + globals=global_dict, + locals=locals_dict, + configure=partial(configure, scrypted_loop), + ) server_task: asyncio.Task = None + def ready_cb(): - future.set_result((telnet_port, lambda: repl_loop.call_soon_threadsafe(server_task.cancel))) + future.set_result( + ( + telnet_port, + lambda: repl_loop.call_soon_threadsafe(server_task.cancel), + ) + ) # Start the REPL server - telnet_server = TelnetServer(interact=interact, port=telnet_port, enable_cpr=False) + telnet_server = TelnetServer( + interact=interact, port=telnet_port, enable_cpr=False + ) server_task = asyncio.create_task(telnet_server.run(ready_cb=ready_cb)) try: await server_task @@ -277,16 +301,19 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int: def finish_setup(): telnet_port, exit_server = server_started_future.result() - telnet_client = telnetlib.Telnet('localhost', telnet_port, timeout=None) + telnet_client = telnetlib.Telnet("localhost", telnet_port, timeout=None) def telnet_negotiation_cb(telnet_socket, command, option): pass # ignore telnet negotiation + telnet_client.set_option_negotiation_callback(telnet_negotiation_cb) # initialize telnet terminal # this tells the telnet server we are a vt100 terminal - telnet_client.get_socket().sendall(b'\xff\xfb\x18\xff\xfa\x18\x00\x61\x6e\x73\x69\xff\xf0') - telnet_client.get_socket().sendall(b'\r\n') + telnet_client.get_socket().sendall( + b"\xff\xfb\x18\xff\xfa\x18\x00\x61\x6e\x73\x69\xff\xf0" + ) + telnet_client.get_socket().sendall(b"\r\n") # Bridge the connection to the telnet server, two way def forward_to_telnet(): @@ -303,7 +330,7 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int: while True: data = telnet_client.read_some() if not data: - conn.sendall('REPL exited'.encode()) + conn.sendall("REPL exited".encode()) break if b">>>" in data: # This is an ugly hack - somewhere in ptpython, the @@ -333,4 +360,4 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int: threading.Thread(target=accept_connection).start() proxy_port = sock.getsockname()[1] - return proxy_port \ No newline at end of file + return proxy_port diff --git a/server/python/plugin_volume.py b/server/python/plugin_volume.py index 9d6d34627..ab0a86f10 100644 --- a/server/python/plugin_volume.py +++ b/server/python/plugin_volume.py @@ -1,20 +1,24 @@ import os from pathlib import Path + def get_scrypted_volume(): - volume_dir = os.getenv('SCRYPTED_VOLUME') or Path.home() / '.scrypted' / 'volume' + volume_dir = os.getenv("SCRYPTED_VOLUME") or Path.home() / ".scrypted" / "volume" return str(volume_dir) + def get_plugins_volume(): volume = get_scrypted_volume() - plugins_volume = Path(volume) / 'plugins' + plugins_volume = Path(volume) / "plugins" return str(plugins_volume) + def get_plugin_volume(plugin_id): volume = get_plugins_volume() plugin_volume = Path(volume) / plugin_id return str(plugin_volume) + def ensure_plugin_volume(plugin_id): plugin_volume = get_plugin_volume(plugin_id) try: @@ -23,23 +27,25 @@ def ensure_plugin_volume(plugin_id): pass return plugin_volume + def create_adm_zip_hash(hash): extract_version = "1-" return extract_version + hash + def prep(plugin_volume, hash): hash = create_adm_zip_hash(hash) zip_filename = f"{hash}.zip" - zip_dir = os.path.join(plugin_volume, 'zip') + zip_dir = os.path.join(plugin_volume, "zip") zip_file = os.path.join(zip_dir, zip_filename) - unzipped_path = os.path.join(zip_dir, 'unzipped') - zip_dir_tmp = zip_dir + '.tmp' + unzipped_path = os.path.join(zip_dir, "unzipped") + zip_dir_tmp = zip_dir + ".tmp" return { - 'unzipped_path': unzipped_path, - 'zip_filename': zip_filename, - 'zip_dir': zip_dir, - 'zip_file': zip_file, - 'zip_dir_tmp': zip_dir_tmp, - } \ No newline at end of file + "unzipped_path": unzipped_path, + "zip_filename": zip_filename, + "zip_dir": zip_dir, + "zip_file": zip_file, + "zip_dir_tmp": zip_dir_tmp, + } diff --git a/server/python/rpc-iterator-test.py b/server/python/rpc-iterator-test.py index ac333aff5..3181f071f 100644 --- a/server/python/rpc-iterator-test.py +++ b/server/python/rpc-iterator-test.py @@ -1,22 +1,24 @@ import asyncio import rpc -from rpc_reader import prepare_peer_readloop, RpcFileTransport +from rpc_reader import prepare_peer_readloop, RpcFileTransport import traceback + class Bar: pass + async def main(): peer, peerReadLoop = await prepare_peer_readloop(loop, RpcFileTransport(4, 3)) - peer.params['foo'] = 3 + peer.params["foo"] = 3 jsoncopy = {} jsoncopy[rpc.RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN] = True - jsoncopy['bar'] = Bar() - peer.params['bar'] = jsoncopy + jsoncopy["bar"] = Bar() + peer.params["bar"] = jsoncopy # reader, writer = await asyncio.open_connection( # '127.0.0.1', 6666) - + # writer.write(bytes('abcd', 'utf8')) # async def ticker(delay, to): @@ -27,12 +29,12 @@ async def main(): # peer.params['ticker'] = ticker(0, 3) - print('python starting') + print("python starting") # await peerReadLoop() asyncio.ensure_future(peerReadLoop()) # print('getting param') - test = await peer.getParam('test') + test = await peer.getParam("test") print(test) try: i = 0 @@ -43,7 +45,8 @@ async def main(): i = i + 1 except: traceback.print_exc() - print('all done iterating') + print("all done iterating") + loop = asyncio.new_event_loop() loop.run_until_complete(main()) diff --git a/server/python/rpc.py b/server/python/rpc.py index e9f6603e2..63cc04b4c 100644 --- a/server/python/rpc.py +++ b/server/python/rpc.py @@ -16,7 +16,7 @@ jsonSerializable.add(list) async def maybe_await(value): - if (inspect.isawaitable(value)): + if inspect.isawaitable(value): return await value return value @@ -58,69 +58,112 @@ class LocalProxiedEntry(TypedDict): class RpcProxy(object): - def __init__(self, peer, entry: LocalProxiedEntry, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]): - self.__dict__['__proxy_id'] = entry['id'] - self.__dict__['__proxy_entry'] = entry - self.__dict__['__proxy_constructor'] = proxyConstructorName + def __init__( + self, + peer, + entry: LocalProxiedEntry, + proxyConstructorName: str, + proxyProps: any, + proxyOneWayMethods: List[str], + ): + self.__dict__["__proxy_id"] = entry["id"] + self.__dict__["__proxy_entry"] = entry + self.__dict__["__proxy_constructor"] = proxyConstructorName self.__dict__[RpcPeer.PROPERTY_PROXY_PEER] = peer self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] = proxyProps - self.__dict__['__proxy_oneway_methods'] = proxyOneWayMethods + self.__dict__["__proxy_oneway_methods"] = proxyOneWayMethods def __aiter__(self): - if self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] and 'Symbol(Symbol.asyncIterator)' in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]: + if ( + self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] + and "Symbol(Symbol.asyncIterator)" + in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] + ): return self - raise Exception('RpcProxy is not an async iterable') + raise Exception("RpcProxy is not an async iterable") async def __anext__(self): - if self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] and 'Symbol(Symbol.asyncIterator)' in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]: + if ( + self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] + and "Symbol(Symbol.asyncIterator)" + in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] + ): try: - return await RpcProxyMethod(self, self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]['Symbol(Symbol.asyncIterator)']['next'])() + return await RpcProxyMethod( + self, + self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES][ + "Symbol(Symbol.asyncIterator)" + ]["next"], + )() except RPCResultError as e: - if e.name == 'StopAsyncIteration': + if e.name == "StopAsyncIteration": raise StopAsyncIteration() raise - raise Exception('RpcProxy is not an async iterable') + raise Exception("RpcProxy is not an async iterable") async def aclose(self): - if self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] and 'Symbol(Symbol.asyncIterator)' in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]: + if ( + self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] + and "Symbol(Symbol.asyncIterator)" + in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] + ): try: - return await RpcProxyMethod(self, self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]['Symbol(Symbol.asyncIterator)']['return'])() + return await RpcProxyMethod( + self, + self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES][ + "Symbol(Symbol.asyncIterator)" + ]["return"], + )() except Exception: return - raise Exception('RpcProxy is not an async iterable') + raise Exception("RpcProxy is not an async iterable") def __getattr__(self, name): - if name == '__proxy_finalizer_id': - return self.dict['__proxy_entry']['finalizerId'] + if name == "__proxy_finalizer_id": + return self.dict["__proxy_entry"]["finalizerId"] if name in self.__dict__: return self.__dict__[name] - if self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] and name in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]: + if ( + self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] + and name in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] + ): return self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES][name] return RpcProxyMethod(self, name) def __setattr__(self, name: str, value: Any) -> None: - if name == '__proxy_finalizer_id': - self.__dict__['__proxy_entry']['finalizerId'] = value + if name == "__proxy_finalizer_id": + self.__dict__["__proxy_entry"]["finalizerId"] = value return super().__setattr__(name, value) def __call__(self, *args, **kwargs): - return self.__dict__[RpcPeer.PROPERTY_PROXY_PEER].__apply__(self.__dict__['__proxy_id'], self.__dict__['__proxy_oneway_methods'], None, args) - + return self.__dict__[RpcPeer.PROPERTY_PROXY_PEER].__apply__( + self.__dict__["__proxy_id"], + self.__dict__["__proxy_oneway_methods"], + None, + args, + ) def __apply__(self, method: str, args: list): - return self.__dict__[RpcPeer.PROPERTY_PROXY_PEER].__apply__(self.__dict__['__proxy_id'], self.__dict__['__proxy_oneway_methods'], method, args) + return self.__dict__[RpcPeer.PROPERTY_PROXY_PEER].__apply__( + self.__dict__["__proxy_id"], + self.__dict__["__proxy_oneway_methods"], + method, + args, + ) class RpcPeer: - RPC_RESULT_ERROR_NAME = 'RPCResultError' - PROPERTY_PROXY_PROPERTIES = '__proxy_props' - PROPERTY_JSON_COPY_SERIALIZE_CHILDREN = '__json_copy_serialize_children' - PROPERTY_PROXY_PEER = '__proxy_peer' + RPC_RESULT_ERROR_NAME = "RPCResultError" + PROPERTY_PROXY_PROPERTIES = "__proxy_props" + PROPERTY_JSON_COPY_SERIALIZE_CHILDREN = "__json_copy_serialize_children" + PROPERTY_PROXY_PEER = "__proxy_peer" - def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None: + def __init__( + self, send: Callable[[object, Callable[[Exception], None], Dict], None] + ) -> None: self.send = send - self.peerName = 'Unnamed Peer' + self.peerName = "Unnamed Peer" self.params: Mapping[str, any] = {} self.localProxied: Mapping[any, LocalProxiedEntry] = {} self.localProxyMap: Mapping[str, any] = {} @@ -132,7 +175,9 @@ class RpcPeer: self.killed = False self.tags = {} - def __apply__(self, proxyId: str, oneWayMethods: List[str], method: str, args: list): + def __apply__( + self, proxyId: str, oneWayMethods: List[str], method: str, args: list + ): oneway = oneWayMethods and method in oneWayMethods if self.killed: @@ -140,7 +185,9 @@ class RpcPeer: if oneway: future.set_result(None) return future - future.set_exception(RPCResultError(None, 'RpcPeer has been killed (apply) ' + str(method))) + future.set_exception( + RPCResultError(None, "RpcPeer has been killed (apply) " + str(method)) + ) return future serializationContext: Dict = {} @@ -149,23 +196,24 @@ class RpcPeer: serializedArgs.append(self.serialize(arg, serializationContext)) rpcApply = { - 'type': 'apply', - 'id': None, - 'proxyId': proxyId, - 'args': serializedArgs, - 'method': method, + "type": "apply", + "id": None, + "proxyId": proxyId, + "args": serializedArgs, + "method": method, } if oneway: - rpcApply['oneway'] = True + rpcApply["oneway"] = True self.send(rpcApply, None, serializationContext) future = Future() future.set_result(None) return future async def send(id: str, reject: Callable[[Exception], None]): - rpcApply['id'] = id + rpcApply["id"] = id self.send(rpcApply, reject, serializationContext) + return self.createPendingResult(send) def kill(self, message: str = None): @@ -174,7 +222,7 @@ class RpcPeer: return self.killed = True - error = RPCResultError(None, message or 'peer was killed') + error = RPCResultError(None, message or "peer was killed") # this.killedDeferred.reject(error); for str, future in self.pendingResults.items(): future.set_exception(error) @@ -187,14 +235,14 @@ class RpcPeer: def createErrorResult(self, result: Any, e: Exception): s = self.serializeError(e) - result['result'] = s - result['throw'] = True + result["result"] = s + result["throw"] = True return result def deserializeError(e: Dict) -> RPCResultError: - error = RPCResultError(None, e.get('message')) - error.stack = e.get('stack') - error.name = e.get('name') + error = RPCResultError(None, e.get("message")) + error.stack = e.get("stack") + error.name = e.get("name") return error def serializeError(self, e: Exception): @@ -203,16 +251,16 @@ class RpcPeer: message = str(e) serialized = { - 'stack': tb or '[no stack]', - 'name': name or '[no name]', - 'message': message or '[no message]', + "stack": tb or "[no stack]", + "name": name or "[no name]", + "message": message or "[no message]", } return { - '__remote_constructor_name': RpcPeer.RPC_RESULT_ERROR_NAME, - '__serialized_value': serialized, + "__remote_constructor_name": RpcPeer.RPC_RESULT_ERROR_NAME, + "__serialized_value": serialized, } - + # def getProxyProperties(value): # return getattr(value, RpcPeer.PROPERTY_PROXY_PROPERTIES, None) @@ -220,15 +268,15 @@ class RpcPeer: # setattr(value, RpcPeer.PROPERTY_PROXY_PROPERTIES, properties) def prepareProxyProperties(value): - if not hasattr(value, '__aiter__') or not hasattr(value, '__anext__'): + if not hasattr(value, "__aiter__") or not hasattr(value, "__anext__"): return getattr(value, RpcPeer.PROPERTY_PROXY_PROPERTIES, None) props = getattr(value, RpcPeer.PROPERTY_PROXY_PROPERTIES, None) or {} - if not props.get('Symbol(Symbol.asyncIterator)'): - props['Symbol(Symbol.asyncIterator)'] = { - 'next': '__anext__', - 'throw': 'athrow', - 'return': 'aclose', + if not props.get("Symbol(Symbol.asyncIterator)"): + props["Symbol(Symbol.asyncIterator)"] = { + "next": "__anext__", + "throw": "athrow", + "return": "aclose", } return props @@ -236,34 +284,44 @@ class RpcPeer: return not value or (type(value) in jsonSerializable) def serialize(self, value, serializationContext: Dict): - if type(value) == dict and value.get(RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN, None): + if type(value) == dict and value.get( + RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN, None + ): ret = {} - for (key, val) in value.items(): + for key, val in value.items(): ret[key] = self.serialize(val, serializationContext) return ret - if (RpcPeer.isTransportSafe(value)): + if RpcPeer.isTransportSafe(value): return value - __remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr( - value, '__proxy_constructor') else type(value).__name__ + __remote_constructor_name = ( + "Function" + if callable(value) + else ( + value.__proxy_constructor + if hasattr(value, "__proxy_constructor") + else type(value).__name__ + ) + ) if isinstance(value, Exception): return self.serializeError(value) - serializerMapName = self.constructorSerializerMap.get( - type(value), None) + serializerMapName = self.constructorSerializerMap.get(type(value), None) if serializerMapName: __remote_constructor_name = serializerMapName serializer = self.nameDeserializerMap.get(serializerMapName, None) serialized = serializer.serialize(value, serializationContext) ret = { - '__remote_proxy_id': None, - '__remote_proxy_finalizer_id': None, - '__remote_constructor_name': __remote_constructor_name, - '__remote_proxy_props': RpcPeer.prepareProxyProperties(value), - '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), - '__serialized_value': serialized, + "__remote_proxy_id": None, + "__remote_proxy_finalizer_id": None, + "__remote_constructor_name": __remote_constructor_name, + "__remote_proxy_props": RpcPeer.prepareProxyProperties(value), + "__remote_proxy_oneway_methods": getattr( + value, "__proxy_oneway_methods", None + ), + "__serialized_value": serialized, } return ret @@ -273,26 +331,28 @@ class RpcPeer: proxyId, __remote_proxy_props = self.onProxySerialization(value) else: __remote_proxy_props = RpcPeer.prepareProxyProperties(value) - proxyId = proxiedEntry['id'] + proxyId = proxiedEntry["id"] - if proxyId != proxiedEntry['id']: - raise Exception('onProxySerialization proxy id mismatch') + if proxyId != proxiedEntry["id"]: + raise Exception("onProxySerialization proxy id mismatch") - proxiedEntry['finalizerId'] = RpcPeer.generateId() + proxiedEntry["finalizerId"] = RpcPeer.generateId() ret = { - '__remote_proxy_id': proxyId, - '__remote_proxy_finalizer_id': proxiedEntry['finalizerId'], - '__remote_constructor_name': __remote_constructor_name, - '__remote_proxy_props': __remote_proxy_props, - '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), + "__remote_proxy_id": proxyId, + "__remote_proxy_finalizer_id": proxiedEntry["finalizerId"], + "__remote_constructor_name": __remote_constructor_name, + "__remote_proxy_props": __remote_proxy_props, + "__remote_proxy_oneway_methods": getattr( + value, "__proxy_oneway_methods", None + ), } return ret - __proxy_id = getattr(value, '__proxy_id', None) + __proxy_id = getattr(value, "__proxy_id", None) __proxy_peer = getattr(value, RpcPeer.PROPERTY_PROXY_PEER, None) if __proxy_id and __proxy_peer == self: ret = { - '__local_proxy_id': __proxy_id, + "__local_proxy_id": __proxy_id, } return ret @@ -303,18 +363,20 @@ class RpcPeer: proxyId = RpcPeer.generateId() proxiedEntry = { - 'id': proxyId, - 'finalizerId': proxyId, + "id": proxyId, + "finalizerId": proxyId, } self.localProxied[value] = proxiedEntry self.localProxyMap[proxyId] = value ret = { - '__remote_proxy_id': proxyId, - '__remote_proxy_finalizer_id': proxyId, - '__remote_constructor_name': __remote_constructor_name, - '__remote_proxy_props': __remote_proxy_props, - '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), + "__remote_proxy_id": proxyId, + "__remote_proxy_finalizer_id": proxyId, + "__remote_constructor_name": __remote_constructor_name, + "__remote_proxy_props": __remote_proxy_props, + "__remote_proxy_oneway_methods": getattr( + value, "__proxy_oneway_methods", None + ), } return ret @@ -322,22 +384,33 @@ class RpcPeer: def finalize(self, localProxiedEntry: LocalProxiedEntry): if self.killed: return - id = localProxiedEntry['id'] + id = localProxiedEntry["id"] self.remoteWeakProxies.pop(id, None) rpcFinalize = { - '__local_proxy_id': id, - '__local_proxy_finalizer_id': localProxiedEntry['finalizerId'], - 'type': 'finalize', + "__local_proxy_id": id, + "__local_proxy_finalizer_id": localProxiedEntry["finalizerId"], + "type": "finalize", } self.send(rpcFinalize) - def newProxy(self, proxyId: str, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]): + def newProxy( + self, + proxyId: str, + proxyConstructorName: str, + proxyProps: any, + proxyOneWayMethods: List[str], + ): localProxiedEntry: LocalProxiedEntry = { - 'id': proxyId, - 'finalizerId': None, + "id": proxyId, + "finalizerId": None, } - proxy = RpcProxy(self, localProxiedEntry, proxyConstructorName, - proxyProps, proxyOneWayMethods) + proxy = RpcProxy( + self, + localProxiedEntry, + proxyConstructorName, + proxyProps, + proxyOneWayMethods, + ) wr = weakref.ref(proxy) self.remoteWeakProxies[proxyId] = wr weakref.finalize(proxy, lambda: self.finalize(localProxiedEntry)) @@ -350,23 +423,22 @@ class RpcPeer: if type(value) != dict: return value - copySerializeChildren = value.get(RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN, None) + copySerializeChildren = value.get( + RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN, None + ) if copySerializeChildren: ret = {} - for (key, val) in value.items(): + for key, val in value.items(): ret[key] = self.deserialize(val, deserializationContext) return ret - __remote_proxy_id = value.get('__remote_proxy_id', None) - __remote_proxy_finalizer_id = value.get( - '__remote_proxy_finalizer_id', None) - __local_proxy_id = value.get('__local_proxy_id', None) - __remote_constructor_name = value.get( - '__remote_constructor_name', None) - __serialized_value = value.get('__serialized_value', None) - __remote_proxy_props = value.get('__remote_proxy_props', None) - __remote_proxy_oneway_methods = value.get( - '__remote_proxy_oneway_methods', None) + __remote_proxy_id = value.get("__remote_proxy_id", None) + __remote_proxy_finalizer_id = value.get("__remote_proxy_finalizer_id", None) + __local_proxy_id = value.get("__local_proxy_id", None) + __remote_constructor_name = value.get("__remote_constructor_name", None) + __serialized_value = value.get("__serialized_value", None) + __remote_proxy_props = value.get("__remote_proxy_props", None) + __remote_proxy_oneway_methods = value.get("__remote_proxy_oneway_methods", None) if __remote_constructor_name == RpcPeer.RPC_RESULT_ERROR_NAME: return RpcPeer.deserializeError(__serialized_value) @@ -375,66 +447,71 @@ class RpcPeer: weakref = self.remoteWeakProxies.get(__remote_proxy_id, None) proxy = weakref() if weakref else None if not proxy: - proxy = self.newProxy(__remote_proxy_id, __remote_constructor_name, - __remote_proxy_props, __remote_proxy_oneway_methods) - setattr(proxy, '__proxy_finalizer_id', __remote_proxy_finalizer_id) + proxy = self.newProxy( + __remote_proxy_id, + __remote_constructor_name, + __remote_proxy_props, + __remote_proxy_oneway_methods, + ) + setattr(proxy, "__proxy_finalizer_id", __remote_proxy_finalizer_id) return proxy if __local_proxy_id: ret = self.localProxyMap.get(__local_proxy_id, None) if not ret: raise RPCResultError( - None, 'invalid local proxy id %s' % __local_proxy_id) + None, "invalid local proxy id %s" % __local_proxy_id + ) return ret - deserializer = self.nameDeserializerMap.get( - __remote_constructor_name, None) + deserializer = self.nameDeserializerMap.get(__remote_constructor_name, None) if deserializer: return deserializer.deserialize(__serialized_value, deserializationContext) return value def sendResult(self, result: Dict, serializationContext: Dict): - self.send(result, lambda e: self.send(self.createErrorResult(result, e, None), None), serializationContext) + self.send( + result, + lambda e: self.send(self.createErrorResult(result, e, None), None), + serializationContext, + ) async def handleMessage(self, message: Dict, deserializationContext: Dict): try: - messageType = message['type'] - if messageType == 'param': + messageType = message["type"] + if messageType == "param": result = { - 'type': 'result', - 'id': message['id'], + "type": "result", + "id": message["id"], } serializationContext: Dict = {} try: - value = self.params.get(message['param'], None) + value = self.params.get(message["param"], None) value = await maybe_await(value) - result['result'] = self.serialize(value, serializationContext) + result["result"] = self.serialize(value, serializationContext) except Exception as e: tb = traceback.format_exc() - self.createErrorResult( - result, type(e).__name, str(e), tb) + self.createErrorResult(result, type(e).__name, str(e), tb) self.sendResult(result, serializationContext) - elif messageType == 'apply': + elif messageType == "apply": result = { - 'type': 'result', - 'id': message.get('id', None), + "type": "result", + "id": message.get("id", None), } - method = message.get('method', None) + method = message.get("method", None) try: serializationContext: Dict = {} - target = self.localProxyMap.get( - message['proxyId'], None) + target = self.localProxyMap.get(message["proxyId"], None) if not target: - raise Exception('proxy id %s not found' % - message['proxyId']) + raise Exception("proxy id %s not found" % message["proxyId"]) args = [] - for arg in (message['args'] or []): + for arg in message["args"] or []: args.append(self.deserialize(arg, deserializationContext)) # if method == 'asend' and hasattr(target, '__aiter__') and hasattr(target, '__anext__') and not len(args): @@ -444,47 +521,53 @@ class RpcPeer: if method: if not hasattr(target, method): raise Exception( - 'target %s does not have method %s' % (type(target), method)) + "target %s does not have method %s" + % (type(target), method) + ) invoke = getattr(target, method) value = await maybe_await(invoke(*args)) else: value = await maybe_await(target(*args)) - result['result'] = self.serialize(value, serializationContext) + result["result"] = self.serialize(value, serializationContext) except StopAsyncIteration as e: self.createErrorResult(result, e) except Exception as e: self.createErrorResult(result, e) - if not message.get('oneway', False): + if not message.get("oneway", False): self.sendResult(result, serializationContext) - elif messageType == 'result': - id = message['id'] + elif messageType == "result": + id = message["id"] future = self.pendingResults.get(id, None) if not future: - raise RPCResultError( - None, 'unknown result %s' % id) + raise RPCResultError(None, "unknown result %s" % id) del self.pendingResults[id] - deserialized = self.deserialize(message.get('result', None), deserializationContext) - if message.get('throw'): + deserialized = self.deserialize( + message.get("result", None), deserializationContext + ) + if message.get("throw"): future.set_exception(deserialized) else: future.set_result(deserialized) - elif messageType == 'finalize': - finalizerId = message.get('__local_proxy_finalizer_id', None) - proxyId = message['__local_proxy_id'] + elif messageType == "finalize": + finalizerId = message.get("__local_proxy_finalizer_id", None) + proxyId = message["__local_proxy_id"] local = self.localProxyMap.get(proxyId, None) if local: localProxiedEntry = self.localProxied.get(local) - if localProxiedEntry and finalizerId and localProxiedEntry['finalizerId'] != finalizerId: + if ( + localProxiedEntry + and finalizerId + and localProxiedEntry["finalizerId"] != finalizerId + ): # print('mismatch finalizer id', file=sys.stderr) return self.localProxied.pop(local, None) local = self.localProxyMap.pop(proxyId, None) else: - raise RPCResultError( - None, 'unknown rpc message type %s' % type) + raise RPCResultError(None, "unknown rpc message type %s" % type) except Exception as e: print("unhandled rpc error", self.peerName, e) pass @@ -492,12 +575,16 @@ class RpcPeer: randomDigits = string.ascii_uppercase + string.ascii_lowercase + string.digits def generateId(): - return ''.join(random.choices(RpcPeer.randomDigits, k=8)) + return "".join(random.choices(RpcPeer.randomDigits, k=8)) - async def createPendingResult(self, cb: Callable[[str, Callable[[Exception], None]], None]): + async def createPendingResult( + self, cb: Callable[[str, Callable[[Exception], None]], None] + ): future = Future() if self.killed: - future.set_exception(RPCResultError(None, 'RpcPeer has been killed (createPendingResult)')) + future.set_exception( + RPCResultError(None, "RpcPeer has been killed (createPendingResult)") + ) return future id = RpcPeer.generateId() @@ -508,9 +595,10 @@ class RpcPeer: async def getParam(self, param): async def send(id: str, reject: Callable[[Exception], None]): paramMessage = { - 'id': id, - 'type': 'param', - 'param': param, + "id": id, + "type": "param", + "param": param, } self.send(paramMessage, reject) + return await self.createPendingResult(send) diff --git a/server/python/rpc_reader.py b/server/python/rpc_reader.py index f6a5d3e1e..675aaf5a8 100644 --- a/server/python/rpc_reader.py +++ b/server/python/rpc_reader.py @@ -16,7 +16,7 @@ import json class BufferSerializer(rpc.RpcSerializer): def serialize(self, value, serializationContext): - return base64.b64encode(value).decode('utf8') + return base64.b64encode(value).decode("utf8") def deserialize(self, value, serializationContext): return base64.b64decode(value) @@ -24,15 +24,15 @@ class BufferSerializer(rpc.RpcSerializer): class SidebandBufferSerializer(rpc.RpcSerializer): def serialize(self, value, serializationContext): - buffers = serializationContext.get('buffers', None) + buffers = serializationContext.get("buffers", None) if not buffers: buffers = [] - serializationContext['buffers'] = buffers + serializationContext["buffers"] = buffers buffers.append(value) return len(buffers) - 1 def deserialize(self, value, serializationContext): - buffers: List = serializationContext.get('buffers', None) + buffers: List = serializationContext.get("buffers", None) buffer = buffers.pop() return buffer @@ -56,7 +56,7 @@ class RpcFileTransport(RpcTransport): super().__init__() self.readFd = readFd self.writeFd = writeFd - self.executor = ThreadPoolExecutor(1, 'rpc-read') + self.executor = ThreadPoolExecutor(1, "rpc-read") def osReadExact(self, size: int): b = bytes(0) @@ -64,7 +64,7 @@ class RpcFileTransport(RpcTransport): got = os.read(self.readFd, size) if not len(got): self.executor.shutdown(False) - raise Exception('rpc end of stream reached') + raise Exception("rpc end of stream reached") size -= len(got) b += got return b @@ -73,7 +73,7 @@ class RpcFileTransport(RpcTransport): lengthBytes = self.osReadExact(4) typeBytes = self.osReadExact(1) type = typeBytes[0] - length = int.from_bytes(lengthBytes, 'big') + length = int.from_bytes(lengthBytes, "big") data = self.osReadExact(length - 1) if type == 1: return data @@ -81,11 +81,13 @@ class RpcFileTransport(RpcTransport): return message async def read(self): - return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.readMessageInternal()) + return await asyncio.get_event_loop().run_in_executor( + self.executor, lambda: self.readMessageInternal() + ) def writeMessage(self, type: int, buffer, reject): length = len(buffer) + 1 - lb = length.to_bytes(4, 'big') + lb = length.to_bytes(4, "big") try: for b in [lb, bytes([type]), buffer]: os.write(self.writeFd, b) @@ -94,14 +96,18 @@ class RpcFileTransport(RpcTransport): reject(e) def writeJSON(self, j, reject): - return self.writeMessage(0, bytes(json.dumps(j, allow_nan=False), 'utf8'), reject) + return self.writeMessage( + 0, bytes(json.dumps(j, allow_nan=False), "utf8"), reject + ) def writeBuffer(self, buffer, reject): return self.writeMessage(1, buffer, reject) class RpcStreamTransport(RpcTransport): - def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + def __init__( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: super().__init__() self.reader = reader self.writer = writer @@ -110,7 +116,7 @@ class RpcStreamTransport(RpcTransport): lengthBytes = await self.reader.readexactly(4) typeBytes = await self.reader.readexactly(1) type = typeBytes[0] - length = int.from_bytes(lengthBytes, 'big') + length = int.from_bytes(lengthBytes, "big") data = await self.reader.readexactly(length - 1) if type == 1: return data @@ -119,7 +125,7 @@ class RpcStreamTransport(RpcTransport): def writeMessage(self, type: int, buffer, reject): length = len(buffer) + 1 - lb = length.to_bytes(4, 'big') + lb = length.to_bytes(4, "big") try: for b in [lb, bytes([type]), buffer]: self.writer.write(b) @@ -128,7 +134,9 @@ class RpcStreamTransport(RpcTransport): reject(e) def writeJSON(self, j, reject): - return self.writeMessage(0, bytes(json.dumps(j, allow_nan=False), 'utf8'), reject) + return self.writeMessage( + 0, bytes(json.dumps(j, allow_nan=False), "utf8"), reject + ) def writeBuffer(self, buffer, reject): return self.writeMessage(1, buffer, reject) @@ -141,7 +149,9 @@ class RpcConnectionTransport(RpcTransport): self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) async def read(self): - return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.connection.recv()) + return await asyncio.get_event_loop().run_in_executor( + self.executor, lambda: self.connection.recv() + ) def writeMessage(self, json, reject): try: @@ -158,23 +168,20 @@ class RpcConnectionTransport(RpcTransport): async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport): - deserializationContext = { - 'buffers': [] - } + deserializationContext = {"buffers": []} while True: message = await rpcTransport.read() if type(message) != dict: - deserializationContext['buffers'].append(message) + deserializationContext["buffers"].append(message) continue asyncio.run_coroutine_threadsafe( - peer.handleMessage(message, deserializationContext), loop) + peer.handleMessage(message, deserializationContext), loop + ) - deserializationContext = { - 'buffers': [] - } + deserializationContext = {"buffers": []} async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport): @@ -185,7 +192,7 @@ async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransp def send(message, reject=None, serializationContext=None): with mutex: if serializationContext: - buffers = serializationContext.get('buffers', None) + buffers = serializationContext.get("buffers", None) if buffers: for buffer in buffers: rpcTransport.writeBuffer(buffer, reject) @@ -193,10 +200,10 @@ async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransp rpcTransport.writeJSON(message, reject) peer = rpc.RpcPeer(send) - peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer() - peer.constructorSerializerMap[bytes] = 'Buffer' - peer.constructorSerializerMap[bytearray] = 'Buffer' - peer.constructorSerializerMap[memoryview] = 'Buffer' + peer.nameDeserializerMap["Buffer"] = SidebandBufferSerializer() + peer.constructorSerializerMap[bytes] = "Buffer" + peer.constructorSerializerMap[bytearray] = "Buffer" + peer.constructorSerializerMap[memoryview] = "Buffer" async def peerReadLoop(): try: