mirror of
https://github.com/koush/scrypted.git
synced 2026-02-03 14:13:28 +00:00
284 lines
10 KiB
Python
284 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import os
|
|
from asyncio.events import AbstractEventLoop
|
|
from collections.abc import Mapping
|
|
from typing import Any
|
|
|
|
import rpc
|
|
import rpc_reader
|
|
from typing import TypedDict, Callable
|
|
|
|
|
|
class ClusterObject(TypedDict):
|
|
id: str
|
|
address: str
|
|
port: int
|
|
proxyId: str
|
|
sourceKey: str
|
|
sha256: str
|
|
|
|
|
|
def isClusterAddress(address: str):
|
|
return not address or address == os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None)
|
|
|
|
|
|
def getClusterPeerKey(address: str, port: int):
|
|
return f"{address}:{port}"
|
|
|
|
|
|
class ClusterSetup:
|
|
def __init__(self, loop: AbstractEventLoop, peer: rpc.RpcPeer):
|
|
self.loop = loop
|
|
self.peer = peer
|
|
self.clusterId: str = None
|
|
self.clusterSecret: str = None
|
|
self.clusterAddress: str = None
|
|
self.clusterPort: int = None
|
|
self.SCRYPTED_CLUSTER_ADDRESS: str = None
|
|
self.clusterPeers: Mapping[str, asyncio.Future[rpc.RpcPeer]] = {}
|
|
|
|
async def resolveObject(self, id: str, sourceKey: str):
|
|
sourcePeer: rpc.RpcPeer = (
|
|
self.peer
|
|
if not sourceKey
|
|
else await rpc.maybe_await(self.clusterPeers.get(sourceKey, None))
|
|
)
|
|
if not sourcePeer:
|
|
return
|
|
return sourcePeer.localProxyMap.get(id, None)
|
|
|
|
async def connectClusterObject(self, o: ClusterObject):
|
|
sha256 = self.computeClusterObjectHash(o)
|
|
if sha256 != o["sha256"]:
|
|
raise Exception("secret incorrect")
|
|
return await self.resolveObject(
|
|
o.get("proxyId", None), o.get("sourceKey", None)
|
|
)
|
|
|
|
def onProxySerialization(
|
|
self, peer: rpc.RpcPeer, value: Any, sourceKey: str = None
|
|
):
|
|
properties: dict = rpc.RpcPeer.prepareProxyProperties(value) or {}
|
|
clusterEntry = properties.get("__cluster", None)
|
|
proxyId: str
|
|
existing = peer.localProxied.get(value, None)
|
|
if existing:
|
|
proxyId = existing["id"]
|
|
else:
|
|
proxyId = (
|
|
clusterEntry and clusterEntry.get("proxyId", None)
|
|
) or rpc.RpcPeer.generateId()
|
|
|
|
if clusterEntry:
|
|
if (
|
|
isClusterAddress(clusterEntry.get("address", None))
|
|
and self.clusterPort == clusterEntry["port"]
|
|
and sourceKey != clusterEntry.get("sourceKey", None)
|
|
):
|
|
clusterEntry = None
|
|
|
|
if not clusterEntry:
|
|
clusterEntry: ClusterObject = {
|
|
"id": self.clusterId,
|
|
"proxyId": proxyId,
|
|
"address": self.SCRYPTED_CLUSTER_ADDRESS,
|
|
"port": self.clusterPort,
|
|
"sourceKey": sourceKey,
|
|
}
|
|
clusterEntry["sha256"] = self.computeClusterObjectHash(clusterEntry)
|
|
properties["__cluster"] = clusterEntry
|
|
|
|
return proxyId, properties
|
|
|
|
async def initializeCluster(self, options: dict):
|
|
if self.clusterPort:
|
|
return
|
|
self.clusterId = options["clusterId"]
|
|
self.clusterSecret = options["clusterSecret"]
|
|
self.clusterWorkerId = options.get("clusterWorkerId", None)
|
|
self.SCRYPTED_CLUSTER_ADDRESS = os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None)
|
|
|
|
async def handleClusterClient(
|
|
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
):
|
|
clusterPeerAddress, clusterPeerPort = writer.get_extra_info("peername")
|
|
clusterPeerKey = getClusterPeerKey(clusterPeerAddress, clusterPeerPort)
|
|
rpcTransport = rpc_reader.RpcStreamTransport(reader, writer)
|
|
peer: rpc.RpcPeer
|
|
peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(
|
|
self.loop, rpcTransport
|
|
)
|
|
# set all params from self.peer
|
|
for key, value in self.peer.params.items():
|
|
peer.params[key] = value
|
|
peer.onProxySerialization = lambda value: self.onProxySerialization(
|
|
peer, value, clusterPeerKey
|
|
)
|
|
future: asyncio.Future[rpc.RpcPeer] = asyncio.Future()
|
|
future.set_result(peer)
|
|
self.clusterPeers[clusterPeerKey] = future
|
|
peer.params["connectRPCObject"] = lambda o: self.connectClusterObject(o)
|
|
try:
|
|
await peerReadLoop()
|
|
except:
|
|
pass
|
|
finally:
|
|
self.clusterPeers.pop(clusterPeerKey)
|
|
peer.kill("cluster client killed")
|
|
writer.close()
|
|
|
|
clusterRpcServerInfo = await cluster_listen_zero(handleClusterClient)
|
|
self.clusterPort = clusterRpcServerInfo["port"]
|
|
self.peer.onProxySerialization = lambda value: self.onProxySerialization(
|
|
self.peer, value, None
|
|
)
|
|
del self.peer.params["initializeCluster"]
|
|
|
|
def computeClusterObjectHash(self, o: ClusterObject) -> str:
|
|
m = hashlib.sha256()
|
|
m.update(
|
|
bytes(
|
|
# The use of ` o.get(key, None) or '' ` is to ensure that optional fields
|
|
# are omitted from the hash, matching the JS implementation. Otherwise, since
|
|
# the dict may contain the keys initialized to None, ` o.get(key, '') ` would
|
|
# return None instead of ''.
|
|
f"{o['id']}{o.get('address', None) or ''}{o['port']}{o.get('sourceKey', None) or ''}{o['proxyId']}{self.clusterSecret}",
|
|
"utf8",
|
|
)
|
|
)
|
|
return base64.b64encode(m.digest()).decode("utf-8")
|
|
|
|
def ensureClusterPeer(self, address: str, port: int):
|
|
if isClusterAddress(address):
|
|
address = "127.0.0.1"
|
|
clusterPeerKey = getClusterPeerKey(address, port)
|
|
clusterPeerPromise = self.clusterPeers.get(clusterPeerKey)
|
|
if clusterPeerPromise:
|
|
return clusterPeerPromise
|
|
|
|
async def connectClusterPeer():
|
|
try:
|
|
reader, writer = await asyncio.open_connection(address, port)
|
|
sourceAddress, sourcePort = writer.get_extra_info("sockname")
|
|
if (
|
|
sourceAddress != self.SCRYPTED_CLUSTER_ADDRESS
|
|
and sourceAddress != "127.0.0.1"
|
|
):
|
|
print("source address mismatch", sourceAddress)
|
|
rpcTransport = rpc_reader.RpcStreamTransport(reader, writer)
|
|
clusterPeer, peerReadLoop = await rpc_reader.prepare_peer_readloop(
|
|
self.loop, rpcTransport
|
|
)
|
|
# set all params from self.peer
|
|
for key, value in self.peer.params.items():
|
|
clusterPeer.params[key] = value
|
|
clusterPeer.onProxySerialization = (
|
|
lambda value: self.onProxySerialization(
|
|
clusterPeer, value, clusterPeerKey
|
|
)
|
|
)
|
|
except:
|
|
self.clusterPeers.pop(clusterPeerKey)
|
|
raise
|
|
|
|
async def run_loop():
|
|
try:
|
|
await peerReadLoop()
|
|
except:
|
|
pass
|
|
finally:
|
|
self.clusterPeers.pop(clusterPeerKey)
|
|
|
|
asyncio.run_coroutine_threadsafe(run_loop(), self.loop)
|
|
return clusterPeer
|
|
|
|
clusterPeerPromise = self.loop.create_task(connectClusterPeer())
|
|
|
|
self.clusterPeers[clusterPeerKey] = clusterPeerPromise
|
|
return clusterPeerPromise
|
|
|
|
async def connectRPCObject(self, value):
|
|
__cluster = getattr(value, "__cluster")
|
|
if type(__cluster) is not dict:
|
|
return value
|
|
|
|
clusterObject: ClusterObject = __cluster
|
|
|
|
if clusterObject.get("id", None) != self.clusterId:
|
|
return value
|
|
|
|
address = clusterObject.get("address", None)
|
|
port = clusterObject["port"]
|
|
proxyId = clusterObject["proxyId"]
|
|
if port == self.clusterPort:
|
|
return await self.connectClusterObject(clusterObject)
|
|
|
|
clusterPeerPromise = self.ensureClusterPeer(address, port)
|
|
|
|
try:
|
|
clusterPeer = await clusterPeerPromise
|
|
weakref = clusterPeer.remoteWeakProxies.get(proxyId, None)
|
|
existing = weakref() if weakref else None
|
|
if existing:
|
|
return existing
|
|
|
|
peerConnectRPCObject = clusterPeer.tags.get("connectRPCObject")
|
|
if not peerConnectRPCObject:
|
|
peerConnectRPCObject = await clusterPeer.getParam("connectRPCObject")
|
|
clusterPeer.tags["connectRPCObject"] = peerConnectRPCObject
|
|
newValue = await peerConnectRPCObject(clusterObject)
|
|
if not newValue:
|
|
raise Exception("rpc object not found?")
|
|
return newValue
|
|
except Exception as e:
|
|
return value
|
|
|
|
|
|
class ClusterServerListener(TypedDict):
|
|
server: asyncio.Server
|
|
port: int
|
|
|
|
|
|
async def cluster_listen_zero(
|
|
callback: Callable[[asyncio.StreamReader, asyncio.StreamWriter]]
|
|
) -> ClusterServerListener:
|
|
SCRYPTED_CLUSTER_ADDRESS = os.getenv("SCRYPTED_CLUSTER_ADDRESS")
|
|
if not SCRYPTED_CLUSTER_ADDRESS or SCRYPTED_CLUSTER_ADDRESS == "127.0.0.1":
|
|
server = await asyncio.start_server(callback, host=None, port=0)
|
|
port = server.sockets[0].getsockname()[1]
|
|
return {
|
|
"server": server,
|
|
"port": port,
|
|
}
|
|
|
|
# need to listen on the cluster address and 127.0.0.1 on the same port.
|
|
retries = 5
|
|
while retries > 0:
|
|
cluster_server = await asyncio.start_server(
|
|
callback, host=SCRYPTED_CLUSTER_ADDRESS, port=0
|
|
)
|
|
port = cluster_server.sockets[0].getsockname()[1]
|
|
|
|
try:
|
|
local_server = await asyncio.start_server(
|
|
callback, host="127.0.0.1", port=port
|
|
)
|
|
|
|
future = asyncio.ensure_future(local_server.wait_closed())
|
|
future.add_done_callback(lambda: local_server.close())
|
|
|
|
return {
|
|
"server": cluster_server,
|
|
"port": port,
|
|
}
|
|
except:
|
|
# Port may be in use, keep trying.
|
|
cluster_server.close()
|
|
retries -= 1
|
|
|
|
raise Exception("failed to bind to cluster address.")
|