Files
scrypted/server/python/cluster_setup.py
2025-11-09 08:23:06 -08:00

284 lines
10 KiB
Python

from __future__ import annotations
import asyncio
import base64
import hashlib
import os
from asyncio.events import AbstractEventLoop
from collections.abc import Mapping
from typing import Any
import rpc
import rpc_reader
from typing import TypedDict, Callable
class ClusterObject(TypedDict):
id: str
address: str
port: int
proxyId: str
sourceKey: str
sha256: str
def isClusterAddress(address: str):
return not address or address == os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None)
def getClusterPeerKey(address: str, port: int):
return f"{address}:{port}"
class ClusterSetup:
def __init__(self, loop: AbstractEventLoop, peer: rpc.RpcPeer):
self.loop = loop
self.peer = peer
self.clusterId: str = None
self.clusterSecret: str = None
self.clusterAddress: str = None
self.clusterPort: int = None
self.SCRYPTED_CLUSTER_ADDRESS: str = None
self.clusterPeers: Mapping[str, asyncio.Future[rpc.RpcPeer]] = {}
async def resolveObject(self, id: str, sourceKey: str):
sourcePeer: rpc.RpcPeer = (
self.peer
if not sourceKey
else await rpc.maybe_await(self.clusterPeers.get(sourceKey, None))
)
if not sourcePeer:
return
return sourcePeer.localProxyMap.get(id, None)
async def connectClusterObject(self, o: ClusterObject):
sha256 = self.computeClusterObjectHash(o)
if sha256 != o["sha256"]:
raise Exception("secret incorrect")
return await self.resolveObject(
o.get("proxyId", None), o.get("sourceKey", None)
)
def onProxySerialization(
self, peer: rpc.RpcPeer, value: Any, sourceKey: str = None
):
properties: dict = rpc.RpcPeer.prepareProxyProperties(value) or {}
clusterEntry = properties.get("__cluster", None)
proxyId: str
existing = peer.localProxied.get(value, None)
if existing:
proxyId = existing["id"]
else:
proxyId = (
clusterEntry and clusterEntry.get("proxyId", None)
) or rpc.RpcPeer.generateId()
if clusterEntry:
if (
isClusterAddress(clusterEntry.get("address", None))
and self.clusterPort == clusterEntry["port"]
and sourceKey != clusterEntry.get("sourceKey", None)
):
clusterEntry = None
if not clusterEntry:
clusterEntry: ClusterObject = {
"id": self.clusterId,
"proxyId": proxyId,
"address": self.SCRYPTED_CLUSTER_ADDRESS,
"port": self.clusterPort,
"sourceKey": sourceKey,
}
clusterEntry["sha256"] = self.computeClusterObjectHash(clusterEntry)
properties["__cluster"] = clusterEntry
return proxyId, properties
async def initializeCluster(self, options: dict):
if self.clusterPort:
return
self.clusterId = options["clusterId"]
self.clusterSecret = options["clusterSecret"]
self.clusterWorkerId = options.get("clusterWorkerId", None)
self.SCRYPTED_CLUSTER_ADDRESS = os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None)
async def handleClusterClient(
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
clusterPeerAddress, clusterPeerPort = writer.get_extra_info("peername")
clusterPeerKey = getClusterPeerKey(clusterPeerAddress, clusterPeerPort)
rpcTransport = rpc_reader.RpcStreamTransport(reader, writer)
peer: rpc.RpcPeer
peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(
self.loop, rpcTransport
)
# set all params from self.peer
for key, value in self.peer.params.items():
peer.params[key] = value
peer.onProxySerialization = lambda value: self.onProxySerialization(
peer, value, clusterPeerKey
)
future: asyncio.Future[rpc.RpcPeer] = asyncio.Future()
future.set_result(peer)
self.clusterPeers[clusterPeerKey] = future
peer.params["connectRPCObject"] = lambda o: self.connectClusterObject(o)
try:
await peerReadLoop()
except:
pass
finally:
self.clusterPeers.pop(clusterPeerKey)
peer.kill("cluster client killed")
writer.close()
clusterRpcServerInfo = await cluster_listen_zero(handleClusterClient)
self.clusterPort = clusterRpcServerInfo["port"]
self.peer.onProxySerialization = lambda value: self.onProxySerialization(
self.peer, value, None
)
del self.peer.params["initializeCluster"]
def computeClusterObjectHash(self, o: ClusterObject) -> str:
m = hashlib.sha256()
m.update(
bytes(
# The use of ` o.get(key, None) or '' ` is to ensure that optional fields
# are omitted from the hash, matching the JS implementation. Otherwise, since
# the dict may contain the keys initialized to None, ` o.get(key, '') ` would
# return None instead of ''.
f"{o['id']}{o.get('address', None) or ''}{o['port']}{o.get('sourceKey', None) or ''}{o['proxyId']}{self.clusterSecret}",
"utf8",
)
)
return base64.b64encode(m.digest()).decode("utf-8")
def ensureClusterPeer(self, address: str, port: int):
if isClusterAddress(address):
address = "127.0.0.1"
clusterPeerKey = getClusterPeerKey(address, port)
clusterPeerPromise = self.clusterPeers.get(clusterPeerKey)
if clusterPeerPromise:
return clusterPeerPromise
async def connectClusterPeer():
try:
reader, writer = await asyncio.open_connection(address, port)
sourceAddress, sourcePort = writer.get_extra_info("sockname")
if (
sourceAddress != self.SCRYPTED_CLUSTER_ADDRESS
and sourceAddress != "127.0.0.1"
):
print("source address mismatch", sourceAddress)
rpcTransport = rpc_reader.RpcStreamTransport(reader, writer)
clusterPeer, peerReadLoop = await rpc_reader.prepare_peer_readloop(
self.loop, rpcTransport
)
# set all params from self.peer
for key, value in self.peer.params.items():
clusterPeer.params[key] = value
clusterPeer.onProxySerialization = (
lambda value: self.onProxySerialization(
clusterPeer, value, clusterPeerKey
)
)
except:
self.clusterPeers.pop(clusterPeerKey)
raise
async def run_loop():
try:
await peerReadLoop()
except:
pass
finally:
self.clusterPeers.pop(clusterPeerKey)
asyncio.run_coroutine_threadsafe(run_loop(), self.loop)
return clusterPeer
clusterPeerPromise = self.loop.create_task(connectClusterPeer())
self.clusterPeers[clusterPeerKey] = clusterPeerPromise
return clusterPeerPromise
async def connectRPCObject(self, value):
__cluster = getattr(value, "__cluster")
if type(__cluster) is not dict:
return value
clusterObject: ClusterObject = __cluster
if clusterObject.get("id", None) != self.clusterId:
return value
address = clusterObject.get("address", None)
port = clusterObject["port"]
proxyId = clusterObject["proxyId"]
if port == self.clusterPort:
return await self.connectClusterObject(clusterObject)
clusterPeerPromise = self.ensureClusterPeer(address, port)
try:
clusterPeer = await clusterPeerPromise
weakref = clusterPeer.remoteWeakProxies.get(proxyId, None)
existing = weakref() if weakref else None
if existing:
return existing
peerConnectRPCObject = clusterPeer.tags.get("connectRPCObject")
if not peerConnectRPCObject:
peerConnectRPCObject = await clusterPeer.getParam("connectRPCObject")
clusterPeer.tags["connectRPCObject"] = peerConnectRPCObject
newValue = await peerConnectRPCObject(clusterObject)
if not newValue:
raise Exception("rpc object not found?")
return newValue
except Exception as e:
return value
class ClusterServerListener(TypedDict):
server: asyncio.Server
port: int
async def cluster_listen_zero(
callback: Callable[[asyncio.StreamReader, asyncio.StreamWriter]]
) -> ClusterServerListener:
SCRYPTED_CLUSTER_ADDRESS = os.getenv("SCRYPTED_CLUSTER_ADDRESS")
if not SCRYPTED_CLUSTER_ADDRESS or SCRYPTED_CLUSTER_ADDRESS == "127.0.0.1":
server = await asyncio.start_server(callback, host=None, port=0)
port = server.sockets[0].getsockname()[1]
return {
"server": server,
"port": port,
}
# need to listen on the cluster address and 127.0.0.1 on the same port.
retries = 5
while retries > 0:
cluster_server = await asyncio.start_server(
callback, host=SCRYPTED_CLUSTER_ADDRESS, port=0
)
port = cluster_server.sockets[0].getsockname()[1]
try:
local_server = await asyncio.start_server(
callback, host="127.0.0.1", port=port
)
future = asyncio.ensure_future(local_server.wait_closed())
future.add_done_callback(lambda: local_server.close())
return {
"server": cluster_server,
"port": port,
}
except:
# Port may be in use, keep trying.
cluster_server.close()
retries -= 1
raise Exception("failed to bind to cluster address.")