rpc: implement python async iterator

This commit is contained in:
Koushik Dutta
2023-03-02 21:03:29 -08:00
parent b2e5801426
commit 096c036ea2
7 changed files with 211 additions and 93 deletions

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import base64
import gc
import sys
import os
@@ -18,7 +17,7 @@ from asyncio.streams import StreamReader, StreamWriter
from collections.abc import Mapping
from io import StringIO
from os import sys
from typing import Any, List, Optional, Set, Tuple
from typing import Any, Optional, Set, Tuple
import scrypted_python.scrypted_sdk.types
from scrypted_python.scrypted_sdk import ScryptedStatic, PluginFork
@@ -202,30 +201,6 @@ class DeviceManager(scrypted_python.scrypted_sdk.types.DeviceManager):
def getDeviceStorage(self, nativeId: str = None) -> Storage:
return self.nativeIds.get(nativeId, None)
class BufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
return base64.b64encode(value).decode('utf8')
def deserialize(self, value, serializationContext):
return base64.b64decode(value)
class SidebandBufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
buffers = serializationContext.get('buffers', None)
if not buffers:
buffers = []
serializationContext['buffers'] = buffers
buffers.append(value)
return len(buffers) - 1
def deserialize(self, value, serializationContext):
buffers: List = serializationContext.get('buffers', None)
buffer = buffers.pop()
return buffer
class PluginRemote:
systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {}
nativeIds: Mapping[str, DeviceStorage] = {}

View File

@@ -0,0 +1,21 @@
import sys
import asyncio
from rpc_reader import prepare_peer_readloop
async def main():
peer, peerReadLoop = await prepare_peer_readloop(loop, 4, 3)
peer.params['foo'] = 3
async def ticker(delay, to):
for i in range(to):
# print(i)
yield i
await asyncio.sleep(delay)
peer.params['ticker'] = ticker(0, 3)
print('python starting')
await peerReadLoop()
loop = asyncio.new_event_loop()
loop.run_until_complete(main())

View File

