diff --git a/server/python/plugin-remote.py b/server/python/plugin-remote.py index ff1b8a150..690fb4b23 100644 --- a/server/python/plugin-remote.py +++ b/server/python/plugin-remote.py @@ -17,7 +17,7 @@ from asyncio.streams import StreamReader, StreamWriter from collections.abc import Mapping from io import StringIO from os import sys -from typing import Any, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import aiofiles import scrypted_python.scrypted_sdk.types @@ -139,13 +139,27 @@ class DeviceManager(scrypted_python.scrypted_sdk.types.DeviceManager): return self.nativeIds.get(nativeId, None) class BufferSerializer(rpc.RpcSerializer): - def serialize(self, value): + def serialize(self, value, serializationContext): return base64.b64encode(value).decode('utf8') - def deserialize(self, value): + def deserialize(self, value, serializationContext): return base64.b64decode(value) +class SidebandBufferSerializer(rpc.RpcSerializer): + def serialize(self, value, serializationContext): + buffers = serializationContext.get('buffers', None) + if not buffers: + buffers = [] + serializationContext['buffers'] = buffers + buffers.append(value) + return len(buffers) - 1 + + def deserialize(self, value, serializationContext): + buffers: List = serializationContext.get('buffers', None) + buffer = buffers.pop() + return buffer + class PluginRemote: systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {} nativeIds: Mapping[str, DeviceStorage] = {} @@ -329,29 +343,69 @@ class PluginRemote: pass -async def readLoop(loop, peer, reader): - async for line in reader: +async def readLoop(loop, peer: rpc.RpcPeer, reader): + deserializationContext = { + 'buffers': [] + } + + while True: try: - message = json.loads(line) - asyncio.run_coroutine_threadsafe(peer.handleMessage(message), loop) + lengthBytes = await reader.read(4) + typeBytes = await reader.read(1) + type = typeBytes[0] + length = int.from_bytes(lengthBytes, 'big') + data = await reader.read(length - 1) + + if type == 1: + deserializationContext['buffers'].append(data) + continue + + message = json.loads(data) + asyncio.run_coroutine_threadsafe(peer.handleMessage(message, deserializationContext), loop) + + deserializationContext = { + 'buffers': [] + } except Exception as e: print('read loop error', e) sys.exit() async def async_main(loop: AbstractEventLoop): - reader = await aiofiles.open(3, mode='r') + reader = await aiofiles.open(3, mode='rb') + + def send(message, reject=None, serializationContext = None): + if serializationContext: + buffers = serializationContext.get('buffers', None) + if buffers: + for buffer in buffers: + length = len(buffer) + 1 + lb = length.to_bytes(4, 'big') + type = 1 + try: + os.write(4, lb) + os.write(4, bytes([type])) + os.write(4, buffer) + except Exception as e: + if reject: + reject(e) + return - def send(message, reject=None): jsonString = json.dumps(message) + b = bytes(jsonString, 'utf8') + length = len(b) + 1 + lb = length.to_bytes(4, 'big') + type = 0 try: - os.write(4, bytes(jsonString + '\n', 'utf8')) + os.write(4, lb) + os.write(4, bytes([type])) + os.write(4, b) except Exception as e: if reject: reject(e) peer = rpc.RpcPeer(send) - peer.nameDeserializerMap['Buffer'] = BufferSerializer() + peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer() peer.constructorSerializerMap[bytes] = 'Buffer' peer.constructorSerializerMap[bytearray] = 'Buffer' peer.params['print'] = print diff --git a/server/python/rpc.py b/server/python/rpc.py index 905c7eb4e..16db4eead 100644 --- a/server/python/rpc.py +++ b/server/python/rpc.py @@ -1,5 +1,5 @@ from asyncio.futures import Future -from typing import Any, Callable, Mapping, List +from typing import Any, Callable, Dict, Mapping, List import traceback import inspect from typing_extensions import TypedDict @@ -31,10 +31,10 @@ class RpcResultException(Exception): class RpcSerializer: - def serialize(self, value): + def serialize(self, value, serializationContext): pass - def deserialize(self, value): + def deserialize(self, value, deserializationContext): pass @@ -85,7 +85,7 @@ class RpcProxy(object): class RpcPeer: - def __init__(self, send: Callable[[object, Callable[[Exception], None]], None]) -> None: + def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None: self.send = send self.idCounter = 1 self.peerName = 'Unnamed Peer' @@ -99,9 +99,10 @@ class RpcPeer: self.nameDeserializerMap: Mapping[str, RpcSerializer] = {} def __apply__(self, proxyId: str, oneWayMethods: List[str], method: str, args: list): + serializationContext: Dict = {} serializedArgs = [] for arg in args: - serializedArgs.append(self.serialize(arg, False)) + serializedArgs.append(self.serialize(arg, False, serializationContext)) rpcApply = { 'type': 'apply', @@ -113,25 +114,25 @@ class RpcPeer: if oneWayMethods and method in oneWayMethods: rpcApply['oneway'] = True - self.send(rpcApply) + self.send(rpcApply, None, serializationContext) future = Future() future.set_result(None) return future async def send(id: str, reject: Callable[[Exception], None]): rpcApply['id'] = id - self.send(rpcApply, reject) + self.send(rpcApply, reject, serializationContext) return self.createPendingResult(send) def kill(self): self.killed = True - def createErrorResult(self, result: any, name: str, message: str, tb: str): + def createErrorResult(self, result: Any, name: str, message: str, tb: str): result['stack'] = tb if tb else 'no stack' result['result'] = name if name else 'no name' result['message'] = message if message else 'no message' - def serialize(self, value, requireProxy): + def serialize(self, value, requireProxy, serializationContext: Dict): if (not value or (not requireProxy and type(value) in jsonSerializable)): return value @@ -164,7 +165,7 @@ class RpcPeer: if serializerMapName: __remote_constructor_name = serializerMapName serializer = self.nameDeserializerMap.get(serializerMapName, None) - serialized = serializer.serialize(value) + serialized = serializer.serialize(value, serializationContext) ret = { '__remote_proxy_id': None, '__remote_proxy_finalizer_id': None, @@ -216,7 +217,7 @@ class RpcPeer: weakref.finalize(proxy, lambda: self.finalize(localProxiedEntry)) return proxy - def deserialize(self, value): + def deserialize(self, value, deserializationContext: Dict): if not value: return value @@ -253,11 +254,11 @@ class RpcPeer: deserializer = self.nameDeserializerMap.get( __remote_constructor_name, None) if deserializer: - return deserializer.deserialize(__serialized_value) + return deserializer.deserialize(__serialized_value, deserializationContext) return value - async def handleMessage(self, message: any): + async def handleMessage(self, message: Any, deserializationContext: Dict): try: messageType = message['type'] if messageType == 'param': @@ -266,17 +267,18 @@ class RpcPeer: 'id': message['id'], } + serializationContext: Dict = {} try: value = self.params.get(message['param'], None) value = await maybe_await(value) result['result'] = self.serialize( - value, message.get('requireProxy', None)) + value, message.get('requireProxy', None), serializationContext) except Exception as e: tb = traceback.format_exc() self.createErrorResult( result, type(e).__name, str(e), tb) - self.send(result) + self.send(result, None, serializationContext) elif messageType == 'apply': result = { @@ -286,6 +288,7 @@ class RpcPeer: method = message.get('method', None) try: + serializationContext: Dict = {} target = self.localProxyMap.get( message['proxyId'], None) if not target: @@ -294,7 +297,7 @@ class RpcPeer: args = [] for arg in (message['args'] or []): - args.append(self.deserialize(arg)) + args.append(self.deserialize(arg, deserializationContext)) value = None if method: @@ -306,7 +309,7 @@ class RpcPeer: else: value = await maybe_await(target(*args)) - result['result'] = self.serialize(value, False) + result['result'] = self.serialize(value, False, serializationContext) except Exception as e: tb = traceback.format_exc() # print('failure', method, e, tb) @@ -314,7 +317,7 @@ class RpcPeer: result, type(e).__name__, str(e), tb) if not message.get('oneway', False): - self.send(result) + self.send(result, serializationContext) elif messageType == 'result': id = message['id'] @@ -331,7 +334,7 @@ class RpcPeer: future.set_exception(e) return future.set_result(self.deserialize( - message.get('result', None))) + message.get('result', None), deserializationContext)) elif messageType == 'finalize': finalizerId = message.get('__local_proxy_finalizer_id', None) proxyId = message['__local_proxy_id'] diff --git a/server/src/plugin/plugin-host.ts b/server/src/plugin/plugin-host.ts index 9c917e756..896d79662 100644 --- a/server/src/plugin/plugin-host.ts +++ b/server/src/plugin/plugin-host.ts @@ -293,9 +293,9 @@ export class PluginHost { } } - this.peer = new RpcPeer('host', this.pluginId, (message, reject) => { + this.peer = new RpcPeer('host', this.pluginId, (message, reject, serializationContext) => { if (connected) { - this.worker.send(message, reject); + this.worker.send(message, reject, serializationContext); } else if (reject) { reject(new Error('peer disconnected')); diff --git a/server/src/plugin/runtime/python-worker.ts b/server/src/plugin/runtime/python-worker.ts index 780fa4454..e1d99a020 100644 --- a/server/src/plugin/runtime/python-worker.ts +++ b/server/src/plugin/runtime/python-worker.ts @@ -5,10 +5,12 @@ import path from 'path'; import readline from 'readline'; import { Readable, Writable } from 'stream'; import { RpcMessage, RpcPeer } from "../../rpc"; +import { createRpcDuplexSerializer } from '../../rpc-serializer'; import { ChildProcessWorker } from "./child-process-worker"; import { RuntimeWorkerOptions } from "./runtime-worker"; export class PythonRuntimeWorker extends ChildProcessWorker { + serializer: ReturnType; constructor(pluginId: string, options: RuntimeWorkerOptions) { super(pluginId, options); @@ -62,21 +64,24 @@ export class PythonRuntimeWorker extends ChildProcessWorker { const peerin = this.worker.stdio[3] as Writable; const peerout = this.worker.stdio[4] as Readable; - peerin.on('error', e => this.emit('error', e)); - peerout.on('error', e => this.emit('error', e)); - - const readInterface = readline.createInterface({ - input: peerout, - terminal: false, + const serializer = this.serializer = createRpcDuplexSerializer(peerin); + serializer.setupRpcPeer(peer); + peerout.on('data', data => serializer.onData(data)); + peerin.on('error', e => { + this.emit('error', e); + serializer.onDisconnected(); + }); + peerout.on('error', e => { + this.emit('error', e) + serializer.onDisconnected(); }); - readInterface.on('line', line => peer.handleMessage(JSON.parse(line))); } - send(message: RpcMessage, reject?: (e: Error) => void): void { + send(message: RpcMessage, reject?: (e: Error) => void, serializationContext?: any): void { try { if (!this.worker) throw new Error('worked has been killed'); - (this.worker.stdio[3] as Writable).write(JSON.stringify(message) + '\n', e => e && reject?.(e)); + this.serializer.sendMessage(message, reject, serializationContext); } catch (e) { reject?.(e);