server: python formatting

This commit is contained in:
Koushik Dutta
2024-11-21 14:53:16 -08:00
parent cd0ab104ea
commit 2c5b79291f
10 changed files with 501 additions and 269 deletions

View File

@@ -47,5 +47,8 @@ def needs_cluster_fork_worker(options: ClusterForkOptions) -> bool:
return (
os.environ.get("SCRYPTED_CLUSTER_ADDRESS")
and options
and (not matches_cluster_labels(options, get_cluster_labels()) or options.get("clusterWorkerId", None))
and (
not matches_cluster_labels(options, get_cluster_labels())
or options.get("clusterWorkerId", None)
)
)

View File

@@ -12,6 +12,7 @@ import rpc
import rpc_reader
from typing import TypedDict
class ClusterObject(TypedDict):
id: str
address: str
@@ -20,12 +21,16 @@ class ClusterObject(TypedDict):
sourceKey: str
sha256: str
def isClusterAddress(address: str):
return not address or address == os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None)
def getClusterPeerKey(address: str, port: int):
return f"{address}:{port}"
class ClusterSetup():
class ClusterSetup:
def __init__(self, loop: AbstractEventLoop, peer: rpc.RpcPeer):
self.loop = loop
self.peer = peer
@@ -50,9 +55,13 @@ class ClusterSetup():
sha256 = self.computeClusterObjectHash(o)
if sha256 != o["sha256"]:
raise Exception("secret incorrect")
return await self.resolveObject(o.get('proxyId', None), o.get('sourceKey', None))
return await self.resolveObject(
o.get("proxyId", None), o.get("sourceKey", None)
)
def onProxySerialization(self, peer: rpc.RpcPeer, value: Any, sourceKey: str = None):
def onProxySerialization(
self, peer: rpc.RpcPeer, value: Any, sourceKey: str = None
):
properties: dict = rpc.RpcPeer.prepareProxyProperties(value) or {}
clusterEntry = properties.get("__cluster", None)
proxyId: str
@@ -126,7 +135,9 @@ class ClusterSetup():
handleClusterClient, listenAddress, 0
)
self.clusterPort = clusterRpcServer.sockets[0].getsockname()[1]
self.peer.onProxySerialization = lambda value: self.onProxySerialization(self.peer, value, None)
self.peer.onProxySerialization = lambda value: self.onProxySerialization(
self.peer, value, None
)
del self.peer.params["initializeCluster"]
def computeClusterObjectHash(self, o: ClusterObject) -> str:
@@ -215,9 +226,7 @@ class ClusterSetup():
peerConnectRPCObject = clusterPeer.tags.get("connectRPCObject")
if not peerConnectRPCObject:
peerConnectRPCObject = await clusterPeer.getParam(
"connectRPCObject"
)
peerConnectRPCObject = await clusterPeer.getParam("connectRPCObject")
clusterPeer.tags["connectRPCObject"] = peerConnectRPCObject
newValue = await peerConnectRPCObject(clusterObject)
if not newValue:

View File

@@ -1,5 +1,6 @@
import typing
async def writeWorkerGenerator(gen, out: typing.TextIO):
try:
async for item in gen:

View File