@@ -4,7 +4,6 @@ import traceback
import inspect
from typing_extensions import TypedDict
import weakref
import sys
jsonSerializable = set()
jsonSerializable.add(float)
@@ -16,14 +15,16 @@ jsonSerializable.add(list)
async def maybe_await(value):
if (inspect.iscoroutinefunction(value) or inspect.iscoroutine(value)):
if (inspect.isawaitable(value)):
return await value
return value
class RpcResultException(Exception):
name = None
stack = None
class RPCResultError(Exception):
name: str
stack: str
message: str
caught: Exception
def __init__(self, caught, message):
self.caught = caught
@@ -85,6 +86,8 @@ class RpcProxy(object):
class RpcPeer:
RPC_RESULT_ERROR_NAME = 'RPCResultError'
def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None:
self.send = send
self.idCounter = 1
@@ -127,10 +130,52 @@ class RpcPeer:
def kill(self):
self.killed = True
def createErrorResult(self, result: Any, name: str, message: str, tb: str):
result['stack'] = tb if tb else 'no stack'
result['result'] = name if name else 'no name'
result['message'] = message if message else 'no message'
def createErrorResult(self, result: Any, e: Exception):
s = self.serializeError(e)
result['result'] = s
result['throw'] = True
# TODO 3/2/2023 deprecate these properties
tb = traceback.format_exc()
message = str(e)
result['stack'] = tb or '[no stack]',
result['message'] = message or '[no message]',
# END TODO
def deserializeError(e: Dict) -> RPCResultError:
error = RPCResultError(None, e.get('message'))
error.stack = e.get('stack')
error.name = e.get('name')
return error
def serializeError(self, e: Exception):
tb = traceback.format_exc()
name = type(e).__name__
message = str(e)
serialized = {
'stack': tb or '[no stack]',
'name': name or '[no name]',
'message': message or '[no message]',
}
return {
'__remote_constructor_name': RpcPeer.RPC_RESULT_ERROR_NAME,
'__serialized_value': serialized,
}
def getProxyProperties(self, value):
if not hasattr(value, '__aiter__') or not hasattr(value, '__anext__'):
return getattr(value, '__proxy_props', None)
props = getattr(value, '__proxy_props', None) or {}
props['Symbol(Symbol.asyncIterator)'] = {
'next': '__anext__',
'throw': 'athrow',
'return': 'asend',
}
return props
def serialize(self, value, requireProxy, serializationContext: Dict):
if (not value or (not requireProxy and type(value) in jsonSerializable)):
@@ -139,6 +184,9 @@ class RpcPeer:
__remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr(
value, '__proxy_constructor') else type(value).__name__
if isinstance(value, Exception):
return self.serializeError(value)
proxiedEntry = self.localProxied.get(value, None)
if proxiedEntry:
proxiedEntry['finalizerId'] = str(self.proxyCounter)
@@ -147,7 +195,7 @@ class RpcPeer:
'__remote_proxy_id': proxiedEntry['id'],
'__remote_proxy_finalizer_id': proxiedEntry['finalizerId'],
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
}
return ret
@@ -170,7 +218,7 @@ class RpcPeer:
'__remote_proxy_id': None,
'__remote_proxy_finalizer_id': None,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
'__serialized_value': serialized,
}
@@ -189,7 +237,7 @@ class RpcPeer:
'__remote_proxy_id': proxyId,
'__remote_proxy_finalizer_id': proxyId,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
}
@@ -235,6 +283,9 @@ class RpcPeer:
__remote_proxy_oneway_methods = value.get(
'__remote_proxy_oneway_methods', None)
if __remote_constructor_name == RpcPeer.RPC_RESULT_ERROR_NAME:
return self.deserializeError(__serialized_value);
if __remote_proxy_id:
weakref = self.remoteWeakProxies.get('__remote_proxy_id', None)
proxy = weakref() if weakref else None
@@ -247,7 +298,7 @@ class RpcPeer:
if __local_proxy_id:
ret = self.localProxyMap.get(__local_proxy_id, None)
if not ret:
raise RpcResultException(
raise RPCResultError(
None, 'invalid local proxy id %s' % __local_proxy_id)
return ret
@@ -258,7 +309,7 @@ class RpcPeer:
return value
async def handleMessage(self, message: Any, deserializationContext: Dict):
async def handleMessage(self, message: Dict, deserializationContext: Dict):
try:
messageType = message['type']
if messageType == 'param':
@@ -310,11 +361,10 @@ class RpcPeer:
value = await maybe_await(target(*args))
result['result'] = self.serialize(value, False, serializationContext)
except StopAsyncIteration as e:
self.createErrorResult(result, e)
except Exception as e:
tb = traceback.format_exc()
# print('failure', method, e, tb)
self.createErrorResult(
result, type(e).__name__, str(e), tb)
self.createErrorResult(result, e)
if not message.get('oneway', False):
self.send(result, None, serializationContext)
@@ -323,18 +373,21 @@ class RpcPeer:
id = message['id']
future = self.pendingResults.get(id, None)
if not future:
raise RpcResultException(
raise RPCResultError(
None, 'unknown result %s' % id)
del self.pendingResults[id]
if hasattr(message, 'message') or hasattr(message, 'stack'):
e = RpcResultException(
if (hasattr(message, 'message') or hasattr(message, 'stack')) and not hasattr(message, 'throw'):
e = RPCResultError(
None, message.get('message', None))
e.stack = message.get('stack', None)
e.name = message.get('name', None)
future.set_exception(e)
return
future.set_result(self.deserialize(
message.get('result', None), deserializationContext))
deserialized = self.deserialize(message.get('result', None), deserializationContext)
if message.get('throw'):
future.set_exception(deserialized)
else:
future.set_result(deserialized)
elif messageType == 'finalize':
finalizerId = message.get('__local_proxy_finalizer_id', None)
proxyId = message['__local_proxy_id']
@@ -347,7 +400,7 @@ class RpcPeer:
self.localProxied.pop(local, None)
local = self.localProxyMap.pop(proxyId, None)
else:
raise RpcResultException(
raise RPCResultError(
None, 'unknown rpc message type %s' % type)
except Exception as e:
print("unhandled rpc error", self.peerName, e)
@@ -361,7 +414,7 @@ class RpcPeer:
self.idCounter = self.idCounter + 1
future = Future()
self.pendingResults[id] = future
await cb(id, lambda e: future.set_exception(RpcResultException(e, None)))
await cb(id, lambda e: future.set_exception(RPCResultError(e, None)))
return await future
async def getParam(self, param):

View File

@@ -2,33 +2,16 @@ from __future__ import annotations
import asyncio
import base64
import gc
import json
import sys
import os
import platform
import shutil
import subprocess
import sys
import threading
import time
import traceback
import zipfile
from asyncio.events import AbstractEventLoop
from asyncio.futures import Future
from asyncio.streams import StreamReader, StreamWriter
from collections.abc import Mapping
from io import StringIO
from os import sys
from typing import Any, List, Optional, Set, Tuple
from typing import List
import aiofiles
import scrypted_python.scrypted_sdk.types
from scrypted_python.scrypted_sdk import ScryptedStatic, PluginFork
from scrypted_python.scrypted_sdk.types import Device, DeviceManifest, EventDetails, ScryptedInterfaceProperty, Storage
from typing_extensions import TypedDict
import rpc
import multiprocessing
import multiprocessing.connection
class BufferSerializer(rpc.RpcSerializer):

View File

@@ -91,6 +91,12 @@ export interface PrimitiveProxyHandler<T extends object> extends ProxyHandler<T>
}
class RpcProxy implements PrimitiveProxyHandler<any> {
static iteratorMethods = new Set([
'next',
'throw',
'return',
]);
constructor(public peer: RpcPeer,
public entry: LocalProxiedEntry,
public constructorName: string,
@@ -108,13 +114,14 @@ class RpcProxy implements PrimitiveProxyHandler<any> {
if (!this.proxyProps?.[Symbol.asyncIterator.toString()])
return;
return () => {
return {
next: async () => {
return this.apply(() => 'next', undefined)
}
}
return new Proxy(() => { }, this);
};
}
if (RpcProxy.iteratorMethods.has(p?.toString())) {
const asyncIteratorMethod = this.proxyProps?.[Symbol.asyncIterator.toString()]?.[p];
if (asyncIteratorMethod)
return new Proxy(() => asyncIteratorMethod, this);
}
if (p === RpcPeer.PROPERTY_PROXY_ID)
return this.entry.id;
if (p === '__proxy_constructor')
@@ -182,10 +189,31 @@ class RpcProxy implements PrimitiveProxyHandler<any> {
return Promise.resolve();
}
return this.peer.createPendingResult((id, reject) => {
const pendingResult = this.peer.createPendingResult((id, reject) => {
rpcApply.id = id;
this.peer.send(rpcApply, reject, serializationContext);
})
});
const asyncIterator = this.proxyProps?.[Symbol.asyncIterator.toString()];
if (!asyncIterator || method !== asyncIterator.next)
return pendingResult;
return pendingResult
.then(value => {
return ({
value,
done: false,
});
})
.catch(e => {
if (e.name === 'StopAsyncIteration') {
return {
done: true,
value: undefined,
}
}
throw e;
})
}
}
@@ -252,6 +280,12 @@ interface LocalProxiedEntry {
finalizerId: string | undefined;
}
interface ErrorType {
name: string;
message: string;
stack?: string;
}
export class RpcPeer {
idCounter = 1;
params: { [name: string]: any } = {};
@@ -311,10 +345,13 @@ export class RpcPeer {
static getProxyProperies(value: any) {
if (!value[Symbol.asyncIterator])
return value?.[RpcPeer.PROPERTY_PROXY_PROPERTIES];
return {
[Symbol.asyncIterator.toString()]: true,
...value?.[RpcPeer.PROPERTY_PROXY_PROPERTIES],
}
const props = value?.[RpcPeer.PROPERTY_PROXY_PROPERTIES] || {};
props[Symbol.asyncIterator.toString()] = {
next: 'next',
throw: 'throw',
return: 'return',
};
return props;
}
static readonly RPC_RESULT_ERROR_NAME = 'RPCResultError';
@@ -425,7 +462,7 @@ export class RpcPeer {
* @param result
* @param e
*/
createErrorResult(result: RpcResult, e: Error) {
createErrorResult(result: RpcResult, e: ErrorType) {
result.result = this.serializeError(e);
result.throw = true;
result.message = (e as Error).message || 'no message';
@@ -483,7 +520,7 @@ export class RpcPeer {
return new RPCResultError(this, message, undefined, { name, stack });
}
serializeError(e: Error): RpcRemoteProxyValue {
serializeError(e: ErrorType): RpcRemoteProxyValue {
const __serialized_value: SerialiedRpcResultError = {
stack: e.stack || '[no stack]',
name: e.name || '[no name]',
@@ -646,6 +683,19 @@ export class RpcPeer {
if (!method)
throw new Error(`target ${target?.constructor?.name} does not have method ${rpcApply.method}`);
value = await target[rpcApply.method](...args);
if (target[Symbol.asyncIterator] && rpcApply.method === 'next') {
if (value.done) {
const errorType: ErrorType = {
name: 'StopAsyncIteration',
message: undefined,
};
throw errorType;
}
else {
value = value.value;
}
}
}
else {
value = await target(...args);
@@ -663,12 +713,13 @@ export class RpcPeer {
break;
}
case 'result': {
// console.log(message)
const rpcResult = message as RpcResult;
const deferred = this.pendingResults[rpcResult.id];
delete this.pendingResults[rpcResult.id];
if (!deferred)
throw new Error(`unknown result ${rpcResult.id}`);
if (rpcResult.message || rpcResult.stack) {
if ((rpcResult.message || rpcResult.stack) && !rpcResult.throw) {
const e = new RPCResultError(this, rpcResult.message || 'no message', undefined, {
name: rpcResult.result,
stack: rpcResult.stack,

View File

@@ -1,5 +1,5 @@
import { RpcPeer } from "../src/rpc";
import {sleep} from '../src/sleep';
import { sleep } from '../src/sleep';
const p1 = new RpcPeer('p1', 'p2', message => {
// console.log('message p1 p2', message);
@@ -25,14 +25,22 @@ p1.params['thing'] = generator();
async function test() {
const foo = await p2.getParam('thing') as AsyncGenerator<number>;
await sleep(0);
console.log(await foo.next());
await sleep(0);
// await foo.throw(new Error('barf'));
await foo.return(44);
await sleep(0);
console.log(await foo.next());
console.log(await foo.next());
if (true) {
for await (const c of foo) {
console.log(c);
}
}
else {
await sleep(0);
console.log(await foo.next());
await sleep(0);
// await foo.throw(new Error('barf'));
await foo.return(44);
await sleep(0);
console.log(await foo.next());
console.log(await foo.next());
}
}
test();

View File

@@ -0,0 +1,27 @@
import child_process from 'child_process';
import path from 'path';
import type { Readable, Writable } from "stream";
import { createDuplexRpcPeer } from '../src/rpc-serializer';
import assert from 'assert';
async function main() {
const cp = child_process.spawn('python3', [path.join(__dirname, '../python/rpc-iterator-test.py')], {
stdio: ['pipe', 'inherit', 'inherit', 'pipe', 'pipe'],
});
cp.on('exit', code => console.log('exited', code))
const rpcPeer = createDuplexRpcPeer('node', 'python', cp.stdio[3] as Readable, cp.stdio[4] as Writable);
const foo = await rpcPeer.getParam('foo');
assert.equal(foo, 3);
const ticker = await rpcPeer.getParam('ticker');
for await (const v of ticker) {
console.log(v);
}
process.exit();
}
main();