From 7d28d1d9d429f36ed6db2452f229a62f7d03e41b Mon Sep 17 00:00:00 2001 From: Koushik Dutta Date: Tue, 12 Nov 2024 21:22:22 -0800 Subject: [PATCH] server: wip python clustering --- server/python/plugin_remote.py | 272 +++++++++++++++++---------------- 1 file changed, 140 insertions(+), 132 deletions(-) diff --git a/server/python/plugin_remote.py b/server/python/plugin_remote.py index 7354edb06..4bd8517d7 100644 --- a/server/python/plugin_remote.py +++ b/server/python/plugin_remote.py @@ -555,9 +555,125 @@ class DeviceManager(scrypted_python.scrypted_sdk.types.DeviceManager): return self.nativeIds.get(nativeId, None) +def isClusterAddress(address: str): + return not address or address == os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None) + +def getClusterPeerKey(address: str, port: int): + return f"{address}:{port}" + +class ClusterSetup(): + def __init__(self, peer: rpc.RpcPeer): + self.peer = peer + self.clusterId: str = None + self.clusterSecret: str = None + self.clusterAddress: str = None + self.clusterPort: int = None + self.SCRYPTED_CLUSTER_ADDRESS: str = None + self.clusterPeers: Mapping[str, asyncio.Future[rpc.RpcPeer]] = {} + + async def resolveObject(self, id: str, sourceKey: str): + sourcePeer: rpc.RpcPeer = ( + self.peer + if not sourceKey + else await rpc.maybe_await(self.clusterPeers.get(sourceKey, None)) + ) + if not sourcePeer: + return + return sourcePeer.localProxyMap.get(id, None) + + async def connectRPCObject(self, o: ClusterObject): + sha256 = self.computeClusterObjectHash(o) + if sha256 != o["sha256"]: + raise Exception("secret incorrect") + return await self.resolveObject(o.get('proxyId', None), o.get('sourceKey', None)) + + def onProxySerialization(self, peer: rpc.RpcPeer, value: Any, sourceKey: str = None): + properties: dict = rpc.RpcPeer.prepareProxyProperties(value) or {} + clusterEntry = properties.get("__cluster", None) + proxyId: str + existing = peer.localProxied.get(value, None) + if existing: + proxyId = existing["id"] + else: + proxyId = ( + clusterEntry and clusterEntry.get("proxyId", None) + ) or rpc.RpcPeer.generateId() + + if clusterEntry: + if ( + isClusterAddress(clusterEntry.get("address", None)) + and self.clusterPort == clusterEntry["port"] + and sourceKey != clusterEntry.get("sourceKey", None) + ): + clusterEntry = None + + if not clusterEntry: + clusterEntry: ClusterObject = { + "id": self.clusterId, + "proxyId": proxyId, + "address": self.SCRYPTED_CLUSTER_ADDRESS, + "port": self.clusterPort, + "sourceKey": sourceKey, + } + clusterEntry["sha256"] = self.computeClusterObjectHash(clusterEntry) + properties["__cluster"] = clusterEntry + + return proxyId, properties + + async def initializeCluster(self, options: dict): + if self.clusterPort: + return + self.clusterId = options["clusterId"] + self.clusterSecret = options["clusterSecret"] + self.SCRYPTED_CLUSTER_ADDRESS = os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None) + + async def handleClusterClient( + reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ): + clusterPeerAddress, clusterPeerPort = writer.get_extra_info("peername") + clusterPeerKey = getClusterPeerKey(clusterPeerAddress, clusterPeerPort) + rpcTransport = rpc_reader.RpcStreamTransport(reader, writer) + peer: rpc.RpcPeer + peer, peerReadLoop = await rpc_reader.prepare_peer_readloop( + self.loop, rpcTransport + ) + peer.onProxySerialization = lambda value: self.onProxySerialization( + peer, value, clusterPeerKey + ) + future: asyncio.Future[rpc.RpcPeer] = asyncio.Future() + future.set_result(peer) + self.clusterPeers[clusterPeerKey] = future + peer.params["connectRPCObject"] = lambda o: self.connectRPCObject(o) + try: + await peerReadLoop() + except: + pass + finally: + self.clusterPeers.pop(clusterPeerKey) + peer.kill("cluster client killed") + writer.close() + + listenAddress = "0.0.0.0" if self.SCRYPTED_CLUSTER_ADDRESS else "127.0.0.1" + clusterRpcServer = await asyncio.start_server( + handleClusterClient, listenAddress, 0 + ) + self.clusterPort = clusterRpcServer.sockets[0].getsockname()[1] + self.peer.onProxySerialization = lambda value: self.onProxySerialization(self.peer, value, None) + del self.peer.params["initializeCluster"] + + def computeClusterObjectHash(self, o: ClusterObject) -> str: + m = hashlib.sha256() + m.update( + bytes( + f"{o['id']}{o.get('address') or ''}{o['port']}{o.get('sourceKey', None) or ''}{o['proxyId']}{self.clusterSecret}", + "utf8", + ) + ) + return base64.b64encode(m.digest()).decode("utf-8") + class PluginRemote: def __init__( - self, peer: rpc.RpcPeer, api, pluginId: str, hostInfo, loop: AbstractEventLoop + self, clusterSetup: ClusterSetup, api, pluginId: str, hostInfo, loop: AbstractEventLoop ): self.systemState: Mapping[str, Mapping[str, SystemDeviceState]] = {} self.nativeIds: Mapping[str, DeviceStorage] = {} @@ -565,7 +681,8 @@ class PluginRemote: self.consoles: Mapping[str, Future[Tuple[StreamReader, StreamWriter]]] = {} self.ptimeSum = 0 self.allMemoryStats = {} - self.peer = peer + self.peer = clusterSetup.peer + self.clusterSetup = clusterSetup self.api = api self.pluginId = pluginId self.hostInfo = hostInfo @@ -631,122 +748,16 @@ class PluginRemote: traceback.print_exc() raise - async def loadZipWrapped(self, packageJson, getZip: Any, options: dict): + async def loadZipWrapped(self, packageJson, zipAPI: Any, options: dict): + await self.clusterSetup.initializeCluster(options) + sdk = ScryptedStatic() - clusterId = options["clusterId"] - clusterSecret = options["clusterSecret"] - SCRYPTED_CLUSTER_ADDRESS = os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None) - - def computeClusterObjectHash(o: ClusterObject) -> str: - m = hashlib.sha256() - m.update( - bytes( - f"{o['id']}{o.get('address') or ''}{o['port']}{o.get('sourceKey', None) or ''}{o['proxyId']}{clusterSecret}", - "utf8", - ) - ) - return base64.b64encode(m.digest()).decode("utf-8") - - def isClusterAddress(address: str): - return not address or address == SCRYPTED_CLUSTER_ADDRESS - - def onProxySerialization(peer: rpc.RpcPeer, value: Any, sourceKey: str = None): - properties: dict = rpc.RpcPeer.prepareProxyProperties(value) or {} - clusterEntry = properties.get("__cluster", None) - proxyId: str - existing = peer.localProxied.get(value, None) - if existing: - proxyId = existing["id"] - else: - proxyId = ( - clusterEntry and clusterEntry.get("proxyId", None) - ) or rpc.RpcPeer.generateId() - - if clusterEntry: - if ( - isClusterAddress(clusterEntry.get("address", None)) - and clusterPort == clusterEntry["port"] - and sourceKey != clusterEntry.get("sourceKey", None) - ): - clusterEntry = None - - if not clusterEntry: - clusterEntry: ClusterObject = { - "id": clusterId, - "proxyId": proxyId, - "address": SCRYPTED_CLUSTER_ADDRESS, - "port": clusterPort, - "sourceKey": sourceKey, - } - clusterEntry["sha256"] = computeClusterObjectHash(clusterEntry) - properties["__cluster"] = clusterEntry - - return proxyId, properties - - self.peer.onProxySerialization = lambda value: onProxySerialization( - self.peer, value, None - ) - - async def resolveObject(id: str, sourceKey: str): - sourcePeer: rpc.RpcPeer = ( - self.peer - if not sourceKey - else await rpc.maybe_await(clusterPeers.get(sourceKey, None)) - ) - if not sourcePeer: - return - return sourcePeer.localProxyMap.get(id, None) - - clusterPeers: Mapping[str, asyncio.Future[rpc.RpcPeer]] = {} - - def getClusterPeerKey(address: str, port: int): - return f"{address}:{port}" - - async def handleClusterClient( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ): - clusterPeerAddress, clusterPeerPort = writer.get_extra_info("peername") - clusterPeerKey = getClusterPeerKey(clusterPeerAddress, clusterPeerPort) - rpcTransport = rpc_reader.RpcStreamTransport(reader, writer) - peer: rpc.RpcPeer - peer, peerReadLoop = await rpc_reader.prepare_peer_readloop( - self.loop, rpcTransport - ) - peer.onProxySerialization = lambda value: onProxySerialization( - peer, value, clusterPeerKey - ) - future: asyncio.Future[rpc.RpcPeer] = asyncio.Future() - future.set_result(peer) - clusterPeers[clusterPeerKey] = future - - async def connectRPCObject(o: ClusterObject): - sha256 = computeClusterObjectHash(o) - if sha256 != o["sha256"]: - raise Exception("secret incorrect") - return await resolveObject(o["proxyId"], o.get("sourceKey", None)) - - peer.params["connectRPCObject"] = connectRPCObject - try: - await peerReadLoop() - except: - pass - finally: - clusterPeers.pop(clusterPeerKey) - peer.kill("cluster client killed") - writer.close() - - listenAddress = "0.0.0.0" if SCRYPTED_CLUSTER_ADDRESS else "127.0.0.1" - clusterRpcServer = await asyncio.start_server( - handleClusterClient, listenAddress, 0 - ) - clusterPort = clusterRpcServer.sockets[0].getsockname()[1] - def ensureClusterPeer(address: str, port: int): if isClusterAddress(address): address = "127.0.0.1" clusterPeerKey = getClusterPeerKey(address, port) - clusterPeerPromise = clusterPeers.get(clusterPeerKey) + clusterPeerPromise = self.clusterSetup.clusterPeers.get(clusterPeerKey) if clusterPeerPromise: return clusterPeerPromise @@ -755,7 +766,7 @@ class PluginRemote: reader, writer = await asyncio.open_connection(address, port) sourceAddress, sourcePort = writer.get_extra_info("sockname") if ( - sourceAddress != SCRYPTED_CLUSTER_ADDRESS + sourceAddress != self.clusterSetup.SCRYPTED_CLUSTER_ADDRESS and sourceAddress != "127.0.0.1" ): print("source address mismatch", sourceAddress) @@ -764,12 +775,12 @@ class PluginRemote: self.loop, rpcTransport ) clusterPeer.onProxySerialization = ( - lambda value: onProxySerialization( + lambda value: self.clusterSetup.onProxySerialization( clusterPeer, value, clusterPeerKey ) ) except: - clusterPeers.pop(clusterPeerKey) + self.clusterSetup.clusterPeers.pop(clusterPeerKey) raise async def run_loop(): @@ -778,14 +789,14 @@ class PluginRemote: except: pass finally: - clusterPeers.pop(clusterPeerKey) + self.clusterSetup.clusterPeers.pop(clusterPeerKey) asyncio.run_coroutine_threadsafe(run_loop(), self.loop) return clusterPeer clusterPeerPromise = self.loop.create_task(connectClusterPeer()) - clusterPeers[clusterPeerKey] = clusterPeerPromise + self.clusterSetup.clusterPeers[clusterPeerKey] = clusterPeerPromise return clusterPeerPromise async def connectRPCObject(value): @@ -795,16 +806,14 @@ class PluginRemote: clusterObject: ClusterObject = __cluster - if clusterObject.get("id", None) != clusterId: + if clusterObject.get("id", None) != self.clusterSetup.clusterId: return value address = clusterObject.get("address", None) port = clusterObject["port"] proxyId = clusterObject["proxyId"] - if port == clusterPort: - return await resolveObject( - proxyId, clusterObject.get("sourceKey", None) - ) + if port == self.clusterSetup.clusterPort: + return await self.clusterSetup.connectRPCObject(clusterObject) clusterPeerPromise = ensureClusterPeer(address, port) @@ -846,7 +855,7 @@ class PluginRemote: if not os.path.exists(zipPath) or debug: os.makedirs(os.path.dirname(zipPath), exist_ok=True) - zipData = await getZip() + zipData = await zipAPI.getZip() zipPathTmp = zipPath + ".tmp" with open(zipPathTmp, "wb") as f: f.write(zipData) @@ -962,7 +971,7 @@ class PluginRemote: self.deviceManager = DeviceManager(self.nativeIds, self.systemManager) self.mediaManager = MediaManager(await self.api.getMediaManager()) - await self.start_stats_runner() + await self.start_stats_runner(zipAPI.updateState) try: from scrypted_sdk import sdk_init2 # type: ignore @@ -1030,7 +1039,7 @@ class PluginRemote: forkOptions = options.copy() forkOptions["fork"] = True forkOptions["debug"] = debug - return await remote.loadZip(packageJson, getZip, forkOptions) + return await remote.loadZip(packageJson, zipAPI.getZip, forkOptions) pluginFork.result = asyncio.create_task(getFork()) return pluginFork @@ -1122,7 +1131,7 @@ class PluginRemote: return [self.replPort, os.getenv("SCRYPTED_CLUSTER_ADDRESS", None)] raise Exception(f"unknown service {name}") - async def start_stats_runner(self): + async def start_stats_runner(self, update_stats): pong = None async def ping(time: int): @@ -1132,11 +1141,6 @@ class PluginRemote: self.peer.params["ping"] = ping - update_stats = await self.peer.getParam("updateStats") - if not update_stats: - print("host did not provide update_stats") - return - def stats_runner(): ptime = round(time.process_time() * 1000000) + self.ptimeSum try: @@ -1175,8 +1179,12 @@ async def plugin_async_main( ): peer, readLoop = await rpc_reader.prepare_peer_readloop(loop, rpcTransport) peer.params["print"] = print + + clusterSetup = ClusterSetup(peer) + peer.params["initializeCluster"] = lambda options: clusterSetup.initializeCluster(options) + peer.params["getRemote"] = lambda api, pluginId, hostInfo: PluginRemote( - peer, api, pluginId, hostInfo, loop + clusterSetup, api, pluginId, hostInfo, loop ) try: