server: enable stable cluster proxyIds

This commit is contained in:
Koushik Dutta
2024-07-29 18:34:18 -07:00
parent 946d88236c
commit 82908b82c0
4 changed files with 50 additions and 32 deletions

View File

@@ -414,9 +414,10 @@ class PluginRemote:
m.update(bytes(f"{o['id']}{o['port']}{o.get('sourcePort') or ''}{o['proxyId']}{clusterSecret}", 'utf8'))
return base64.b64encode(m.digest()).decode('utf-8')
def onProxySerialization(value: Any, proxyId: str, sourcePeerPort: int = None):
def onProxySerialization(value: Any, sourcePeerPort: int = None):
properties: dict = rpc.RpcPeer.prepareProxyProperties(value) or {}
clusterEntry = properties.get('__cluster', None)
proxyId: str = (clusterEntry and clusterEntry.get('proxyId', None)) or rpc.RpcPeer.generateId()
if clusterEntry and clusterPort == clusterEntry['port'] and sourcePeerPort != clusterEntry.get('sourcePort', None):
clusterEntry = None
@@ -431,7 +432,7 @@ class PluginRemote:
clusterEntry['sha256'] = computeClusterObjectHash(clusterEntry)
properties['__cluster'] = clusterEntry
return properties
return proxyId, properties
self.peer.onProxySerialization = onProxySerialization
@@ -448,8 +449,8 @@ class PluginRemote:
rpcTransport = rpc_reader.RpcStreamTransport(reader, writer)
peer: rpc.RpcPeer
peer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, rpcTransport)
peer.onProxySerialization = lambda value, proxyId: onProxySerialization(
value, proxyId, clusterPeerPort)
peer.onProxySerialization = lambda value: onProxySerialization(
value, clusterPeerPort)
future: asyncio.Future[rpc.RpcPeer] = asyncio.Future()
future.set_result(peer)
clusterPeers[clusterPeerPort] = future
@@ -483,9 +484,8 @@ class PluginRemote:
rpcTransport = rpc_reader.RpcStreamTransport(
reader, writer)
clusterPeer, peerReadLoop = await rpc_reader.prepare_peer_readloop(self.loop, rpcTransport)
clusterPeer.tags['localPort'] = clusterPeerPort
clusterPeer.onProxySerialization = lambda value, proxyId: onProxySerialization(
value, proxyId, clusterPeerPort)
clusterPeer.onProxySerialization = lambda value: onProxySerialization(
value, clusterPeerPort)
async def run_loop():
try:
@@ -519,8 +519,11 @@ class PluginRemote:
try:
clusterPeer = await clusterPeerPromise
if clusterPeer.tags.get('localPort') == sourcePort:
return value
weakref = clusterPeer.remoteWeakProxies.get(proxyId, None)
existing = weakref() if weakref else None
if existing:
return existing
peerConnectRPCObject = clusterPeer.tags.get('connectRPCObject')
if not peerConnectRPCObject:
peerConnectRPCObject = await clusterPeer.getParam('connectRPCObject')

View File

@@ -126,7 +126,7 @@ class RpcPeer:
self.pendingResults: Mapping[str, Future] = {}
self.remoteWeakProxies: Mapping[str, any] = {}
self.nameDeserializerMap: Mapping[str, RpcSerializer] = {}
self.onProxySerialization: Callable[[Any, str], Any] = None
self.onProxySerialization: Callable[[Any, str], tuple[str, Any]] = None
self.killed = False
self.tags = {}
@@ -274,7 +274,7 @@ class RpcPeer:
proxiedEntry = self.localProxied.get(value, None)
if proxiedEntry:
proxiedEntry['finalizerId'] = self.generateId()
proxiedEntry['finalizerId'] = RpcPeer.generateId()
ret = {
'__remote_proxy_id': proxiedEntry['id'],
'__remote_proxy_finalizer_id': proxiedEntry['finalizerId'],
@@ -292,7 +292,12 @@ class RpcPeer:
}
return ret
proxyId = self.generateId()
if self.onProxySerialization:
proxyId, __remote_proxy_props = self.onProxySerialization(value)
else:
__remote_proxy_props = RpcPeer.prepareProxyProperties(value)
proxyId = RpcPeer.generateId()
proxiedEntry = {
'id': proxyId,
'finalizerId': proxyId,
@@ -300,11 +305,6 @@ class RpcPeer:
self.localProxied[value] = proxiedEntry
self.localProxyMap[proxyId] = value
if self.onProxySerialization:
__remote_proxy_props = self.onProxySerialization(value, proxyId)
else:
__remote_proxy_props = RpcPeer.prepareProxyProperties(value)
ret = {
'__remote_proxy_id': proxyId,
'__remote_proxy_finalizer_id': proxyId,
@@ -491,7 +491,7 @@ class RpcPeer:
randomDigits = string.ascii_uppercase + string.ascii_lowercase + string.digits
def generateId(self):
def generateId():
return ''.join(random.choices(RpcPeer.randomDigits, k=8))
async def createPendingResult(self, cb: Callable[[str, Callable[[Exception], None]], None]):
@@ -500,7 +500,7 @@ class RpcPeer:
future.set_exception(RPCResultError(None, 'RpcPeer has been killed (createPendingResult)'))
return future
id = self.generateId()
id = RpcPeer.generateId()
self.pendingResults[id] = future
await cb(id, lambda e: future.set_exception(RPCResultError(e, None)))
return await future