@@ -4,22 +4,25 @@ import sys
from typing import Any
import shutil
def get_requirements_files(requirements: str):
want_requirements = requirements + '.txt'
installed_requirementstxt = requirements + '.installed.txt'
want_requirements = requirements + ".txt"
installed_requirementstxt = requirements + ".installed.txt"
return want_requirements, installed_requirementstxt
def need_requirements(requirements_basename: str, requirements_str: str):
_, installed_requirementstxt = get_requirements_files(requirements_basename)
if not os.path.exists(installed_requirementstxt):
return True
try:
f = open(installed_requirementstxt, "rb")
installed_requirements = f.read().decode('utf8')
installed_requirements = f.read().decode("utf8")
return requirements_str != installed_requirements
except:
return True
def remove_pip_dirs(plugin_volume: str):
try:
for de in os.listdir(plugin_volume):
@@ -48,7 +51,9 @@ def install_with_pip(
ignore_error: bool = False,
site_packages: str = None,
):
requirementstxt, installed_requirementstxt = get_requirements_files(requirements_basename)
requirementstxt, installed_requirementstxt = get_requirements_files(
requirements_basename
)
os.makedirs(python_prefix, exist_ok=True)
@@ -81,15 +86,16 @@ def install_with_pip(
# force reinstall even if it exists in system packages.
pipArgs.append("--force-reinstall")
env = None
if site_packages:
env = dict(os.environ)
PYTHONPATH = env['PYTHONPATH'] or ''
PYTHONPATH += ':' + site_packages
PYTHONPATH = env["PYTHONPATH"] or ""
PYTHONPATH += ":" + site_packages
env["PYTHONPATH"] = PYTHONPATH
print("PYTHONPATH", env["PYTHONPATH"])
p = subprocess.Popen(pipArgs, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)
p = subprocess.Popen(
pipArgs, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env
)
while True:
line = p.stdout.readline()

View File

@@ -27,18 +27,22 @@ from cluster_setup import ClusterSetup
import cluster_labels
from plugin_pip import install_with_pip, need_requirements, remove_pip_dirs
from scrypted_python.scrypted_sdk import PluginFork, ScryptedStatic
from scrypted_python.scrypted_sdk.types import (Device, DeviceManifest,
EventDetails,
ScryptedInterface,
ScryptedInterfaceMethods,
ScryptedInterfaceProperty,
Storage)
from scrypted_python.scrypted_sdk.types import (
Device,
DeviceManifest,
EventDetails,
ScryptedInterface,
ScryptedInterfaceMethods,
ScryptedInterfaceProperty,
Storage,
)
SCRYPTED_REQUIREMENTS = """
ptpython
wheel
""".strip()
class SystemDeviceState(TypedDict):
lastEventTime: int
stateTime: int
@@ -47,8 +51,10 @@ class SystemDeviceState(TypedDict):
def ensure_not_coroutine(fn: Callable | Coroutine) -> Callable:
if inspect.iscoroutinefunction(fn):
def wrapper(*args, **kwargs):
return asyncio.create_task(fn(*args, **kwargs))
return wrapper
return fn
@@ -96,25 +102,29 @@ class DeviceProxy(object):
class EventListenerRegisterImpl(scrypted_python.scrypted_sdk.EventListenerRegister):
removeListener: Callable[[], None]
def __init__(self, removeListener: Callable[[], None] | Coroutine[Any, None, None]) -> None:
def __init__(
self, removeListener: Callable[[], None] | Coroutine[Any, None, None]
) -> None:
self.removeListener = ensure_not_coroutine(removeListener)
class EventRegistry(object):
systemListeners: Set[scrypted_python.scrypted_sdk.EventListener]
listeners: Mapping[str, Set[Callable[[scrypted_python.scrypted_sdk.EventDetails, Any], None]]]
listeners: Mapping[
str, Set[Callable[[scrypted_python.scrypted_sdk.EventDetails, Any], None]]
]
__allowedEventInterfaces = set([
ScryptedInterface.ScryptedDevice.value,
'Logger',
'Storage'
])
__allowedEventInterfaces = set(
[ScryptedInterface.ScryptedDevice.value, "Logger", "Storage"]
)
def __init__(self) -> None:
self.systemListeners = set()
self.listeners = {}
def __getMixinEventName(self, options: str | scrypted_python.scrypted_sdk.EventListenerOptions) -> str:
def __getMixinEventName(
self, options: str | scrypted_python.scrypted_sdk.EventListenerOptions
) -> str:
mixinId = None
if type(options) == str:
event = options
@@ -155,7 +165,15 @@ class EventRegistry(object):
self.listeners[id].add(callback)
return EventListenerRegisterImpl(lambda: self.listeners[id].remove(callback))
def notify(self, id: str, eventTime: int, eventInterface: str, property: str, value: Any, options: dict = None):
def notify(
self,
id: str,
eventTime: int,
eventInterface: str,
property: str,
value: Any,
options: dict = None,
):
options = options or {}
changed = options.get("changed")
mixinId = options.get("mixinId")
@@ -174,7 +192,13 @@ class EventRegistry(object):
return self.notifyEventDetails(id, eventDetails, value)
def notifyEventDetails(self, id: str, eventDetails: scrypted_python.scrypted_sdk.EventDetails, value: Any, eventInterface: str = None):
def notifyEventDetails(
self,
id: str,
eventDetails: scrypted_python.scrypted_sdk.EventDetails,
value: Any,
eventInterface: str = None,
):
if not eventDetails.get("eventId"):
eventDetails["eventId"] = self.__generateBase36Str()
if not eventInterface:
@@ -183,8 +207,9 @@ class EventRegistry(object):
# system listeners only get state changes.
# there are many potentially noisy stateless events, like
# object detection and settings changes
if (eventDetails.get("property") and not eventDetails.get("mixinId")) or \
(eventInterface in EventRegistry.__allowedEventInterfaces):
if (eventDetails.get("property") and not eventDetails.get("mixinId")) or (
eventInterface in EventRegistry.__allowedEventInterfaces
):
for listener in self.systemListeners:
listener(id, eventDetails, value)
@@ -202,6 +227,7 @@ class EventRegistry(object):
return True
class ClusterManager(scrypted_python.scrypted_sdk.types.ClusterManager):
def __init__(self, api: Any):
self.api = api
@@ -213,11 +239,16 @@ class ClusterManager(scrypted_python.scrypted_sdk.types.ClusterManager):
def getClusterWorkerId(self) -> str:
return os.getenv("SCRYPTED_CLUSTER_WORKER_ID", None)
async def getClusterWorkers(self) -> Mapping[str, scrypted_python.scrypted_sdk.types.ClusterWorker]:
self.clusterService = self.clusterService or asyncio.ensure_future(self.api.getComponent("cluster-fork"))
async def getClusterWorkers(
self,
) -> Mapping[str, scrypted_python.scrypted_sdk.types.ClusterWorker]:
self.clusterService = self.clusterService or asyncio.ensure_future(
self.api.getComponent("cluster-fork")
)
cs = await self.clusterService
return await cs.getClusterWorkers()
class SystemManager(scrypted_python.scrypted_sdk.types.SystemManager):
def __init__(
self, api: Any, systemState: Mapping[str, Mapping[str, SystemDeviceState]]
@@ -306,19 +337,27 @@ class SystemManager(scrypted_python.scrypted_sdk.types.SystemManager):
callback = ensure_not_coroutine(callback)
if type(options) != str and options.get("watch"):
return self.events.listenDevice(
id, options,
lambda eventDetails, eventData: callback(self.getDeviceById(id), eventDetails, eventData)
id,
options,
lambda eventDetails, eventData: callback(
self.getDeviceById(id), eventDetails, eventData
),
)
register_fut = asyncio.ensure_future(
self.api.listenDevice(
id, options,
lambda eventDetails, eventData: callback(self.getDeviceById(id), eventDetails, eventData)
id,
options,
lambda eventDetails, eventData: callback(
self.getDeviceById(id), eventDetails, eventData
),
)
)
async def unregister():
register = await register_fut
await register.removeListener()
return EventListenerRegisterImpl(lambda: asyncio.ensure_future(unregister()))
async def removeDevice(self, id: str) -> None:
@@ -555,6 +594,7 @@ class DeviceManager(scrypted_python.scrypted_sdk.types.DeviceManager):
def getDeviceStorage(self, nativeId: str = None) -> Storage:
return self.nativeIds.get(nativeId, None)
class PeerLiveness:
def __init__(self, loop: AbstractEventLoop):
self.killed = Future(loop=loop)
@@ -562,15 +602,22 @@ class PeerLiveness:
async def waitKilled(self):
await self.killed
def safe_set_result(fut: Future, result: Any):
try:
fut.set_result(result)
except:
pass
class PluginRemote:
def __init__(
self, clusterSetup: ClusterSetup, api, pluginId: str, hostInfo, loop: AbstractEventLoop
self,
clusterSetup: ClusterSetup,
api,
pluginId: str,
hostInfo,
loop: AbstractEventLoop,
):
self.systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {}
self.nativeIds: Mapping[str, DeviceStorage] = {}
@@ -606,7 +653,9 @@ class PluginRemote:
consoleFuture = Future()
self.consoles[nativeId] = consoleFuture
plugins = await self.api.getComponent("plugins")
port, hostname = await plugins.getRemoteServicePort(self.pluginId, "console-writer")
port, hostname = await plugins.getRemoteServicePort(
self.pluginId, "console-writer"
)
connection = await asyncio.open_connection(host=hostname, port=port)
_, writer = connection
if not nativeId:
@@ -682,7 +731,7 @@ class PluginRemote:
if not forkMain:
multiprocessing.set_start_method("spawn")
# forkMain may be set to true, but the environment may not be initialized
# if the plugin is loaded in another cluster worker.
# instead rely on a environemnt variable that will be passed to
@@ -819,10 +868,13 @@ class PluginRemote:
async def getZip(self):
return await zipAPI.getZip()
return await remote.loadZip(packageJson, PluginZipAPI(), forkOptions)
return await remote.loadZip(
packageJson, PluginZipAPI(), forkOptions
)
if cluster_labels.needs_cluster_fork_worker(options):
peerLiveness = PeerLiveness(self.loop)
async def getClusterFork():
runtimeWorkerOptions = {
"packageJson": packageJson,
@@ -835,14 +887,17 @@ class PluginRemote:
forkComponent = await self.api.getComponent("cluster-fork")
sanitizedOptions = options.copy()
sanitizedOptions["runtime"] = sanitizedOptions.get("runtime", "python")
sanitizedOptions["runtime"] = sanitizedOptions.get(
"runtime", "python"
)
sanitizedOptions["zipHash"] = zipHash
clusterForkResult = await forkComponent.fork(
runtimeWorkerOptions,
sanitizedOptions,
peerLiveness, lambda: zipAPI.getZip()
peerLiveness,
lambda: zipAPI.getZip(),
)
async def waitPeerLiveness():
try:
await peerLiveness.waitKilled()
@@ -851,6 +906,7 @@ class PluginRemote:
await clusterForkResult.kill()
except:
pass
asyncio.ensure_future(waitPeerLiveness(), loop=self.loop)
async def waitClusterForkKilled():
@@ -859,30 +915,48 @@ class PluginRemote:
except:
pass
safe_set_result(peerLiveness.killed, None)
asyncio.ensure_future(waitClusterForkKilled(), loop=self.loop)
clusterGetRemote = await self.clusterSetup.connectRPCObject(await clusterForkResult.getResult())
clusterGetRemote = await self.clusterSetup.connectRPCObject(
await clusterForkResult.getResult()
)
remoteDict = await clusterGetRemote()
asyncio.ensure_future(plugin_console.writeWorkerGenerator(remoteDict["stdout"], sys.stdout))
asyncio.ensure_future(plugin_console.writeWorkerGenerator(remoteDict["stderr"], sys.stderr))
asyncio.ensure_future(
plugin_console.writeWorkerGenerator(
remoteDict["stdout"], sys.stdout
)
)
asyncio.ensure_future(
plugin_console.writeWorkerGenerator(
remoteDict["stderr"], sys.stderr
)
)
getRemote = remoteDict["getRemote"]
directGetRemote = await self.clusterSetup.connectRPCObject(getRemote)
directGetRemote = await self.clusterSetup.connectRPCObject(
getRemote
)
if directGetRemote is getRemote:
raise Exception("cluster fork peer not direct connected")
forkPeer = getattr(directGetRemote, rpc.RpcPeer.PROPERTY_PROXY_PEER)
forkPeer = getattr(
directGetRemote, rpc.RpcPeer.PROPERTY_PROXY_PEER
)
return await finishFork(forkPeer)
pluginFork = PluginFork()
pluginFork.result = asyncio.create_task(getClusterFork())
async def waitKilled():
await peerLiveness.killed
pluginFork.exit = asyncio.create_task(waitKilled())
def terminate():
safe_set_result(peerLiveness.killed, None)
pluginFork.worker.terminate()
pluginFork.terminate = terminate
pluginFork.worker = None
@@ -902,12 +976,16 @@ class PluginRemote:
pluginFork = PluginFork()
killed = Future(loop=self.loop)
async def waitKilled():
await killed
pluginFork.exit = asyncio.create_task(waitKilled())
def terminate():
safe_set_result(killed, None)
pluginFork.worker.kill()
pluginFork.terminate = terminate
pluginFork.worker = multiprocessing.Process(
@@ -956,6 +1034,7 @@ class PluginRemote:
# sdk.
from scrypted_sdk import sdk_init2 # type: ignore
sdk_init2(sdk)
except:
from scrypted_sdk import sdk_init # type: ignore
@@ -1048,7 +1127,9 @@ async def plugin_async_main(
peer.params["print"] = print
clusterSetup = ClusterSetup(loop, peer)
peer.params["initializeCluster"] = lambda options: clusterSetup.initializeCluster(options)
peer.params["initializeCluster"] = lambda options: clusterSetup.initializeCluster(
options
)
async def ping(time: int):
return time
@@ -1077,6 +1158,7 @@ def main(rpcTransport: rpc_reader.RpcTransport):
loop.run_until_complete(plugin_async_main(loop, rpcTransport))
loop.close()
def plugin_fork(conn: multiprocessing.connection.Connection):
main(rpc_reader.RpcConnectionTransport(conn))

View File

@@ -39,6 +39,8 @@ ColorDepth.default = lambda *args, **kwargs: ColorDepth.DEPTH_4_BIT
# the library. The patches here allow us to scope a particular call stack
# to a particular REPL, and to get the current Application from the stack.
default_get_app = prompt_toolkit.application.current.get_app
def get_app_patched() -> Application[Any]:
stack = inspect.stack()
for frame in stack:
@@ -46,6 +48,8 @@ def get_app_patched() -> Application[Any]:
if self_var is not None and isinstance(self_var, Application):
return self_var
return default_get_app()
prompt_toolkit.application.current.get_app = get_app_patched
prompt_toolkit.key_binding.key_processor.get_app = get_app_patched
prompt_toolkit.contrib.telnet.server.get_app = get_app_patched
@@ -141,7 +145,9 @@ async def eval_async_patched(self: PythonRepl, line: str) -> object:
def eval_across_loops(code, *args, **kwargs):
future = concurrent.futures.Future()
scrypted_loop.call_soon_threadsafe(partial(eval_in_scrypted, future), code, *args, **kwargs)
scrypted_loop.call_soon_threadsafe(
partial(eval_in_scrypted, future), code, *args, **kwargs
)
return future.result()
# WORKAROUND: Due to a bug in Jedi, the current directory is removed
@@ -192,7 +198,7 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(None)
sock.bind(('localhost', 0))
sock.bind(("localhost", 0))
sock.listen()
scrypted_loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
@@ -222,7 +228,7 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
# Select a free port for the telnet server
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('localhost', 0))
s.bind(("localhost", 0))
telnet_port = s.getsockname()[1]
s.close()
@@ -230,14 +236,19 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
# repl_loop owns the print capabilities, but the prints will
# be executed in scrypted_loop. We need to bridge the two here
repl_print = partial(print_formatted_text, output=connection.vt100_output)
def print_across_loops(*args, **kwargs):
repl_loop.call_soon_threadsafe(repl_print, *args, **kwargs)
global_dict = {
**globals(),
"print": print_across_loops,
"help": lambda *args, **kwargs: print_across_loops("Help is not available in this environment"),
"input": lambda *args, **kwargs: print_across_loops("Input is not available in this environment"),
"help": lambda *args, **kwargs: print_across_loops(
"Help is not available in this environment"
),
"input": lambda *args, **kwargs: print_across_loops(
"Input is not available in this environment"
),
}
locals_dict = {
"device": device,
@@ -245,19 +256,32 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
"deviceManager": deviceManager,
"mediaManager": mediaManager,
"sdk": sdk,
"realDevice": realDevice
"realDevice": realDevice,
}
vars_prompt = '\n'.join([f" {k}" for k in locals_dict.keys()])
vars_prompt = "\n".join([f" {k}" for k in locals_dict.keys()])
banner = f"Python REPL variables:\n{vars_prompt}"
print_formatted_text(banner)
await embed(return_asyncio_coroutine=True, globals=global_dict, locals=locals_dict, configure=partial(configure, scrypted_loop))
await embed(
return_asyncio_coroutine=True,
globals=global_dict,
locals=locals_dict,
configure=partial(configure, scrypted_loop),
)
server_task: asyncio.Task = None
def ready_cb():
future.set_result((telnet_port, lambda: repl_loop.call_soon_threadsafe(server_task.cancel)))
future.set_result(
(
telnet_port,
lambda: repl_loop.call_soon_threadsafe(server_task.cancel),
)
)
# Start the REPL server
telnet_server = TelnetServer(interact=interact, port=telnet_port, enable_cpr=False)
telnet_server = TelnetServer(
interact=interact, port=telnet_port, enable_cpr=False
)
server_task = asyncio.create_task(telnet_server.run(ready_cb=ready_cb))
try:
await server_task
@@ -277,16 +301,19 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
def finish_setup():
telnet_port, exit_server = server_started_future.result()
telnet_client = telnetlib.Telnet('localhost', telnet_port, timeout=None)
telnet_client = telnetlib.Telnet("localhost", telnet_port, timeout=None)
def telnet_negotiation_cb(telnet_socket, command, option):
pass # ignore telnet negotiation
telnet_client.set_option_negotiation_callback(telnet_negotiation_cb)
# initialize telnet terminal
# this tells the telnet server we are a vt100 terminal
telnet_client.get_socket().sendall(b'\xff\xfb\x18\xff\xfa\x18\x00\x61\x6e\x73\x69\xff\xf0')
telnet_client.get_socket().sendall(b'\r\n')
telnet_client.get_socket().sendall(
b"\xff\xfb\x18\xff\xfa\x18\x00\x61\x6e\x73\x69\xff\xf0"
)
telnet_client.get_socket().sendall(b"\r\n")
# Bridge the connection to the telnet server, two way
def forward_to_telnet():
@@ -303,7 +330,7 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
while True:
data = telnet_client.read_some()
if not data:
conn.sendall('REPL exited'.encode())
conn.sendall("REPL exited".encode())
break
if b">>>" in data:
# This is an ugly hack - somewhere in ptpython, the
@@ -333,4 +360,4 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
threading.Thread(target=accept_connection).start()
proxy_port = sock.getsockname()[1]
return proxy_port
return proxy_port

View File

@@ -1,20 +1,24 @@
import os
from pathlib import Path
def get_scrypted_volume():
volume_dir = os.getenv('SCRYPTED_VOLUME') or Path.home() / '.scrypted' / 'volume'
volume_dir = os.getenv("SCRYPTED_VOLUME") or Path.home() / ".scrypted" / "volume"
return str(volume_dir)
def get_plugins_volume():
volume = get_scrypted_volume()
plugins_volume = Path(volume) / 'plugins'
plugins_volume = Path(volume) / "plugins"
return str(plugins_volume)
def get_plugin_volume(plugin_id):
volume = get_plugins_volume()
plugin_volume = Path(volume) / plugin_id
return str(plugin_volume)
def ensure_plugin_volume(plugin_id):
plugin_volume = get_plugin_volume(plugin_id)
try:
@@ -23,23 +27,25 @@ def ensure_plugin_volume(plugin_id):
pass
return plugin_volume
def create_adm_zip_hash(hash):
extract_version = "1-"
return extract_version + hash
def prep(plugin_volume, hash):
hash = create_adm_zip_hash(hash)
zip_filename = f"{hash}.zip"
zip_dir = os.path.join(plugin_volume, 'zip')
zip_dir = os.path.join(plugin_volume, "zip")
zip_file = os.path.join(zip_dir, zip_filename)
unzipped_path = os.path.join(zip_dir, 'unzipped')
zip_dir_tmp = zip_dir + '.tmp'
unzipped_path = os.path.join(zip_dir, "unzipped")
zip_dir_tmp = zip_dir + ".tmp"
return {
'unzipped_path': unzipped_path,
'zip_filename': zip_filename,
'zip_dir': zip_dir,
'zip_file': zip_file,
'zip_dir_tmp': zip_dir_tmp,
}
"unzipped_path": unzipped_path,
"zip_filename": zip_filename,
"zip_dir": zip_dir,
"zip_file": zip_file,
"zip_dir_tmp": zip_dir_tmp,
}

View File

@@ -1,22 +1,24 @@
import asyncio
import rpc
from rpc_reader import prepare_peer_readloop, RpcFileTransport
from rpc_reader import prepare_peer_readloop, RpcFileTransport
import traceback
class Bar:
pass
async def main():
peer, peerReadLoop = await prepare_peer_readloop(loop, RpcFileTransport(4, 3))
peer.params['foo'] = 3
peer.params["foo"] = 3
jsoncopy = {}
jsoncopy[rpc.RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN] = True
jsoncopy['bar'] = Bar()
peer.params['bar'] = jsoncopy
jsoncopy["bar"] = Bar()
peer.params["bar"] = jsoncopy
# reader, writer = await asyncio.open_connection(
# '127.0.0.1', 6666)
# writer.write(bytes('abcd', 'utf8'))
# async def ticker(delay, to):
@@ -27,12 +29,12 @@ async def main():
# peer.params['ticker'] = ticker(0, 3)
print('python starting')
print("python starting")
# await peerReadLoop()
asyncio.ensure_future(peerReadLoop())
# print('getting param')
test = await peer.getParam('test')
test = await peer.getParam("test")
print(test)
try:
i = 0
@@ -43,7 +45,8 @@ async def main():
i = i + 1
except:
traceback.print_exc()
print('all done iterating')
print("all done iterating")
loop = asyncio.new_event_loop()
loop.run_until_complete(main())

View File

@@ -16,7 +16,7 @@ jsonSerializable.add(list)
async def maybe_await(value):
if (inspect.isawaitable(value)):
if inspect.isawaitable(value):
return await value
return value
@@ -58,69 +58,112 @@ class LocalProxiedEntry(TypedDict):
class RpcProxy(object):
def __init__(self, peer, entry: LocalProxiedEntry, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]):
self.__dict__['__proxy_id'] = entry['id']
self.__dict__['__proxy_entry'] = entry
self.__dict__['__proxy_constructor'] = proxyConstructorName
def __init__(
self,
peer,
entry: LocalProxiedEntry,
proxyConstructorName: str,
proxyProps: any,
proxyOneWayMethods: List[str],
):
self.__dict__["__proxy_id"] = entry["id"]
self.__dict__["__proxy_entry"] = entry
self.__dict__["__proxy_constructor"] = proxyConstructorName
self.__dict__[RpcPeer.PROPERTY_PROXY_PEER] = peer
self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] = proxyProps
self.__dict__['__proxy_oneway_methods'] = proxyOneWayMethods
self.__dict__["__proxy_oneway_methods"] = proxyOneWayMethods
def __aiter__(self):
if self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] and 'Symbol(Symbol.asyncIterator)' in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]:
if (
self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]
and "Symbol(Symbol.asyncIterator)"
in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]
):
return self
raise Exception('RpcProxy is not an async iterable')
raise Exception("RpcProxy is not an async iterable")
async def __anext__(self):
if self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] and 'Symbol(Symbol.asyncIterator)' in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]:
if (
self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]
and "Symbol(Symbol.asyncIterator)"
in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]
):
try:
return await RpcProxyMethod(self, self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]['Symbol(Symbol.asyncIterator)']['next'])()
return await RpcProxyMethod(
self,
self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES][
"Symbol(Symbol.asyncIterator)"
]["next"],
)()
except RPCResultError as e:
if e.name == 'StopAsyncIteration':
if e.name == "StopAsyncIteration":
raise StopAsyncIteration()
raise
raise Exception('RpcProxy is not an async iterable')
raise Exception("RpcProxy is not an async iterable")
async def aclose(self):
if self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] and 'Symbol(Symbol.asyncIterator)' in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]:
if (
self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]
and "Symbol(Symbol.asyncIterator)"
in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]
):
try:
return await RpcProxyMethod(self, self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]['Symbol(Symbol.asyncIterator)']['return'])()
return await RpcProxyMethod(
self,
self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES][
"Symbol(Symbol.asyncIterator)"
]["return"],
)()
except Exception:
return
raise Exception('RpcProxy is not an async iterable')
raise Exception("RpcProxy is not an async iterable")
def __getattr__(self, name):
if name == '__proxy_finalizer_id':
return self.dict['__proxy_entry']['finalizerId']
if name == "__proxy_finalizer_id":
return self.dict["__proxy_entry"]["finalizerId"]
if name in self.__dict__:
return self.__dict__[name]
if self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES] and name in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]:
if (
self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]
and name in self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES]
):
return self.__dict__[RpcPeer.PROPERTY_PROXY_PROPERTIES][name]
return RpcProxyMethod(self, name)
def __setattr__(self, name: str, value: Any) -> None:
if name == '__proxy_finalizer_id':
self.__dict__['__proxy_entry']['finalizerId'] = value
if name == "__proxy_finalizer_id":
self.__dict__["__proxy_entry"]["finalizerId"] = value
return super().__setattr__(name, value)
def __call__(self, *args, **kwargs):
return self.__dict__[RpcPeer.PROPERTY_PROXY_PEER].__apply__(self.__dict__['__proxy_id'], self.__dict__['__proxy_oneway_methods'], None, args)
return self.__dict__[RpcPeer.PROPERTY_PROXY_PEER].__apply__(
self.__dict__["__proxy_id"],
self.__dict__["__proxy_oneway_methods"],
None,
args,
)
def __apply__(self, method: str, args: list):
return self.__dict__[RpcPeer.PROPERTY_PROXY_PEER].__apply__(self.__dict__['__proxy_id'], self.__dict__['__proxy_oneway_methods'], method, args)
return self.__dict__[RpcPeer.PROPERTY_PROXY_PEER].__apply__(
self.__dict__["__proxy_id"],
self.__dict__["__proxy_oneway_methods"],
method,
args,
)
class RpcPeer:
RPC_RESULT_ERROR_NAME = 'RPCResultError'
PROPERTY_PROXY_PROPERTIES = '__proxy_props'
PROPERTY_JSON_COPY_SERIALIZE_CHILDREN = '__json_copy_serialize_children'
PROPERTY_PROXY_PEER = '__proxy_peer'
RPC_RESULT_ERROR_NAME = "RPCResultError"
PROPERTY_PROXY_PROPERTIES = "__proxy_props"
PROPERTY_JSON_COPY_SERIALIZE_CHILDREN = "__json_copy_serialize_children"
PROPERTY_PROXY_PEER = "__proxy_peer"
def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None:
def __init__(
self, send: Callable[[object, Callable[[Exception], None], Dict], None]
) -> None:
self.send = send
self.peerName = 'Unnamed Peer'
self.peerName = "Unnamed Peer"
self.params: Mapping[str, any] = {}
self.localProxied: Mapping[any, LocalProxiedEntry] = {}
self.localProxyMap: Mapping[str, any] = {}
@@ -132,7 +175,9 @@ class RpcPeer:
self.killed = False
self.tags = {}
def __apply__(self, proxyId: str, oneWayMethods: List[str], method: str, args: list):
def __apply__(
self, proxyId: str, oneWayMethods: List[str], method: str, args: list
):
oneway = oneWayMethods and method in oneWayMethods
if self.killed:
@@ -140,7 +185,9 @@ class RpcPeer:
if oneway:
future.set_result(None)
return future
future.set_exception(RPCResultError(None, 'RpcPeer has been killed (apply) ' + str(method)))
future.set_exception(
RPCResultError(None, "RpcPeer has been killed (apply) " + str(method))
)
return future
serializationContext: Dict = {}
@@ -149,23 +196,24 @@ class RpcPeer:
serializedArgs.append(self.serialize(arg, serializationContext))
rpcApply = {
'type': 'apply',
'id': None,
'proxyId': proxyId,
'args': serializedArgs,
'method': method,
"type": "apply",
"id": None,
"proxyId": proxyId,
"args": serializedArgs,
"method": method,
}
if oneway:
rpcApply['oneway'] = True
rpcApply["oneway"] = True
self.send(rpcApply, None, serializationContext)
future = Future()
future.set_result(None)
return future
async def send(id: str, reject: Callable[[Exception], None]):
rpcApply['id'] = id
rpcApply["id"] = id
self.send(rpcApply, reject, serializationContext)
return self.createPendingResult(send)
def kill(self, message: str = None):
@@ -174,7 +222,7 @@ class RpcPeer:
return
self.killed = True
error = RPCResultError(None, message or 'peer was killed')
error = RPCResultError(None, message or "peer was killed")
# this.killedDeferred.reject(error);
for str, future in self.pendingResults.items():
future.set_exception(error)
@@ -187,14 +235,14 @@ class RpcPeer:
def createErrorResult(self, result: Any, e: Exception):
s = self.serializeError(e)
result['result'] = s
result['throw'] = True
result["result"] = s
result["throw"] = True
return result
def deserializeError(e: Dict) -> RPCResultError:
error = RPCResultError(None, e.get('message'))
error.stack = e.get('stack')
error.name = e.get('name')
error = RPCResultError(None, e.get("message"))
error.stack = e.get("stack")
error.name = e.get("name")
return error
def serializeError(self, e: Exception):
@@ -203,16 +251,16 @@ class RpcPeer:
message = str(e)
serialized = {
'stack': tb or '[no stack]',
'name': name or '[no name]',
'message': message or '[no message]',
"stack": tb or "[no stack]",
"name": name or "[no name]",
"message": message or "[no message]",
}
return {
'__remote_constructor_name': RpcPeer.RPC_RESULT_ERROR_NAME,
'__serialized_value': serialized,
"__remote_constructor_name": RpcPeer.RPC_RESULT_ERROR_NAME,
"__serialized_value": serialized,
}
# def getProxyProperties(value):
# return getattr(value, RpcPeer.PROPERTY_PROXY_PROPERTIES, None)
@@ -220,15 +268,15 @@ class RpcPeer:
# setattr(value, RpcPeer.PROPERTY_PROXY_PROPERTIES, properties)
def prepareProxyProperties(value):
if not hasattr(value, '__aiter__') or not hasattr(value, '__anext__'):
if not hasattr(value, "__aiter__") or not hasattr(value, "__anext__"):
return getattr(value, RpcPeer.PROPERTY_PROXY_PROPERTIES, None)
props = getattr(value, RpcPeer.PROPERTY_PROXY_PROPERTIES, None) or {}
if not props.get('Symbol(Symbol.asyncIterator)'):
props['Symbol(Symbol.asyncIterator)'] = {
'next': '__anext__',
'throw': 'athrow',
'return': 'aclose',
if not props.get("Symbol(Symbol.asyncIterator)"):
props["Symbol(Symbol.asyncIterator)"] = {
"next": "__anext__",
"throw": "athrow",
"return": "aclose",
}
return props
@@ -236,34 +284,44 @@ class RpcPeer:
return not value or (type(value) in jsonSerializable)
def serialize(self, value, serializationContext: Dict):
if type(value) == dict and value.get(RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN, None):
if type(value) == dict and value.get(
RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN, None
):
ret = {}
for (key, val) in value.items():
for key, val in value.items():
ret[key] = self.serialize(val, serializationContext)
return ret
if (RpcPeer.isTransportSafe(value)):
if RpcPeer.isTransportSafe(value):
return value
__remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr(
value, '__proxy_constructor') else type(value).__name__
__remote_constructor_name = (
"Function"
if callable(value)
else (
value.__proxy_constructor
if hasattr(value, "__proxy_constructor")
else type(value).__name__
)
)
if isinstance(value, Exception):
return self.serializeError(value)
serializerMapName = self.constructorSerializerMap.get(
type(value), None)
serializerMapName = self.constructorSerializerMap.get(type(value), None)
if serializerMapName:
__remote_constructor_name = serializerMapName
serializer = self.nameDeserializerMap.get(serializerMapName, None)
serialized = serializer.serialize(value, serializationContext)
ret = {
'__remote_proxy_id': None,
'__remote_proxy_finalizer_id': None,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': RpcPeer.prepareProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
'__serialized_value': serialized,
"__remote_proxy_id": None,
"__remote_proxy_finalizer_id": None,
"__remote_constructor_name": __remote_constructor_name,
"__remote_proxy_props": RpcPeer.prepareProxyProperties(value),
"__remote_proxy_oneway_methods": getattr(
value, "__proxy_oneway_methods", None
),
"__serialized_value": serialized,
}
return ret
@@ -273,26 +331,28 @@ class RpcPeer:
proxyId, __remote_proxy_props = self.onProxySerialization(value)
else:
__remote_proxy_props = RpcPeer.prepareProxyProperties(value)
proxyId = proxiedEntry['id']
proxyId = proxiedEntry["id"]
if proxyId != proxiedEntry['id']:
raise Exception('onProxySerialization proxy id mismatch')
if proxyId != proxiedEntry["id"]:
raise Exception("onProxySerialization proxy id mismatch")
proxiedEntry['finalizerId'] = RpcPeer.generateId()
proxiedEntry["finalizerId"] = RpcPeer.generateId()
ret = {
'__remote_proxy_id': proxyId,
'__remote_proxy_finalizer_id': proxiedEntry['finalizerId'],
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': __remote_proxy_props,
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
"__remote_proxy_id": proxyId,
"__remote_proxy_finalizer_id": proxiedEntry["finalizerId"],
"__remote_constructor_name": __remote_constructor_name,
"__remote_proxy_props": __remote_proxy_props,
"__remote_proxy_oneway_methods": getattr(
value, "__proxy_oneway_methods", None
),
}
return ret
__proxy_id = getattr(value, '__proxy_id', None)
__proxy_id = getattr(value, "__proxy_id", None)
__proxy_peer = getattr(value, RpcPeer.PROPERTY_PROXY_PEER, None)
if __proxy_id and __proxy_peer == self:
ret = {
'__local_proxy_id': __proxy_id,
"__local_proxy_id": __proxy_id,
}
return ret
@@ -303,18 +363,20 @@ class RpcPeer:
proxyId = RpcPeer.generateId()
proxiedEntry = {
'id': proxyId,
'finalizerId': proxyId,
"id": proxyId,
"finalizerId": proxyId,
}
self.localProxied[value] = proxiedEntry
self.localProxyMap[proxyId] = value
ret = {
'__remote_proxy_id': proxyId,
'__remote_proxy_finalizer_id': proxyId,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': __remote_proxy_props,
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
"__remote_proxy_id": proxyId,
"__remote_proxy_finalizer_id": proxyId,
"__remote_constructor_name": __remote_constructor_name,
"__remote_proxy_props": __remote_proxy_props,
"__remote_proxy_oneway_methods": getattr(
value, "__proxy_oneway_methods", None
),
}
return ret
@@ -322,22 +384,33 @@ class RpcPeer:
def finalize(self, localProxiedEntry: LocalProxiedEntry):
if self.killed:
return
id = localProxiedEntry['id']
id = localProxiedEntry["id"]
self.remoteWeakProxies.pop(id, None)
rpcFinalize = {
'__local_proxy_id': id,
'__local_proxy_finalizer_id': localProxiedEntry['finalizerId'],
'type': 'finalize',
"__local_proxy_id": id,
"__local_proxy_finalizer_id": localProxiedEntry["finalizerId"],
"type": "finalize",
}
self.send(rpcFinalize)
def newProxy(self, proxyId: str, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]):
def newProxy(
self,
proxyId: str,
proxyConstructorName: str,
proxyProps: any,
proxyOneWayMethods: List[str],
):
localProxiedEntry: LocalProxiedEntry = {
'id': proxyId,
'finalizerId': None,
"id": proxyId,
"finalizerId": None,
}
proxy = RpcProxy(self, localProxiedEntry, proxyConstructorName,
proxyProps, proxyOneWayMethods)
proxy = RpcProxy(
self,
localProxiedEntry,
proxyConstructorName,
proxyProps,
proxyOneWayMethods,
)
wr = weakref.ref(proxy)
self.remoteWeakProxies[proxyId] = wr
weakref.finalize(proxy, lambda: self.finalize(localProxiedEntry))
@@ -350,23 +423,22 @@ class RpcPeer:
if type(value) != dict:
return value
copySerializeChildren = value.get(RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN, None)
copySerializeChildren = value.get(
RpcPeer.PROPERTY_JSON_COPY_SERIALIZE_CHILDREN, None
)
if copySerializeChildren:
ret = {}
for (key, val) in value.items():
for key, val in value.items():
ret[key] = self.deserialize(val, deserializationContext)
return ret
__remote_proxy_id = value.get('__remote_proxy_id', None)
__remote_proxy_finalizer_id = value.get(
'__remote_proxy_finalizer_id', None)
__local_proxy_id = value.get('__local_proxy_id', None)
__remote_constructor_name = value.get(
'__remote_constructor_name', None)
__serialized_value = value.get('__serialized_value', None)
__remote_proxy_props = value.get('__remote_proxy_props', None)
__remote_proxy_oneway_methods = value.get(
'__remote_proxy_oneway_methods', None)
__remote_proxy_id = value.get("__remote_proxy_id", None)
__remote_proxy_finalizer_id = value.get("__remote_proxy_finalizer_id", None)
__local_proxy_id = value.get("__local_proxy_id", None)
__remote_constructor_name = value.get("__remote_constructor_name", None)
__serialized_value = value.get("__serialized_value", None)
__remote_proxy_props = value.get("__remote_proxy_props", None)
__remote_proxy_oneway_methods = value.get("__remote_proxy_oneway_methods", None)
if __remote_constructor_name == RpcPeer.RPC_RESULT_ERROR_NAME:
return RpcPeer.deserializeError(__serialized_value)
@@ -375,66 +447,71 @@ class RpcPeer:
weakref = self.remoteWeakProxies.get(__remote_proxy_id, None)
proxy = weakref() if weakref else None
if not proxy:
proxy = self.newProxy(__remote_proxy_id, __remote_constructor_name,
__remote_proxy_props, __remote_proxy_oneway_methods)
setattr(proxy, '__proxy_finalizer_id', __remote_proxy_finalizer_id)
proxy = self.newProxy(
__remote_proxy_id,
__remote_constructor_name,
__remote_proxy_props,
__remote_proxy_oneway_methods,
)
setattr(proxy, "__proxy_finalizer_id", __remote_proxy_finalizer_id)
return proxy
if __local_proxy_id:
ret = self.localProxyMap.get(__local_proxy_id, None)
if not ret:
raise RPCResultError(
None, 'invalid local proxy id %s' % __local_proxy_id)
None, "invalid local proxy id %s" % __local_proxy_id
)
return ret
deserializer = self.nameDeserializerMap.get(
__remote_constructor_name, None)
deserializer = self.nameDeserializerMap.get(__remote_constructor_name, None)
if deserializer:
return deserializer.deserialize(__serialized_value, deserializationContext)
return value
def sendResult(self, result: Dict, serializationContext: Dict):
self.send(result, lambda e: self.send(self.createErrorResult(result, e, None), None), serializationContext)
self.send(
result,
lambda e: self.send(self.createErrorResult(result, e, None), None),
serializationContext,
)
async def handleMessage(self, message: Dict, deserializationContext: Dict):
try:
messageType = message['type']
if messageType == 'param':
messageType = message["type"]
if messageType == "param":
result = {
'type': 'result',
'id': message['id'],
"type": "result",
"id": message["id"],
}
serializationContext: Dict = {}
try:
value = self.params.get(message['param'], None)
value = self.params.get(message["param"], None)
value = await maybe_await(value)
result['result'] = self.serialize(value, serializationContext)
result["result"] = self.serialize(value, serializationContext)
except Exception as e:
tb = traceback.format_exc()
self.createErrorResult(
result, type(e).__name, str(e), tb)
self.createErrorResult(result, type(e).__name, str(e), tb)
self.sendResult(result, serializationContext)
elif messageType == 'apply':
elif messageType == "apply":
result = {
'type': 'result',
'id': message.get('id', None),
"type": "result",
"id": message.get("id", None),
}
method = message.get('method', None)
method = message.get("method", None)
try:
serializationContext: Dict = {}
target = self.localProxyMap.get(
message['proxyId'], None)
target = self.localProxyMap.get(message["proxyId"], None)
if not target:
raise Exception('proxy id %s not found' %
message['proxyId'])
raise Exception("proxy id %s not found" % message["proxyId"])
args = []
for arg in (message['args'] or []):
for arg in message["args"] or []:
args.append(self.deserialize(arg, deserializationContext))
# if method == 'asend' and hasattr(target, '__aiter__') and hasattr(target, '__anext__') and not len(args):
@@ -444,47 +521,53 @@ class RpcPeer:
if method:
if not hasattr(target, method):
raise Exception(
'target %s does not have method %s' % (type(target), method))
"target %s does not have method %s"
% (type(target), method)
)
invoke = getattr(target, method)
value = await maybe_await(invoke(*args))
else:
value = await maybe_await(target(*args))
result['result'] = self.serialize(value, serializationContext)
result["result"] = self.serialize(value, serializationContext)
except StopAsyncIteration as e:
self.createErrorResult(result, e)
except Exception as e:
self.createErrorResult(result, e)
if not message.get('oneway', False):
if not message.get("oneway", False):
self.sendResult(result, serializationContext)
elif messageType == 'result':
id = message['id']
elif messageType == "result":
id = message["id"]
future = self.pendingResults.get(id, None)
if not future:
raise RPCResultError(
None, 'unknown result %s' % id)
raise RPCResultError(None, "unknown result %s" % id)
del self.pendingResults[id]
deserialized = self.deserialize(message.get('result', None), deserializationContext)
if message.get('throw'):
deserialized = self.deserialize(
message.get("result", None), deserializationContext
)
if message.get("throw"):
future.set_exception(deserialized)
else:
future.set_result(deserialized)
elif messageType == 'finalize':
finalizerId = message.get('__local_proxy_finalizer_id', None)
proxyId = message['__local_proxy_id']
elif messageType == "finalize":
finalizerId = message.get("__local_proxy_finalizer_id", None)
proxyId = message["__local_proxy_id"]
local = self.localProxyMap.get(proxyId, None)
if local:
localProxiedEntry = self.localProxied.get(local)
if localProxiedEntry and finalizerId and localProxiedEntry['finalizerId'] != finalizerId:
if (
localProxiedEntry
and finalizerId
and localProxiedEntry["finalizerId"] != finalizerId
):
# print('mismatch finalizer id', file=sys.stderr)
return
self.localProxied.pop(local, None)
local = self.localProxyMap.pop(proxyId, None)
else:
raise RPCResultError(
None, 'unknown rpc message type %s' % type)
raise RPCResultError(None, "unknown rpc message type %s" % type)
except Exception as e:
print("unhandled rpc error", self.peerName, e)
pass
@@ -492,12 +575,16 @@ class RpcPeer:
randomDigits = string.ascii_uppercase + string.ascii_lowercase + string.digits
def generateId():
return ''.join(random.choices(RpcPeer.randomDigits, k=8))
return "".join(random.choices(RpcPeer.randomDigits, k=8))
async def createPendingResult(self, cb: Callable[[str, Callable[[Exception], None]], None]):
async def createPendingResult(
self, cb: Callable[[str, Callable[[Exception], None]], None]
):
future = Future()
if self.killed:
future.set_exception(RPCResultError(None, 'RpcPeer has been killed (createPendingResult)'))
future.set_exception(
RPCResultError(None, "RpcPeer has been killed (createPendingResult)")
)
return future
id = RpcPeer.generateId()
@@ -508,9 +595,10 @@ class RpcPeer:
async def getParam(self, param):
async def send(id: str, reject: Callable[[Exception], None]):
paramMessage = {
'id': id,
'type': 'param',
'param': param,
"id": id,
"type": "param",
"param": param,
}
self.send(paramMessage, reject)
return await self.createPendingResult(send)

