server: improve python rpc buffer wrangling

This commit is contained in:
Koushik Dutta
2022-09-21 21:53:38 -07:00
parent 0a890ddabb
commit a4582da683
4 changed files with 103 additions and 41 deletions

View File

@@ -17,7 +17,7 @@ from asyncio.streams import StreamReader, StreamWriter
from collections.abc import Mapping
from io import StringIO
from os import sys
from typing import Any, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
import aiofiles
import scrypted_python.scrypted_sdk.types
@@ -139,13 +139,27 @@ class DeviceManager(scrypted_python.scrypted_sdk.types.DeviceManager):
return self.nativeIds.get(nativeId, None)
class BufferSerializer(rpc.RpcSerializer):
def serialize(self, value):
def serialize(self, value, serializationContext):
return base64.b64encode(value).decode('utf8')
def deserialize(self, value):
def deserialize(self, value, serializationContext):
return base64.b64decode(value)
class SidebandBufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
buffers = serializationContext.get('buffers', None)
if not buffers:
buffers = []
serializationContext['buffers'] = buffers
buffers.append(value)
return len(buffers) - 1
def deserialize(self, value, serializationContext):
buffers: List = serializationContext.get('buffers', None)
buffer = buffers.pop()
return buffer
class PluginRemote:
systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {}
nativeIds: Mapping[str, DeviceStorage] = {}
@@ -329,29 +343,69 @@ class PluginRemote:
pass
async def readLoop(loop, peer, reader):
async for line in reader:
async def readLoop(loop, peer: rpc.RpcPeer, reader):
deserializationContext = {
'buffers': []
}
while True:
try:
message = json.loads(line)
asyncio.run_coroutine_threadsafe(peer.handleMessage(message), loop)
lengthBytes = await reader.read(4)
typeBytes = await reader.read(1)
type = typeBytes[0]
length = int.from_bytes(lengthBytes, 'big')
data = await reader.read(length - 1)
if type == 1:
deserializationContext['buffers'].append(data)
continue
message = json.loads(data)
asyncio.run_coroutine_threadsafe(peer.handleMessage(message, deserializationContext), loop)
deserializationContext = {
'buffers': []
}
except Exception as e:
print('read loop error', e)
sys.exit()
async def async_main(loop: AbstractEventLoop):
reader = await aiofiles.open(3, mode='r')
reader = await aiofiles.open(3, mode='rb')
def send(message, reject=None, serializationContext = None):
if serializationContext:
buffers = serializationContext.get('buffers', None)
if buffers:
for buffer in buffers:
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
type = 1
try:
os.write(4, lb)
os.write(4, bytes([type]))
os.write(4, buffer)
except Exception as e:
if reject:
reject(e)
return
def send(message, reject=None):
jsonString = json.dumps(message)
b = bytes(jsonString, 'utf8')
length = len(b) + 1
lb = length.to_bytes(4, 'big')
type = 0
try:
os.write(4, bytes(jsonString + '\n', 'utf8'))
os.write(4, lb)
os.write(4, bytes([type]))
os.write(4, b)
except Exception as e:
if reject:
reject(e)
peer = rpc.RpcPeer(send)
peer.nameDeserializerMap['Buffer'] = BufferSerializer()
peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer()
peer.constructorSerializerMap[bytes] = 'Buffer'
peer.constructorSerializerMap[bytearray] = 'Buffer'
peer.params['print'] = print

View File

