mirror of
https://github.com/koush/scrypted.git
synced 2026-02-03 14:13:28 +00:00
209 lines
6.1 KiB
Python
209 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import os
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from asyncio.events import AbstractEventLoop
|
|
from typing import List, Any
|
|
import multiprocessing.connection
|
|
import rpc
|
|
import concurrent.futures
|
|
import json
|
|
|
|
|
|
class BufferSerializer(rpc.RpcSerializer):
|
|
def serialize(self, value, serializationContext):
|
|
return base64.b64encode(value).decode('utf8')
|
|
|
|
def deserialize(self, value, serializationContext):
|
|
return base64.b64decode(value)
|
|
|
|
|
|
class SidebandBufferSerializer(rpc.RpcSerializer):
|
|
def serialize(self, value, serializationContext):
|
|
buffers = serializationContext.get('buffers', None)
|
|
if not buffers:
|
|
buffers = []
|
|
serializationContext['buffers'] = buffers
|
|
buffers.append(value)
|
|
return len(buffers) - 1
|
|
|
|
def deserialize(self, value, serializationContext):
|
|
buffers: List = serializationContext.get('buffers', None)
|
|
buffer = buffers.pop()
|
|
return buffer
|
|
|
|
|
|
class RpcTransport:
|
|
async def prepare(self):
|
|
pass
|
|
|
|
async def read(self):
|
|
pass
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
pass
|
|
|
|
def writeJSON(self, json, reject):
|
|
pass
|
|
|
|
|
|
class RpcFileTransport(RpcTransport):
|
|
def __init__(self, readFd: int, writeFd: int) -> None:
|
|
super().__init__()
|
|
self.readFd = readFd
|
|
self.writeFd = writeFd
|
|
self.executor = ThreadPoolExecutor(1, 'rpc-read')
|
|
|
|
def osReadExact(self, size: int):
|
|
b = bytes(0)
|
|
while size:
|
|
got = os.read(self.readFd, size)
|
|
if not len(got):
|
|
self.executor.shutdown(False)
|
|
raise Exception('rpc end of stream reached')
|
|
size -= len(got)
|
|
b += got
|
|
return b
|
|
|
|
def readMessageInternal(self):
|
|
lengthBytes = self.osReadExact(4)
|
|
typeBytes = self.osReadExact(1)
|
|
type = typeBytes[0]
|
|
length = int.from_bytes(lengthBytes, 'big')
|
|
data = self.osReadExact(length - 1)
|
|
if type == 1:
|
|
return data
|
|
message = json.loads(data)
|
|
return message
|
|
|
|
async def read(self):
|
|
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.readMessageInternal())
|
|
|
|
def writeMessage(self, type: int, buffer, reject):
|
|
length = len(buffer) + 1
|
|
lb = length.to_bytes(4, 'big')
|
|
try:
|
|
for b in [lb, bytes([type]), buffer]:
|
|
os.write(self.writeFd, b)
|
|
except Exception as e:
|
|
if reject:
|
|
reject(e)
|
|
|
|
def writeJSON(self, j, reject):
|
|
return self.writeMessage(0, bytes(json.dumps(j, allow_nan=False), 'utf8'), reject)
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
return self.writeMessage(1, buffer, reject)
|
|
|
|
|
|
class RpcStreamTransport(RpcTransport):
|
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
|
super().__init__()
|
|
self.reader = reader
|
|
self.writer = writer
|
|
|
|
async def read(self):
|
|
lengthBytes = await self.reader.readexactly(4)
|
|
typeBytes = await self.reader.readexactly(1)
|
|
type = typeBytes[0]
|
|
length = int.from_bytes(lengthBytes, 'big')
|
|
data = await self.reader.readexactly(length - 1)
|
|
if type == 1:
|
|
return data
|
|
message = json.loads(data)
|
|
return message
|
|
|
|
def writeMessage(self, type: int, buffer, reject):
|
|
length = len(buffer) + 1
|
|
lb = length.to_bytes(4, 'big')
|
|
try:
|
|
for b in [lb, bytes([type]), buffer]:
|
|
self.writer.write(b)
|
|
except Exception as e:
|
|
if reject:
|
|
reject(e)
|
|
|
|
def writeJSON(self, j, reject):
|
|
return self.writeMessage(0, bytes(json.dumps(j, allow_nan=False), 'utf8'), reject)
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
return self.writeMessage(1, buffer, reject)
|
|
|
|
|
|
class RpcConnectionTransport(RpcTransport):
|
|
def __init__(self, connection: multiprocessing.connection.Connection) -> None:
|
|
super().__init__()
|
|
self.connection = connection
|
|
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
|
|
async def read(self):
|
|
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.connection.recv())
|
|
|
|
def writeMessage(self, json, reject):
|
|
try:
|
|
self.connection.send(json)
|
|
except Exception as e:
|
|
if reject:
|
|
reject(e)
|
|
|
|
def writeJSON(self, json, reject):
|
|
return self.writeMessage(json, reject)
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
return self.writeMessage(bytes(buffer), reject)
|
|
|
|
|
|
async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport):
|
|
deserializationContext = {
|
|
'buffers': []
|
|
}
|
|
|
|
while True:
|
|
message = await rpcTransport.read()
|
|
|
|
if type(message) != dict:
|
|
deserializationContext['buffers'].append(message)
|
|
continue
|
|
|
|
asyncio.run_coroutine_threadsafe(
|
|
peer.handleMessage(message, deserializationContext), loop)
|
|
|
|
deserializationContext = {
|
|
'buffers': []
|
|
}
|
|
|
|
|
|
async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport):
|
|
await rpcTransport.prepare()
|
|
|
|
mutex = threading.Lock()
|
|
|
|
def send(message, reject=None, serializationContext=None):
|
|
with mutex:
|
|
if serializationContext:
|
|
buffers = serializationContext.get('buffers', None)
|
|
if buffers:
|
|
for buffer in buffers:
|
|
rpcTransport.writeBuffer(buffer, reject)
|
|
|
|
rpcTransport.writeJSON(message, reject)
|
|
|
|
peer = rpc.RpcPeer(send)
|
|
peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer()
|
|
peer.constructorSerializerMap[bytes] = 'Buffer'
|
|
peer.constructorSerializerMap[bytearray] = 'Buffer'
|
|
peer.constructorSerializerMap[memoryview] = 'Buffer'
|
|
|
|
async def peerReadLoop():
|
|
try:
|
|
await readLoop(loop, peer, rpcTransport)
|
|
except:
|
|
peer.kill()
|
|
raise
|
|
|
|
return peer, peerReadLoop
|