mirror of
https://github.com/koush/scrypted.git
synced 2026-02-03 14:13:28 +00:00
run_coroutine_threadsafe is designed for scheduling coroutines from a different thread onto the event loop. Since readLoop is already running as an async function on the event loop, using create_task is the correct and more efficient approach. This removes unnecessary thread-safe queue overhead for every RPC message. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
253 lines
7.3 KiB
Python
253 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import concurrent.futures
|
|
import json
|
|
import multiprocessing.connection
|
|
import os
|
|
import pickle
|
|
import threading
|
|
from asyncio.events import AbstractEventLoop
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import List
|
|
|
|
import rpc
|
|
|
|
|
|
class BufferSerializer(rpc.RpcSerializer):
|
|
def serialize(self, value, serializationContext):
|
|
return base64.b64encode(value).decode("utf8")
|
|
|
|
def deserialize(self, value, serializationContext):
|
|
return base64.b64decode(value)
|
|
|
|
|
|
class SidebandBufferSerializer(rpc.RpcSerializer):
|
|
def serialize(self, value, serializationContext):
|
|
buffers = serializationContext.get("buffers", None)
|
|
if not buffers:
|
|
buffers = []
|
|
serializationContext["buffers"] = buffers
|
|
buffers.append(value)
|
|
return len(buffers) - 1
|
|
|
|
def deserialize(self, value, serializationContext):
|
|
buffers: List = serializationContext.get("buffers", None)
|
|
buffer = buffers.pop()
|
|
return buffer
|
|
|
|
|
|
class RpcTransport:
|
|
async def prepare(self):
|
|
pass
|
|
|
|
async def read(self):
|
|
pass
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
pass
|
|
|
|
def writeSerialized(self, json, reject):
|
|
pass
|
|
|
|
|
|
class RpcFileTransport(RpcTransport):
|
|
def __init__(self, readFd: int, writeFd: int) -> None:
|
|
super().__init__()
|
|
self.readFd = readFd
|
|
self.writeFd = writeFd
|
|
self.executor = ThreadPoolExecutor(1, "rpc-read")
|
|
|
|
def osReadExact(self, size: int):
|
|
b = bytes(0)
|
|
while size:
|
|
got = os.read(self.readFd, size)
|
|
if not len(got):
|
|
self.executor.shutdown(False)
|
|
raise Exception("rpc end of stream reached")
|
|
size -= len(got)
|
|
b += got
|
|
return b
|
|
|
|
def readMessageInternal(self):
|
|
lengthBytes = self.osReadExact(4)
|
|
typeBytes = self.osReadExact(1)
|
|
type = typeBytes[0]
|
|
length = int.from_bytes(lengthBytes, "big")
|
|
data = self.osReadExact(length - 1)
|
|
if type == 1:
|
|
return data
|
|
message = json.loads(data)
|
|
return message
|
|
|
|
async def read(self):
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
self.executor, lambda: self.readMessageInternal()
|
|
)
|
|
|
|
def writeMessage(self, type: int, buffer, reject):
|
|
length = len(buffer) + 1
|
|
lb = length.to_bytes(4, "big")
|
|
try:
|
|
for b in [lb, bytes([type]), buffer]:
|
|
os.write(self.writeFd, b)
|
|
except Exception as e:
|
|
if reject:
|
|
reject(e)
|
|
|
|
def writeSerialized(self, j, reject):
|
|
return self.writeMessage(
|
|
0, bytes(json.dumps(j, allow_nan=False), "utf8"), reject
|
|
)
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
return self.writeMessage(1, buffer, reject)
|
|
|
|
|
|
class RpcStreamTransport(RpcTransport):
|
|
def __init__(
|
|
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
) -> None:
|
|
super().__init__()
|
|
self.reader = reader
|
|
self.writer = writer
|
|
|
|
async def read(self):
|
|
lengthBytes = await self.reader.readexactly(4)
|
|
typeBytes = await self.reader.readexactly(1)
|
|
type = typeBytes[0]
|
|
length = int.from_bytes(lengthBytes, "big")
|
|
data = await self.reader.readexactly(length - 1)
|
|
if type == 1:
|
|
return data
|
|
message = json.loads(data)
|
|
return message
|
|
|
|
def writeMessage(self, type: int, buffer, reject):
|
|
length = len(buffer) + 1
|
|
lb = length.to_bytes(4, "big")
|
|
try:
|
|
for b in [lb, bytes([type]), buffer]:
|
|
self.writer.write(b)
|
|
except Exception as e:
|
|
if reject:
|
|
reject(e)
|
|
|
|
def writeSerialized(self, j, reject):
|
|
return self.writeMessage(
|
|
0, bytes(json.dumps(j, allow_nan=False), "utf8"), reject
|
|
)
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
return self.writeMessage(1, buffer, reject)
|
|
|
|
|
|
class RpcPickleStreamTransport(RpcTransport):
|
|
def __init__(
|
|
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, pickler = pickle
|
|
) -> None:
|
|
super().__init__()
|
|
self.reader = reader
|
|
self.writer = writer
|
|
self.pickler = pickler
|
|
|
|
async def read(self):
|
|
lengthBytes = await self.reader.readexactly(4)
|
|
typeBytes = await self.reader.readexactly(1)
|
|
type = typeBytes[0]
|
|
length = int.from_bytes(lengthBytes, "big")
|
|
data = await self.reader.readexactly(length - 1)
|
|
if type == 1:
|
|
return data
|
|
message = self.pickler.loads(data)
|
|
return message
|
|
|
|
def writeMessage(self, type: int, buffer, reject):
|
|
length = len(buffer) + 1
|
|
lb = length.to_bytes(4, "big")
|
|
try:
|
|
for b in [lb, bytes([type]), buffer]:
|
|
self.writer.write(b)
|
|
except Exception as e:
|
|
if reject:
|
|
reject(e)
|
|
|
|
def writeSerialized(self, j, reject):
|
|
pickled = self.pickler.dumps(j)
|
|
return self.writeMessage(0, pickled, reject)
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
return self.writeMessage(1, buffer, reject)
|
|
|
|
|
|
class RpcConnectionTransport(RpcTransport):
|
|
def __init__(self, connection: multiprocessing.connection.Connection) -> None:
|
|
super().__init__()
|
|
self.connection = connection
|
|
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
|
|
async def read(self):
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
self.executor, lambda: self.connection.recv()
|
|
)
|
|
|
|
def writeMessage(self, json, reject):
|
|
try:
|
|
self.connection.send(json)
|
|
except Exception as e:
|
|
if reject:
|
|
reject(e)
|
|
|
|
def writeSerialized(self, json, reject):
|
|
return self.writeMessage(json, reject)
|
|
|
|
def writeBuffer(self, buffer, reject):
|
|
return self.writeMessage(bytes(buffer), reject)
|
|
|
|
|
|
async def readLoop(peer: rpc.RpcPeer, rpcTransport: RpcTransport):
|
|
deserializationContext = {"buffers": []}
|
|
|
|
while True:
|
|
message = await rpcTransport.read()
|
|
|
|
if type(message) != dict:
|
|
deserializationContext["buffers"].append(message)
|
|
continue
|
|
|
|
asyncio.create_task(peer.handleMessage(message, deserializationContext))
|
|
|
|
deserializationContext = {"buffers": []}
|
|
|
|
|
|
async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport):
|
|
await rpcTransport.prepare()
|
|
|
|
mutex = threading.Lock()
|
|
|
|
def send(message, reject=None, serializationContext=None):
|
|
with mutex:
|
|
if serializationContext:
|
|
buffers = serializationContext.get("buffers", None)
|
|
if buffers:
|
|
for buffer in buffers:
|
|
rpcTransport.writeBuffer(buffer, reject)
|
|
|
|
rpcTransport.writeSerialized(message, reject)
|
|
|
|
peer = rpc.RpcPeer(send)
|
|
peer.nameDeserializerMap["Buffer"] = SidebandBufferSerializer()
|
|
peer.constructorSerializerMap[bytes] = "Buffer"
|
|
peer.constructorSerializerMap[bytearray] = "Buffer"
|
|
peer.constructorSerializerMap[memoryview] = "Buffer"
|
|
|
|
async def peerReadLoop():
|
|
try:
|
|
await readLoop(peer, rpcTransport)
|
|
except:
|
|
peer.kill()
|
|
raise
|
|
|
|
return peer, peerReadLoop
|