@@ -1,5 +1,5 @@
from asyncio.futures import Future
from typing import Any, Callable, Mapping, List
from typing import Any, Callable, Dict, Mapping, List
import traceback
import inspect
from typing_extensions import TypedDict
@@ -31,10 +31,10 @@ class RpcResultException(Exception):
class RpcSerializer:
def serialize(self, value):
def serialize(self, value, serializationContext):
pass
def deserialize(self, value):
def deserialize(self, value, deserializationContext):
pass
@@ -85,7 +85,7 @@ class RpcProxy(object):
class RpcPeer:
def __init__(self, send: Callable[[object, Callable[[Exception], None]], None]) -> None:
def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None:
self.send = send
self.idCounter = 1
self.peerName = 'Unnamed Peer'
@@ -99,9 +99,10 @@ class RpcPeer:
self.nameDeserializerMap: Mapping[str, RpcSerializer] = {}
def __apply__(self, proxyId: str, oneWayMethods: List[str], method: str, args: list):
serializationContext: Dict = {}
serializedArgs = []
for arg in args:
serializedArgs.append(self.serialize(arg, False))
serializedArgs.append(self.serialize(arg, False, serializationContext))
rpcApply = {
'type': 'apply',
@@ -113,25 +114,25 @@ class RpcPeer:
if oneWayMethods and method in oneWayMethods:
rpcApply['oneway'] = True
self.send(rpcApply)
self.send(rpcApply, None, serializationContext)
future = Future()
future.set_result(None)
return future
async def send(id: str, reject: Callable[[Exception], None]):
rpcApply['id'] = id
self.send(rpcApply, reject)
self.send(rpcApply, reject, serializationContext)
return self.createPendingResult(send)
def kill(self):
self.killed = True
def createErrorResult(self, result: any, name: str, message: str, tb: str):
def createErrorResult(self, result: Any, name: str, message: str, tb: str):
result['stack'] = tb if tb else 'no stack'
result['result'] = name if name else 'no name'
result['message'] = message if message else 'no message'
def serialize(self, value, requireProxy):
def serialize(self, value, requireProxy, serializationContext: Dict):
if (not value or (not requireProxy and type(value) in jsonSerializable)):
return value
@@ -164,7 +165,7 @@ class RpcPeer:
if serializerMapName:
__remote_constructor_name = serializerMapName
serializer = self.nameDeserializerMap.get(serializerMapName, None)
serialized = serializer.serialize(value)
serialized = serializer.serialize(value, serializationContext)
ret = {
'__remote_proxy_id': None,
'__remote_proxy_finalizer_id': None,
@@ -216,7 +217,7 @@ class RpcPeer:
weakref.finalize(proxy, lambda: self.finalize(localProxiedEntry))
return proxy
def deserialize(self, value):
def deserialize(self, value, deserializationContext: Dict):
if not value:
return value
@@ -253,11 +254,11 @@ class RpcPeer:
deserializer = self.nameDeserializerMap.get(
__remote_constructor_name, None)
if deserializer:
return deserializer.deserialize(__serialized_value)
return deserializer.deserialize(__serialized_value, deserializationContext)
return value
async def handleMessage(self, message: any):
async def handleMessage(self, message: Any, deserializationContext: Dict):
try:
messageType = message['type']
if messageType == 'param':
@@ -266,17 +267,18 @@ class RpcPeer:
'id': message['id'],
}
serializationContext: Dict = {}
try:
value = self.params.get(message['param'], None)
value = await maybe_await(value)
result['result'] = self.serialize(
value, message.get('requireProxy', None))
value, message.get('requireProxy', None), serializationContext)
except Exception as e:
tb = traceback.format_exc()
self.createErrorResult(
result, type(e).__name, str(e), tb)
self.send(result)
self.send(result, None, serializationContext)
elif messageType == 'apply':
result = {
@@ -286,6 +288,7 @@ class RpcPeer:
method = message.get('method', None)
try:
serializationContext: Dict = {}
target = self.localProxyMap.get(
message['proxyId'], None)
if not target:
@@ -294,7 +297,7 @@ class RpcPeer:
args = []
for arg in (message['args'] or []):
args.append(self.deserialize(arg))
args.append(self.deserialize(arg, deserializationContext))
value = None
if method:
@@ -306,7 +309,7 @@ class RpcPeer:
else:
value = await maybe_await(target(*args))
result['result'] = self.serialize(value, False)
result['result'] = self.serialize(value, False, serializationContext)
except Exception as e:
tb = traceback.format_exc()
# print('failure', method, e, tb)
@@ -314,7 +317,7 @@ class RpcPeer:
result, type(e).__name__, str(e), tb)
if not message.get('oneway', False):
self.send(result)
self.send(result, serializationContext)
elif messageType == 'result':
id = message['id']
@@ -331,7 +334,7 @@ class RpcPeer:
future.set_exception(e)
return
future.set_result(self.deserialize(
message.get('result', None)))
message.get('result', None), deserializationContext))
elif messageType == 'finalize':
finalizerId = message.get('__local_proxy_finalizer_id', None)
proxyId = message['__local_proxy_id']

View File

@@ -293,9 +293,9 @@ export class PluginHost {
}
}
this.peer = new RpcPeer('host', this.pluginId, (message, reject) => {
this.peer = new RpcPeer('host', this.pluginId, (message, reject, serializationContext) => {
if (connected) {
this.worker.send(message, reject);
this.worker.send(message, reject, serializationContext);
}
else if (reject) {
reject(new Error('peer disconnected'));

View File

@@ -5,10 +5,12 @@ import path from 'path';
import readline from 'readline';
import { Readable, Writable } from 'stream';
import { RpcMessage, RpcPeer } from "../../rpc";
import { createRpcDuplexSerializer } from '../../rpc-serializer';
import { ChildProcessWorker } from "./child-process-worker";
import { RuntimeWorkerOptions } from "./runtime-worker";
export class PythonRuntimeWorker extends ChildProcessWorker {
serializer: ReturnType<typeof createRpcDuplexSerializer>;
constructor(pluginId: string, options: RuntimeWorkerOptions) {
super(pluginId, options);
@@ -62,21 +64,24 @@ export class PythonRuntimeWorker extends ChildProcessWorker {
const peerin = this.worker.stdio[3] as Writable;
const peerout = this.worker.stdio[4] as Readable;
peerin.on('error', e => this.emit('error', e));
peerout.on('error', e => this.emit('error', e));
const readInterface = readline.createInterface({
input: peerout,
terminal: false,
const serializer = this.serializer = createRpcDuplexSerializer(peerin);
serializer.setupRpcPeer(peer);
peerout.on('data', data => serializer.onData(data));
peerin.on('error', e => {
this.emit('error', e);
serializer.onDisconnected();
});
peerout.on('error', e => {
this.emit('error', e)
serializer.onDisconnected();
});
readInterface.on('line', line => peer.handleMessage(JSON.parse(line)));
}
send(message: RpcMessage, reject?: (e: Error) => void): void {
send(message: RpcMessage, reject?: (e: Error) => void, serializationContext?: any): void {
try {
if (!this.worker)
throw new Error('worked has been killed');
(this.worker.stdio[3] as Writable).write(JSON.stringify(message) + '\n', e => e && reject?.(e));
this.serializer.sendMessage(message, reject, serializationContext);
}
catch (e) {
reject?.(e);