View File

@@ -16,7 +16,7 @@ import json
class BufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
return base64.b64encode(value).decode('utf8')
return base64.b64encode(value).decode("utf8")
def deserialize(self, value, serializationContext):
return base64.b64decode(value)
@@ -24,15 +24,15 @@ class BufferSerializer(rpc.RpcSerializer):
class SidebandBufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
buffers = serializationContext.get('buffers', None)
buffers = serializationContext.get("buffers", None)
if not buffers:
buffers = []
serializationContext['buffers'] = buffers
serializationContext["buffers"] = buffers
buffers.append(value)
return len(buffers) - 1
def deserialize(self, value, serializationContext):
buffers: List = serializationContext.get('buffers', None)
buffers: List = serializationContext.get("buffers", None)
buffer = buffers.pop()
return buffer
@@ -56,7 +56,7 @@ class RpcFileTransport(RpcTransport):
super().__init__()
self.readFd = readFd
self.writeFd = writeFd
self.executor = ThreadPoolExecutor(1, 'rpc-read')
self.executor = ThreadPoolExecutor(1, "rpc-read")
def osReadExact(self, size: int):
b = bytes(0)
@@ -64,7 +64,7 @@ class RpcFileTransport(RpcTransport):
got = os.read(self.readFd, size)
if not len(got):
self.executor.shutdown(False)
raise Exception('rpc end of stream reached')
raise Exception("rpc end of stream reached")
size -= len(got)
b += got
return b
@@ -73,7 +73,7 @@ class RpcFileTransport(RpcTransport):
lengthBytes = self.osReadExact(4)
typeBytes = self.osReadExact(1)
type = typeBytes[0]
length = int.from_bytes(lengthBytes, 'big')
length = int.from_bytes(lengthBytes, "big")
data = self.osReadExact(length - 1)
if type == 1:
return data
@@ -81,11 +81,13 @@ class RpcFileTransport(RpcTransport):
return message
async def read(self):
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.readMessageInternal())
return await asyncio.get_event_loop().run_in_executor(
self.executor, lambda: self.readMessageInternal()
)
def writeMessage(self, type: int, buffer, reject):
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
lb = length.to_bytes(4, "big")
try:
for b in [lb, bytes([type]), buffer]:
os.write(self.writeFd, b)
@@ -94,14 +96,18 @@ class RpcFileTransport(RpcTransport):
reject(e)
def writeJSON(self, j, reject):
return self.writeMessage(0, bytes(json.dumps(j, allow_nan=False), 'utf8'), reject)
return self.writeMessage(
0, bytes(json.dumps(j, allow_nan=False), "utf8"), reject
)
def writeBuffer(self, buffer, reject):
return self.writeMessage(1, buffer, reject)
class RpcStreamTransport(RpcTransport):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
def __init__(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
super().__init__()
self.reader = reader
self.writer = writer
@@ -110,7 +116,7 @@ class RpcStreamTransport(RpcTransport):
lengthBytes = await self.reader.readexactly(4)
typeBytes = await self.reader.readexactly(1)
type = typeBytes[0]
length = int.from_bytes(lengthBytes, 'big')
length = int.from_bytes(lengthBytes, "big")
data = await self.reader.readexactly(length - 1)
if type == 1:
return data
@@ -119,7 +125,7 @@ class RpcStreamTransport(RpcTransport):
def writeMessage(self, type: int, buffer, reject):
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
lb = length.to_bytes(4, "big")
try:
for b in [lb, bytes([type]), buffer]:
self.writer.write(b)
@@ -128,7 +134,9 @@ class RpcStreamTransport(RpcTransport):
reject(e)
def writeJSON(self, j, reject):
return self.writeMessage(0, bytes(json.dumps(j, allow_nan=False), 'utf8'), reject)
return self.writeMessage(
0, bytes(json.dumps(j, allow_nan=False), "utf8"), reject
)
def writeBuffer(self, buffer, reject):
return self.writeMessage(1, buffer, reject)
@@ -141,7 +149,9 @@ class RpcConnectionTransport(RpcTransport):
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
async def read(self):
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.connection.recv())
return await asyncio.get_event_loop().run_in_executor(
self.executor, lambda: self.connection.recv()
)
def writeMessage(self, json, reject):
try:
@@ -158,23 +168,20 @@ class RpcConnectionTransport(RpcTransport):
async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport):
deserializationContext = {
'buffers': []
}
deserializationContext = {"buffers": []}
while True:
message = await rpcTransport.read()
if type(message) != dict:
deserializationContext['buffers'].append(message)
deserializationContext["buffers"].append(message)
continue
asyncio.run_coroutine_threadsafe(
peer.handleMessage(message, deserializationContext), loop)
peer.handleMessage(message, deserializationContext), loop
)
deserializationContext = {
'buffers': []
}
deserializationContext = {"buffers": []}
async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport):
@@ -185,7 +192,7 @@ async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransp
def send(message, reject=None, serializationContext=None):
with mutex:
if serializationContext:
buffers = serializationContext.get('buffers', None)
buffers = serializationContext.get("buffers", None)
if buffers:
for buffer in buffers:
rpcTransport.writeBuffer(buffer, reject)
@@ -193,10 +200,10 @@ async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransp
rpcTransport.writeJSON(message, reject)
peer = rpc.RpcPeer(send)
peer.nameDeserializerMap['Buffer'] = SidebandBufferSerializer()
peer.constructorSerializerMap[bytes] = 'Buffer'
peer.constructorSerializerMap[bytearray] = 'Buffer'
peer.constructorSerializerMap[memoryview] = 'Buffer'
peer.nameDeserializerMap["Buffer"] = SidebandBufferSerializer()
peer.constructorSerializerMap[bytes] = "Buffer"
peer.constructorSerializerMap[bytearray] = "Buffer"
peer.constructorSerializerMap[memoryview] = "Buffer"
async def peerReadLoop():
try: