server: limit address binding in cluster mode

This commit is contained in:
Koushik Dutta
2025-03-02 14:57:56 -08:00
parent 16a9abeb9e
commit fe1b677381
6 changed files with 102 additions and 32 deletions

View File

@@ -10,7 +10,7 @@ from typing import Any
import rpc
import rpc_reader
from typing import TypedDict
from typing import TypedDict, Callable
class ClusterObject(TypedDict):
@@ -21,6 +21,7 @@ class ClusterObject(TypedDict):
sourceKey: str
sha256: str
def isClusterAddress(address: str):
return not address or address == os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None)
@@ -130,11 +131,8 @@ class ClusterSetup:
peer.kill("cluster client killed")
writer.close()
listenAddress = "0.0.0.0" if self.SCRYPTED_CLUSTER_ADDRESS else "127.0.0.1"
clusterRpcServer = await asyncio.start_server(
handleClusterClient, listenAddress, 0
)
self.clusterPort = clusterRpcServer.sockets[0].getsockname()[1]
clusterRpcServerInfo = await cluster_listen_zero(handleClusterClient)
self.clusterPort = clusterRpcServerInfo["port"]
self.peer.onProxySerialization = lambda value: self.onProxySerialization(
self.peer, value, None
)
@@ -238,3 +236,49 @@ class ClusterSetup:
return newValue
except Exception as e:
return value
class ClusterServerListener(TypedDict):
server: asyncio.Server
port: int
async def cluster_listen_zero(
callback: Callable[[asyncio.StreamReader, asyncio.StreamWriter]]
) -> ClusterServerListener:
SCRYPTED_CLUSTER_ADDRESS = os.getenv("SCRYPTED_CLUSTER_ADDRESS")
if not SCRYPTED_CLUSTER_ADDRESS:
server = await asyncio.start_server(callback, host=None, port=0)
port = server.sockets[0].getsockname()[1]
return {
"server": server,
"port": port,
}
# need to listen on the cluster address and 127.0.0.1 on the same port.
retries = 5
while retries > 0:
cluster_server = await asyncio.start_server(
callback, host=SCRYPTED_CLUSTER_ADDRESS, port=0
)
port = cluster_server.sockets[0].getsockname()[1]
try:
print('trying to bind to port', port)
local_server = await asyncio.start_server(
callback, host="127.0.0.1", port=port
)
future = asyncio.ensure_future(local_server.wait_closed())
future.add_done_callback(lambda: local_server.close())
return {
"server": cluster_server,
"port": port,
}
except:
# Port may be in use, keep trying.
cluster_server.close()
retries -= 1
raise Exception("failed to bind to cluster address.")

View File

@@ -198,6 +198,7 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(None)
# TODO: this should use an equivalent to cluster_listen_zero
sock.bind(("0.0.0.0" if os.getenv("SCRYPTED_CLUSTER_ADDRESS") else "127.0.0.1", 0))
sock.listen()