mirror of
https://github.com/koush/scrypted.git
synced 2026-06-21 00:50:30 +01:00
server: limit address binding in cluster mode
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
import rpc
|
||||
import rpc_reader
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, Callable
|
||||
|
||||
|
||||
class ClusterObject(TypedDict):
|
||||
@@ -21,6 +21,7 @@ class ClusterObject(TypedDict):
|
||||
sourceKey: str
|
||||
sha256: str
|
||||
|
||||
|
||||
def isClusterAddress(address: str):
|
||||
return not address or address == os.environ.get("SCRYPTED_CLUSTER_ADDRESS", None)
|
||||
|
||||
@@ -130,11 +131,8 @@ class ClusterSetup:
|
||||
peer.kill("cluster client killed")
|
||||
writer.close()
|
||||
|
||||
listenAddress = "0.0.0.0" if self.SCRYPTED_CLUSTER_ADDRESS else "127.0.0.1"
|
||||
clusterRpcServer = await asyncio.start_server(
|
||||
handleClusterClient, listenAddress, 0
|
||||
)
|
||||
self.clusterPort = clusterRpcServer.sockets[0].getsockname()[1]
|
||||
clusterRpcServerInfo = await cluster_listen_zero(handleClusterClient)
|
||||
self.clusterPort = clusterRpcServerInfo["port"]
|
||||
self.peer.onProxySerialization = lambda value: self.onProxySerialization(
|
||||
self.peer, value, None
|
||||
)
|
||||
@@ -238,3 +236,49 @@ class ClusterSetup:
|
||||
return newValue
|
||||
except Exception as e:
|
||||
return value
|
||||
|
||||
|
||||
class ClusterServerListener(TypedDict):
|
||||
server: asyncio.Server
|
||||
port: int
|
||||
|
||||
|
||||
async def cluster_listen_zero(
|
||||
callback: Callable[[asyncio.StreamReader, asyncio.StreamWriter]]
|
||||
) -> ClusterServerListener:
|
||||
SCRYPTED_CLUSTER_ADDRESS = os.getenv("SCRYPTED_CLUSTER_ADDRESS")
|
||||
if not SCRYPTED_CLUSTER_ADDRESS:
|
||||
server = await asyncio.start_server(callback, host=None, port=0)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
return {
|
||||
"server": server,
|
||||
"port": port,
|
||||
}
|
||||
|
||||
# need to listen on the cluster address and 127.0.0.1 on the same port.
|
||||
retries = 5
|
||||
while retries > 0:
|
||||
cluster_server = await asyncio.start_server(
|
||||
callback, host=SCRYPTED_CLUSTER_ADDRESS, port=0
|
||||
)
|
||||
port = cluster_server.sockets[0].getsockname()[1]
|
||||
|
||||
try:
|
||||
print('trying to bind to port', port)
|
||||
local_server = await asyncio.start_server(
|
||||
callback, host="127.0.0.1", port=port
|
||||
)
|
||||
|
||||
future = asyncio.ensure_future(local_server.wait_closed())
|
||||
future.add_done_callback(lambda: local_server.close())
|
||||
|
||||
return {
|
||||
"server": cluster_server,
|
||||
"port": port,
|
||||
}
|
||||
except:
|
||||
# Port may be in use, keep trying.
|
||||
cluster_server.close()
|
||||
retries -= 1
|
||||
|
||||
raise Exception("failed to bind to cluster address.")
|
||||
|
||||
@@ -198,6 +198,7 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.settimeout(None)
|
||||
# TODO: this should use an equivalent to cluster_listen_zero
|
||||
sock.bind(("0.0.0.0" if os.getenv("SCRYPTED_CLUSTER_ADDRESS") else "127.0.0.1", 0))
|
||||
sock.listen()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user