mirror of
https://github.com/koush/scrypted.git
synced 2026-06-20 16:40:30 +01:00
server: improve python rpc buffer wrangling
This commit is contained in:
@@ -17,7 +17,7 @@ from asyncio.streams import StreamReader, StreamWriter
|
||||
from collections.abc import Mapping
|
||||
from io import StringIO
|
||||
from os import sys
|
||||
from typing import Any, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import aiofiles
|
||||
import scrypted_python.scrypted_sdk.types
|
||||
@@ -139,13 +139,27 @@ class DeviceManager(scrypted_python.scrypted_sdk.types.DeviceManager):
|
||||
return self.nativeIds.get(nativeId, None)
|
||||
|
||||
class BufferSerializer(rpc.RpcSerializer):
|
||||
def serialize(self, value):
|
||||
def serialize(self, value, serializationContext):
|
||||
return base64.b64encode(value).decode('utf8')
|
||||
|
||||
def deserialize(self, value):
|
||||
def deserialize(self, value, serializationContext):
|
||||
return base64.b64decode(value)
|
||||
|
||||
|
||||
class SidebandBufferSerializer(rpc.RpcSerializer):
|
||||
def serialize(self, value, serializationContext):
|
||||
buffers = serializationContext.get('buffers', None)
|
||||
if not buffers:
|
||||
buffers = []
|
||||
serializationContext['buffers'] = buffers
|
||||
buffers.append(value)
|
||||
return len(buffers) - 1
|
||||
|
||||
def deserialize(self, value, serializationContext):
|
||||
buffers: List = serializationContext.get('buffers', None)
|
||||
buffer = buffers.pop()
|
||||
return buffer
|
||||
|
||||
class PluginRemote:
|
||||
systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {}
|
||||
nativeIds: Mapping[str, DeviceStorage] = {}
|
||||
@@ -329,29 +343,69 @@ class PluginRemote:
|
||||
pass
|
||||
|
||||
|
||||
async def readLoop(loop, peer, reader):
|
||||
async for line in reader:
|
||||
async def readLoop(loop, peer: rpc.RpcPeer, reader):
|
||||
deserializationContext = {
|
||||
'buffers': []
|
||||
}
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = json.loads(line)
|
||||
asyncio.run_coroutine_threadsafe(peer.handleMessage(message), loop)
|
||||
lengthBytes = await reader.read(4)
|
||||
typeBytes = await reader.read(1)
|
||||
type = typeBytes[0]
|
||||
length = int.from_bytes(lengthBytes, 'big')
|
||||
data = await reader.read(length - 1)
|
||||
|
||||
if type == 1:
|
||||
deserializationContext['buffers'].append(data)
|
||||
continue
|
||||
|
||||
message = json.loads(data)
|
||||
asyncio.run_coroutine_threadsafe(peer.handleMessage(message, deserializationContext), loop)
|
||||
|
||||
deserializationContext = {
|
||||
'buffers': []
|
||||
}
|
||||
except Exception as e:
|
||||
print('read loop error', e)
|
||||
sys.exit()
|
||||
|
||||
|
||||
async def async_main(loop: AbstractEventLoop):
|
||||
reader = await aiofiles.open(3, mode='r')
|
||||
reader = await aiofiles.open(3, mode='rb')
|
||||
|
||||
def send(message, reject=None, serializationContext = None):
|
||||
if serializationContext:
|
||||
buffers = serializationContext.get('buffers', None)
|
||||
if buffers:
|
||||
for buffer in buffers:
|
||||
length = len(buffer) + 1
|
||||
lb = length.to_bytes(4, 'big')
|
||||
type = 1
|
||||
try:
|
||||
os.write(4, lb)
|
||||
os.write(4, bytes([type]))
|
||||
os.write(4, buffer)
|
||||
except Exception as e:
|
||||
if reject:
|
||||
reject(e)
|
||||
return
|
||||
|
||||
def send(message, reject=None):
|
||||
jsonString = json.dumps(message)
|
||||
b = bytes(jsonString, 'utf8')
|
||||
length = len(b) + 1
|
||||
lb = length.to_bytes(4, 'big')
|
||||
type = 0
|
||||
try:
|
||||
os.write(4, bytes(jsonString + '\n', 'utf8'))
|
||||
os.write(4, lb)
|
||||
os.write(4, bytes([type]))
|
||||
os.write(4, b)
|
||||
except Exception as e:
|
||||
if reject:
|
||||
reject(e)
|
||||
|
||||
peer = rpc.RpcPeer(send)
|
||||
peer.nameDeserializerMap['Buffer'] = BufferSerializer()
|
||||
peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer()
|
||||
peer.constructorSerializerMap[bytes] = 'Buffer'
|
||||
peer.constructorSerializerMap[bytearray] = 'Buffer'
|
||||
peer.params['print'] = print
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from asyncio.futures import Future
|
||||
from typing import Any, Callable, Mapping, List
|
||||
from typing import Any, Callable, Dict, Mapping, List
|
||||
import traceback
|
||||
import inspect
|
||||
from typing_extensions import TypedDict
|
||||
@@ -31,10 +31,10 @@ class RpcResultException(Exception):
|
||||
|
||||
|
||||
class RpcSerializer:
|
||||
def serialize(self, value):
|
||||
def serialize(self, value, serializationContext):
|
||||
pass
|
||||
|
||||
def deserialize(self, value):
|
||||
def deserialize(self, value, deserializationContext):
|
||||
pass
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ class RpcProxy(object):
|
||||
|
||||
|
||||
class RpcPeer:
|
||||
def __init__(self, send: Callable[[object, Callable[[Exception], None]], None]) -> None:
|
||||
def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None:
|
||||
self.send = send
|
||||
self.idCounter = 1
|
||||
self.peerName = 'Unnamed Peer'
|
||||
@@ -99,9 +99,10 @@ class RpcPeer:
|
||||
self.nameDeserializerMap: Mapping[str, RpcSerializer] = {}
|
||||
|
||||
def __apply__(self, proxyId: str, oneWayMethods: List[str], method: str, args: list):
|
||||
serializationContext: Dict = {}
|
||||
serializedArgs = []
|
||||
for arg in args:
|
||||
serializedArgs.append(self.serialize(arg, False))
|
||||
serializedArgs.append(self.serialize(arg, False, serializationContext))
|
||||
|
||||
rpcApply = {
|
||||
'type': 'apply',
|
||||
@@ -113,25 +114,25 @@ class RpcPeer:
|
||||
|
||||
if oneWayMethods and method in oneWayMethods:
|
||||
rpcApply['oneway'] = True
|
||||
self.send(rpcApply)
|
||||
self.send(rpcApply, None, serializationContext)
|
||||
future = Future()
|
||||
future.set_result(None)
|
||||
return future
|
||||
|
||||
async def send(id: str, reject: Callable[[Exception], None]):
|
||||
rpcApply['id'] = id
|
||||
self.send(rpcApply, reject)
|
||||
self.send(rpcApply, reject, serializationContext)
|
||||
return self.createPendingResult(send)
|
||||
|
||||
def kill(self):
|
||||
self.killed = True
|
||||
|
||||
def createErrorResult(self, result: any, name: str, message: str, tb: str):
|
||||
def createErrorResult(self, result: Any, name: str, message: str, tb: str):
|
||||
result['stack'] = tb if tb else 'no stack'
|
||||
result['result'] = name if name else 'no name'
|
||||
result['message'] = message if message else 'no message'
|
||||
|
||||
def serialize(self, value, requireProxy):
|
||||
def serialize(self, value, requireProxy, serializationContext: Dict):
|
||||
if (not value or (not requireProxy and type(value) in jsonSerializable)):
|
||||
return value
|
||||
|
||||
@@ -164,7 +165,7 @@ class RpcPeer:
|
||||
if serializerMapName:
|
||||
__remote_constructor_name = serializerMapName
|
||||
serializer = self.nameDeserializerMap.get(serializerMapName, None)
|
||||
serialized = serializer.serialize(value)
|
||||
serialized = serializer.serialize(value, serializationContext)
|
||||
ret = {
|
||||
'__remote_proxy_id': None,
|
||||
'__remote_proxy_finalizer_id': None,
|
||||
@@ -216,7 +217,7 @@ class RpcPeer:
|
||||
weakref.finalize(proxy, lambda: self.finalize(localProxiedEntry))
|
||||
return proxy
|
||||
|
||||
def deserialize(self, value):
|
||||
def deserialize(self, value, deserializationContext: Dict):
|
||||
if not value:
|
||||
return value
|
||||
|
||||
@@ -253,11 +254,11 @@ class RpcPeer:
|
||||
deserializer = self.nameDeserializerMap.get(
|
||||
__remote_constructor_name, None)
|
||||
if deserializer:
|
||||
return deserializer.deserialize(__serialized_value)
|
||||
return deserializer.deserialize(__serialized_value, deserializationContext)
|
||||
|
||||
return value
|
||||
|
||||
async def handleMessage(self, message: any):
|
||||
async def handleMessage(self, message: Any, deserializationContext: Dict):
|
||||
try:
|
||||
messageType = message['type']
|
||||
if messageType == 'param':
|
||||
@@ -266,17 +267,18 @@ class RpcPeer:
|
||||
'id': message['id'],
|
||||
}
|
||||
|
||||
serializationContext: Dict = {}
|
||||
try:
|
||||
value = self.params.get(message['param'], None)
|
||||
value = await maybe_await(value)
|
||||
result['result'] = self.serialize(
|
||||
value, message.get('requireProxy', None))
|
||||
value, message.get('requireProxy', None), serializationContext)
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
self.createErrorResult(
|
||||
result, type(e).__name, str(e), tb)
|
||||
|
||||
self.send(result)
|
||||
self.send(result, None, serializationContext)
|
||||
|
||||
elif messageType == 'apply':
|
||||
result = {
|
||||
@@ -286,6 +288,7 @@ class RpcPeer:
|
||||
method = message.get('method', None)
|
||||
|
||||
try:
|
||||
serializationContext: Dict = {}
|
||||
target = self.localProxyMap.get(
|
||||
message['proxyId'], None)
|
||||
if not target:
|
||||
@@ -294,7 +297,7 @@ class RpcPeer:
|
||||
|
||||
args = []
|
||||
for arg in (message['args'] or []):
|
||||
args.append(self.deserialize(arg))
|
||||
args.append(self.deserialize(arg, deserializationContext))
|
||||
|
||||
value = None
|
||||
if method:
|
||||
@@ -306,7 +309,7 @@ class RpcPeer:
|
||||
else:
|
||||
value = await maybe_await(target(*args))
|
||||
|
||||
result['result'] = self.serialize(value, False)
|
||||
result['result'] = self.serialize(value, False, serializationContext)
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
# print('failure', method, e, tb)
|
||||
@@ -314,7 +317,7 @@ class RpcPeer:
|
||||
result, type(e).__name__, str(e), tb)
|
||||
|
||||
if not message.get('oneway', False):
|
||||
self.send(result)
|
||||
self.send(result, serializationContext)
|
||||
|
||||
elif messageType == 'result':
|
||||
id = message['id']
|
||||
@@ -331,7 +334,7 @@ class RpcPeer:
|
||||
future.set_exception(e)
|
||||
return
|
||||
future.set_result(self.deserialize(
|
||||
message.get('result', None)))
|
||||
message.get('result', None), deserializationContext))
|
||||
elif messageType == 'finalize':
|
||||
finalizerId = message.get('__local_proxy_finalizer_id', None)
|
||||
proxyId = message['__local_proxy_id']
|
||||
|
||||
@@ -293,9 +293,9 @@ export class PluginHost {
|
||||
}
|
||||
}
|
||||
|
||||
this.peer = new RpcPeer('host', this.pluginId, (message, reject) => {
|
||||
this.peer = new RpcPeer('host', this.pluginId, (message, reject, serializationContext) => {
|
||||
if (connected) {
|
||||
this.worker.send(message, reject);
|
||||
this.worker.send(message, reject, serializationContext);
|
||||
}
|
||||
else if (reject) {
|
||||
reject(new Error('peer disconnected'));
|
||||
|
||||
@@ -5,10 +5,12 @@ import path from 'path';
|
||||
import readline from 'readline';
|
||||
import { Readable, Writable } from 'stream';
|
||||
import { RpcMessage, RpcPeer } from "../../rpc";
|
||||
import { createRpcDuplexSerializer } from '../../rpc-serializer';
|
||||
import { ChildProcessWorker } from "./child-process-worker";
|
||||
import { RuntimeWorkerOptions } from "./runtime-worker";
|
||||
|
||||
export class PythonRuntimeWorker extends ChildProcessWorker {
|
||||
serializer: ReturnType<typeof createRpcDuplexSerializer>;
|
||||
|
||||
constructor(pluginId: string, options: RuntimeWorkerOptions) {
|
||||
super(pluginId, options);
|
||||
@@ -62,21 +64,24 @@ export class PythonRuntimeWorker extends ChildProcessWorker {
|
||||
const peerin = this.worker.stdio[3] as Writable;
|
||||
const peerout = this.worker.stdio[4] as Readable;
|
||||
|
||||
peerin.on('error', e => this.emit('error', e));
|
||||
peerout.on('error', e => this.emit('error', e));
|
||||
|
||||
const readInterface = readline.createInterface({
|
||||
input: peerout,
|
||||
terminal: false,
|
||||
const serializer = this.serializer = createRpcDuplexSerializer(peerin);
|
||||
serializer.setupRpcPeer(peer);
|
||||
peerout.on('data', data => serializer.onData(data));
|
||||
peerin.on('error', e => {
|
||||
this.emit('error', e);
|
||||
serializer.onDisconnected();
|
||||
});
|
||||
peerout.on('error', e => {
|
||||
this.emit('error', e)
|
||||
serializer.onDisconnected();
|
||||
});
|
||||
readInterface.on('line', line => peer.handleMessage(JSON.parse(line)));
|
||||
}
|
||||
|
||||
send(message: RpcMessage, reject?: (e: Error) => void): void {
|
||||
send(message: RpcMessage, reject?: (e: Error) => void, serializationContext?: any): void {
|
||||
try {
|
||||
if (!this.worker)
|
||||
throw new Error('worked has been killed');
|
||||
(this.worker.stdio[3] as Writable).write(JSON.stringify(message) + '\n', e => e && reject?.(e));
|
||||
this.serializer.sendMessage(message, reject, serializationContext);
|
||||
}
|
||||
catch (e) {
|
||||
reject?.(e);
|
||||
|
||||
Reference in New Issue
Block a user