From 096c036ea2e310c116f4c850b15e2d0abb4032db Mon Sep 17 00:00:00 2001 From: Koushik Dutta Date: Thu, 2 Mar 2023 21:03:29 -0800 Subject: [PATCH] rpc: implement python async iterator --- server/python/plugin_remote.py | 27 +------- server/python/rpc-iterator-test.py | 21 ++++++ server/python/rpc.py | 103 ++++++++++++++++++++++------- server/python/rpc_reader.py | 21 +----- server/src/rpc.ts | 79 ++++++++++++++++++---- server/test/rpc-iterator-test.ts | 26 +++++--- server/test/rpc-python-test.ts | 27 ++++++++ 7 files changed, 211 insertions(+), 93 deletions(-) create mode 100644 server/python/rpc-iterator-test.py create mode 100644 server/test/rpc-python-test.ts diff --git a/server/python/plugin_remote.py b/server/python/plugin_remote.py index 7d0148eb8..5e700aa63 100644 --- a/server/python/plugin_remote.py +++ b/server/python/plugin_remote.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import base64 import gc import sys import os @@ -18,7 +17,7 @@ from asyncio.streams import StreamReader, StreamWriter from collections.abc import Mapping from io import StringIO from os import sys -from typing import Any, List, Optional, Set, Tuple +from typing import Any, Optional, Set, Tuple import scrypted_python.scrypted_sdk.types from scrypted_python.scrypted_sdk import ScryptedStatic, PluginFork @@ -202,30 +201,6 @@ class DeviceManager(scrypted_python.scrypted_sdk.types.DeviceManager): def getDeviceStorage(self, nativeId: str = None) -> Storage: return self.nativeIds.get(nativeId, None) - -class BufferSerializer(rpc.RpcSerializer): - def serialize(self, value, serializationContext): - return base64.b64encode(value).decode('utf8') - - def deserialize(self, value, serializationContext): - return base64.b64decode(value) - - -class SidebandBufferSerializer(rpc.RpcSerializer): - def serialize(self, value, serializationContext): - buffers = serializationContext.get('buffers', None) - if not buffers: - buffers = [] - serializationContext['buffers'] = buffers - buffers.append(value) - return len(buffers) - 1 - - def deserialize(self, value, serializationContext): - buffers: List = serializationContext.get('buffers', None) - buffer = buffers.pop() - return buffer - - class PluginRemote: systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {} nativeIds: Mapping[str, DeviceStorage] = {} diff --git a/server/python/rpc-iterator-test.py b/server/python/rpc-iterator-test.py new file mode 100644 index 000000000..d4e9bcf6f --- /dev/null +++ b/server/python/rpc-iterator-test.py @@ -0,0 +1,21 @@ +import sys +import asyncio +from rpc_reader import prepare_peer_readloop + +async def main(): + peer, peerReadLoop = await prepare_peer_readloop(loop, 4, 3) + peer.params['foo'] = 3 + + async def ticker(delay, to): + for i in range(to): + # print(i) + yield i + await asyncio.sleep(delay) + + peer.params['ticker'] = ticker(0, 3) + + print('python starting') + await peerReadLoop() + +loop = asyncio.new_event_loop() +loop.run_until_complete(main()) diff --git a/server/python/rpc.py b/server/python/rpc.py index cee0c1538..695defdb9 100644 --- a/server/python/rpc.py +++ b/server/python/rpc.py @@ -4,7 +4,6 @@ import traceback import inspect from typing_extensions import TypedDict import weakref -import sys jsonSerializable = set() jsonSerializable.add(float) @@ -16,14 +15,16 @@ jsonSerializable.add(list) async def maybe_await(value): - if (inspect.iscoroutinefunction(value) or inspect.iscoroutine(value)): + if (inspect.isawaitable(value)): return await value return value -class RpcResultException(Exception): - name = None - stack = None +class RPCResultError(Exception): + name: str + stack: str + message: str + caught: Exception def __init__(self, caught, message): self.caught = caught @@ -85,6 +86,8 @@ class RpcProxy(object): class RpcPeer: + RPC_RESULT_ERROR_NAME = 'RPCResultError' + def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None: self.send = send self.idCounter = 1 @@ -127,10 +130,52 @@ class RpcPeer: def kill(self): self.killed = True - def createErrorResult(self, result: Any, name: str, message: str, tb: str): - result['stack'] = tb if tb else 'no stack' - result['result'] = name if name else 'no name' - result['message'] = message if message else 'no message' + def createErrorResult(self, result: Any, e: Exception): + s = self.serializeError(e) + result['result'] = s + result['throw'] = True + + # TODO 3/2/2023 deprecate these properties + tb = traceback.format_exc() + message = str(e) + result['stack'] = tb or '[no stack]', + result['message'] = message or '[no message]', + # END TODO + + + def deserializeError(e: Dict) -> RPCResultError: + error = RPCResultError(None, e.get('message')) + error.stack = e.get('stack') + error.name = e.get('name') + return error + + def serializeError(self, e: Exception): + tb = traceback.format_exc() + name = type(e).__name__ + message = str(e) + + serialized = { + 'stack': tb or '[no stack]', + 'name': name or '[no name]', + 'message': message or '[no message]', + } + + return { + '__remote_constructor_name': RpcPeer.RPC_RESULT_ERROR_NAME, + '__serialized_value': serialized, + } + + def getProxyProperties(self, value): + if not hasattr(value, '__aiter__') or not hasattr(value, '__anext__'): + return getattr(value, '__proxy_props', None) + + props = getattr(value, '__proxy_props', None) or {} + props['Symbol(Symbol.asyncIterator)'] = { + 'next': '__anext__', + 'throw': 'athrow', + 'return': 'asend', + } + return props def serialize(self, value, requireProxy, serializationContext: Dict): if (not value or (not requireProxy and type(value) in jsonSerializable)): @@ -139,6 +184,9 @@ class RpcPeer: __remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr( value, '__proxy_constructor') else type(value).__name__ + if isinstance(value, Exception): + return self.serializeError(value) + proxiedEntry = self.localProxied.get(value, None) if proxiedEntry: proxiedEntry['finalizerId'] = str(self.proxyCounter) @@ -147,7 +195,7 @@ class RpcPeer: '__remote_proxy_id': proxiedEntry['id'], '__remote_proxy_finalizer_id': proxiedEntry['finalizerId'], '__remote_constructor_name': __remote_constructor_name, - '__remote_proxy_props': getattr(value, '__proxy_props', None), + '__remote_proxy_props': self.getProxyProperties(value), '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), } return ret @@ -170,7 +218,7 @@ class RpcPeer: '__remote_proxy_id': None, '__remote_proxy_finalizer_id': None, '__remote_constructor_name': __remote_constructor_name, - '__remote_proxy_props': getattr(value, '__proxy_props', None), + '__remote_proxy_props': self.getProxyProperties(value), '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), '__serialized_value': serialized, } @@ -189,7 +237,7 @@ class RpcPeer: '__remote_proxy_id': proxyId, '__remote_proxy_finalizer_id': proxyId, '__remote_constructor_name': __remote_constructor_name, - '__remote_proxy_props': getattr(value, '__proxy_props', None), + '__remote_proxy_props': self.getProxyProperties(value), '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), } @@ -235,6 +283,9 @@ class RpcPeer: __remote_proxy_oneway_methods = value.get( '__remote_proxy_oneway_methods', None) + if __remote_constructor_name == RpcPeer.RPC_RESULT_ERROR_NAME: + return self.deserializeError(__serialized_value); + if __remote_proxy_id: weakref = self.remoteWeakProxies.get('__remote_proxy_id', None) proxy = weakref() if weakref else None @@ -247,7 +298,7 @@ class RpcPeer: if __local_proxy_id: ret = self.localProxyMap.get(__local_proxy_id, None) if not ret: - raise RpcResultException( + raise RPCResultError( None, 'invalid local proxy id %s' % __local_proxy_id) return ret @@ -258,7 +309,7 @@ class RpcPeer: return value - async def handleMessage(self, message: Any, deserializationContext: Dict): + async def handleMessage(self, message: Dict, deserializationContext: Dict): try: messageType = message['type'] if messageType == 'param': @@ -310,11 +361,10 @@ class RpcPeer: value = await maybe_await(target(*args)) result['result'] = self.serialize(value, False, serializationContext) + except StopAsyncIteration as e: + self.createErrorResult(result, e) except Exception as e: - tb = traceback.format_exc() - # print('failure', method, e, tb) - self.createErrorResult( - result, type(e).__name__, str(e), tb) + self.createErrorResult(result, e) if not message.get('oneway', False): self.send(result, None, serializationContext) @@ -323,18 +373,21 @@ class RpcPeer: id = message['id'] future = self.pendingResults.get(id, None) if not future: - raise RpcResultException( + raise RPCResultError( None, 'unknown result %s' % id) del self.pendingResults[id] - if hasattr(message, 'message') or hasattr(message, 'stack'): - e = RpcResultException( + if (hasattr(message, 'message') or hasattr(message, 'stack')) and not hasattr(message, 'throw'): + e = RPCResultError( None, message.get('message', None)) e.stack = message.get('stack', None) e.name = message.get('name', None) future.set_exception(e) return - future.set_result(self.deserialize( - message.get('result', None), deserializationContext)) + deserialized = self.deserialize(message.get('result', None), deserializationContext) + if message.get('throw'): + future.set_exception(deserialized) + else: + future.set_result(deserialized) elif messageType == 'finalize': finalizerId = message.get('__local_proxy_finalizer_id', None) proxyId = message['__local_proxy_id'] @@ -347,7 +400,7 @@ class RpcPeer: self.localProxied.pop(local, None) local = self.localProxyMap.pop(proxyId, None) else: - raise RpcResultException( + raise RPCResultError( None, 'unknown rpc message type %s' % type) except Exception as e: print("unhandled rpc error", self.peerName, e) @@ -361,7 +414,7 @@ class RpcPeer: self.idCounter = self.idCounter + 1 future = Future() self.pendingResults[id] = future - await cb(id, lambda e: future.set_exception(RpcResultException(e, None))) + await cb(id, lambda e: future.set_exception(RPCResultError(e, None))) return await future async def getParam(self, param): diff --git a/server/python/rpc_reader.py b/server/python/rpc_reader.py index 75d88a23a..fc54cc874 100644 --- a/server/python/rpc_reader.py +++ b/server/python/rpc_reader.py @@ -2,33 +2,16 @@ from __future__ import annotations import asyncio import base64 -import gc import json -import sys import os -import platform -import shutil -import subprocess +import sys import threading -import time -import traceback -import zipfile from asyncio.events import AbstractEventLoop -from asyncio.futures import Future -from asyncio.streams import StreamReader, StreamWriter -from collections.abc import Mapping -from io import StringIO from os import sys -from typing import Any, List, Optional, Set, Tuple +from typing import List import aiofiles -import scrypted_python.scrypted_sdk.types -from scrypted_python.scrypted_sdk import ScryptedStatic, PluginFork -from scrypted_python.scrypted_sdk.types import Device, DeviceManifest, EventDetails, ScryptedInterfaceProperty, Storage -from typing_extensions import TypedDict import rpc -import multiprocessing -import multiprocessing.connection class BufferSerializer(rpc.RpcSerializer): diff --git a/server/src/rpc.ts b/server/src/rpc.ts index 6b500d251..6ab779885 100644 --- a/server/src/rpc.ts +++ b/server/src/rpc.ts @@ -91,6 +91,12 @@ export interface PrimitiveProxyHandler extends ProxyHandler } class RpcProxy implements PrimitiveProxyHandler { + static iteratorMethods = new Set([ + 'next', + 'throw', + 'return', + ]); + constructor(public peer: RpcPeer, public entry: LocalProxiedEntry, public constructorName: string, @@ -108,13 +114,14 @@ class RpcProxy implements PrimitiveProxyHandler { if (!this.proxyProps?.[Symbol.asyncIterator.toString()]) return; return () => { - return { - next: async () => { - return this.apply(() => 'next', undefined) - } - } + return new Proxy(() => { }, this); }; } + if (RpcProxy.iteratorMethods.has(p?.toString())) { + const asyncIteratorMethod = this.proxyProps?.[Symbol.asyncIterator.toString()]?.[p]; + if (asyncIteratorMethod) + return new Proxy(() => asyncIteratorMethod, this); + } if (p === RpcPeer.PROPERTY_PROXY_ID) return this.entry.id; if (p === '__proxy_constructor') @@ -182,10 +189,31 @@ class RpcProxy implements PrimitiveProxyHandler { return Promise.resolve(); } - return this.peer.createPendingResult((id, reject) => { + const pendingResult = this.peer.createPendingResult((id, reject) => { rpcApply.id = id; this.peer.send(rpcApply, reject, serializationContext); - }) + }); + + const asyncIterator = this.proxyProps?.[Symbol.asyncIterator.toString()]; + if (!asyncIterator || method !== asyncIterator.next) + return pendingResult; + + return pendingResult + .then(value => { + return ({ + value, + done: false, + }); + }) + .catch(e => { + if (e.name === 'StopAsyncIteration') { + return { + done: true, + value: undefined, + } + } + throw e; + }) } } @@ -252,6 +280,12 @@ interface LocalProxiedEntry { finalizerId: string | undefined; } +interface ErrorType { + name: string; + message: string; + stack?: string; +} + export class RpcPeer { idCounter = 1; params: { [name: string]: any } = {}; @@ -311,10 +345,13 @@ export class RpcPeer { static getProxyProperies(value: any) { if (!value[Symbol.asyncIterator]) return value?.[RpcPeer.PROPERTY_PROXY_PROPERTIES]; - return { - [Symbol.asyncIterator.toString()]: true, - ...value?.[RpcPeer.PROPERTY_PROXY_PROPERTIES], - } + const props = value?.[RpcPeer.PROPERTY_PROXY_PROPERTIES] || {}; + props[Symbol.asyncIterator.toString()] = { + next: 'next', + throw: 'throw', + return: 'return', + }; + return props; } static readonly RPC_RESULT_ERROR_NAME = 'RPCResultError'; @@ -425,7 +462,7 @@ export class RpcPeer { * @param result * @param e */ - createErrorResult(result: RpcResult, e: Error) { + createErrorResult(result: RpcResult, e: ErrorType) { result.result = this.serializeError(e); result.throw = true; result.message = (e as Error).message || 'no message'; @@ -483,7 +520,7 @@ export class RpcPeer { return new RPCResultError(this, message, undefined, { name, stack }); } - serializeError(e: Error): RpcRemoteProxyValue { + serializeError(e: ErrorType): RpcRemoteProxyValue { const __serialized_value: SerialiedRpcResultError = { stack: e.stack || '[no stack]', name: e.name || '[no name]', @@ -646,6 +683,19 @@ export class RpcPeer { if (!method) throw new Error(`target ${target?.constructor?.name} does not have method ${rpcApply.method}`); value = await target[rpcApply.method](...args); + + if (target[Symbol.asyncIterator] && rpcApply.method === 'next') { + if (value.done) { + const errorType: ErrorType = { + name: 'StopAsyncIteration', + message: undefined, + }; + throw errorType; + } + else { + value = value.value; + } + } } else { value = await target(...args); @@ -663,12 +713,13 @@ export class RpcPeer { break; } case 'result': { + // console.log(message) const rpcResult = message as RpcResult; const deferred = this.pendingResults[rpcResult.id]; delete this.pendingResults[rpcResult.id]; if (!deferred) throw new Error(`unknown result ${rpcResult.id}`); - if (rpcResult.message || rpcResult.stack) { + if ((rpcResult.message || rpcResult.stack) && !rpcResult.throw) { const e = new RPCResultError(this, rpcResult.message || 'no message', undefined, { name: rpcResult.result, stack: rpcResult.stack, diff --git a/server/test/rpc-iterator-test.ts b/server/test/rpc-iterator-test.ts index e2e477f6c..557c9273c 100644 --- a/server/test/rpc-iterator-test.ts +++ b/server/test/rpc-iterator-test.ts @@ -1,5 +1,5 @@ import { RpcPeer } from "../src/rpc"; -import {sleep} from '../src/sleep'; +import { sleep } from '../src/sleep'; const p1 = new RpcPeer('p1', 'p2', message => { // console.log('message p1 p2', message); @@ -25,14 +25,22 @@ p1.params['thing'] = generator(); async function test() { const foo = await p2.getParam('thing') as AsyncGenerator; - await sleep(0); - console.log(await foo.next()); - await sleep(0); - // await foo.throw(new Error('barf')); - await foo.return(44); - await sleep(0); - console.log(await foo.next()); - console.log(await foo.next()); + if (true) { + for await (const c of foo) { + console.log(c); + } + } + else { + await sleep(0); + console.log(await foo.next()); + await sleep(0); + // await foo.throw(new Error('barf')); + await foo.return(44); + await sleep(0); + console.log(await foo.next()); + console.log(await foo.next()); + } + } test(); diff --git a/server/test/rpc-python-test.ts b/server/test/rpc-python-test.ts new file mode 100644 index 000000000..ec817e400 --- /dev/null +++ b/server/test/rpc-python-test.ts @@ -0,0 +1,27 @@ +import child_process from 'child_process'; +import path from 'path'; +import type { Readable, Writable } from "stream"; +import { createDuplexRpcPeer } from '../src/rpc-serializer'; +import assert from 'assert'; + +async function main() { + + const cp = child_process.spawn('python3', [path.join(__dirname, '../python/rpc-iterator-test.py')], { + stdio: ['pipe', 'inherit', 'inherit', 'pipe', 'pipe'], + }); + + cp.on('exit', code => console.log('exited', code)) + + const rpcPeer = createDuplexRpcPeer('node', 'python', cp.stdio[3] as Readable, cp.stdio[4] as Writable); + + const foo = await rpcPeer.getParam('foo'); + assert.equal(foo, 3); + + const ticker = await rpcPeer.getParam('ticker'); + for await (const v of ticker) { + console.log(v); + } + process.exit(); +} + +main();