Files
scrypted/server/python/rpc_reader.py
2024-11-10 12:30:05 -08:00

209 lines
6.1 KiB
Python

from __future__ import annotations
import asyncio
import base64
import json
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from asyncio.events import AbstractEventLoop
from typing import List, Any
import multiprocessing.connection
import rpc
import concurrent.futures
import json
class BufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
return base64.b64encode(value).decode('utf8')
def deserialize(self, value, serializationContext):
return base64.b64decode(value)
class SidebandBufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
buffers = serializationContext.get('buffers', None)
if not buffers:
buffers = []
serializationContext['buffers'] = buffers
buffers.append(value)
return len(buffers) - 1
def deserialize(self, value, serializationContext):
buffers: List = serializationContext.get('buffers', None)
buffer = buffers.pop()
return buffer
class RpcTransport:
async def prepare(self):
pass
async def read(self):
pass
def writeBuffer(self, buffer, reject):
pass
def writeJSON(self, json, reject):
pass
class RpcFileTransport(RpcTransport):
def __init__(self, readFd: int, writeFd: int) -> None:
super().__init__()
self.readFd = readFd
self.writeFd = writeFd
self.executor = ThreadPoolExecutor(1, 'rpc-read')
def osReadExact(self, size: int):
b = bytes(0)
while size:
got = os.read(self.readFd, size)
if not len(got):
self.executor.shutdown(False)
raise Exception('rpc end of stream reached')
size -= len(got)
b += got
return b
def readMessageInternal(self):
lengthBytes = self.osReadExact(4)
typeBytes = self.osReadExact(1)
type = typeBytes[0]
length = int.from_bytes(lengthBytes, 'big')
data = self.osReadExact(length - 1)
if type == 1:
return data
message = json.loads(data)
return message
async def read(self):
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.readMessageInternal())
def writeMessage(self, type: int, buffer, reject):
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
try:
for b in [lb, bytes([type]), buffer]:
os.write(self.writeFd, b)
except Exception as e:
if reject:
reject(e)
def writeJSON(self, j, reject):
return self.writeMessage(0, bytes(json.dumps(j, allow_nan=False), 'utf8'), reject)
def writeBuffer(self, buffer, reject):
return self.writeMessage(1, buffer, reject)
class RpcStreamTransport(RpcTransport):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__()
self.reader = reader
self.writer = writer
async def read(self):
lengthBytes = await self.reader.readexactly(4)
typeBytes = await self.reader.readexactly(1)
type = typeBytes[0]
length = int.from_bytes(lengthBytes, 'big')
data = await self.reader.readexactly(length - 1)
if type == 1:
return data
message = json.loads(data)
return message
def writeMessage(self, type: int, buffer, reject):
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
try:
for b in [lb, bytes([type]), buffer]:
self.writer.write(b)
except Exception as e:
if reject:
reject(e)
def writeJSON(self, j, reject):
return self.writeMessage(0, bytes(json.dumps(j, allow_nan=False), 'utf8'), reject)
def writeBuffer(self, buffer, reject):
return self.writeMessage(1, buffer, reject)
class RpcConnectionTransport(RpcTransport):
def __init__(self, connection: multiprocessing.connection.Connection) -> None:
super().__init__()
self.connection = connection
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
async def read(self):
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.connection.recv())
def writeMessage(self, json, reject):
try:
self.connection.send(json)
except Exception as e:
if reject:
reject(e)
def writeJSON(self, json, reject):
return self.writeMessage(json, reject)
def writeBuffer(self, buffer, reject):
return self.writeMessage(bytes(buffer), reject)
async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport):
deserializationContext = {
'buffers': []
}
while True:
message = await rpcTransport.read()
if type(message) != dict:
deserializationContext['buffers'].append(message)
continue
asyncio.run_coroutine_threadsafe(
peer.handleMessage(message, deserializationContext), loop)
deserializationContext = {
'buffers': []
}
async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport):
await rpcTransport.prepare()
mutex = threading.Lock()
def send(message, reject=None, serializationContext=None):
with mutex:
if serializationContext:
buffers = serializationContext.get('buffers', None)
if buffers:
for buffer in buffers:
rpcTransport.writeBuffer(buffer, reject)
rpcTransport.writeJSON(message, reject)
peer = rpc.RpcPeer(send)
peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer()
peer.constructorSerializerMap[bytes] = 'Buffer'
peer.constructorSerializerMap[bytearray] = 'Buffer'
peer.constructorSerializerMap[memoryview] = 'Buffer'
async def peerReadLoop():
try:
await readLoop(loop, peer, rpcTransport)
except:
peer.kill()
raise
return peer, peerReadLoop