server: fix race condition around rpc object finalization

This commit is contained in:
Koushik Dutta
2021-12-16 22:11:29 -08:00
parent 3be72f8786
commit d644fe1122
2 changed files with 130 additions and 51 deletions

View File

@@ -1,8 +1,10 @@
from asyncio.futures import Future
from typing import Callable, Mapping, List
from typing import Any, Callable, Mapping, List
import traceback
import inspect
from typing_extensions import TypedDict
import weakref
import sys
jsonSerializable = set()
jsonSerializable.add(float)
@@ -45,21 +47,35 @@ class RpcProxyMethod:
return self.__proxy.__apply__(self.__proxy_method_name, args)
class LocalProxiedEntry(TypedDict):
id: str
finalizerId: str
class RpcProxy(object):
def __init__(self, peer, proxyId: str, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]):
self.__dict__['__proxy_id'] = proxyId
def __init__(self, peer, entry: LocalProxiedEntry, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]):
self.__dict__['__proxy_id'] = entry['id']
self.__dict__['__proxy_entry'] = entry
self.__dict__['__proxy_constructor'] = proxyConstructorName
self.__dict__['__proxy_peer'] = peer
self.__dict__['__proxy_props'] = proxyProps
self.__dict__['__proxy_oneway_methods'] = proxyOneWayMethods
def __getattr__(self, name):
if name == '__proxy_finalizer_id':
return self.dict['__proxy_entry']['finalizerId']
if name in self.__dict__:
return self.__dict__[name]
if self.__dict__['__proxy_props'] and name in self.__dict__['__proxy_props']:
return self.__dict__['__proxy_props'][name]
return RpcProxyMethod(self, name)
def __setattr__(self, name: str, value: Any) -> None:
if name == '__proxy_finalizer_id':
self.dict['__proxy_entry']['finalizerId'] = value
return super().__setattr__(name, value)
def __call__(self, *args, **kwargs):
print('call')
pass
@@ -69,20 +85,18 @@ class RpcProxy(object):
class RpcPeer:
# todo: these are all class statics lol, fix this.
idCounter = 1
peerName = 'Unnamed Peer'
params: Mapping[str, any] = {}
localProxied: Mapping[any, str] = {}
localProxyMap: Mapping[str, any] = {}
constructorSerializerMap = {}
proxyCounter = 1
pendingResults: Mapping[str, Future] = {}
remoteWeakProxies: Mapping[str, any] = {}
nameDeserializerMap: Mapping[str, RpcSerializer] = {}
def __init__(self, send: Callable[[object, Callable[[Exception], None]], None]) -> None:
self.send = send
self.idCounter = 1
self.peerName = 'Unnamed Peer'
self.params: Mapping[str, any] = {}
self.localProxied: Mapping[any, LocalProxiedEntry] = {}
self.localProxyMap: Mapping[str, any] = {}
self.constructorSerializerMap = {}
self.proxyCounter = 1
self.pendingResults: Mapping[str, Future] = {}
self.remoteWeakProxies: Mapping[str, any] = {}
self.nameDeserializerMap: Mapping[str, RpcSerializer] = {}
def __apply__(self, proxyId: str, oneWayMethods: List[str], method: str, args: list):
serializedArgs = []
@@ -120,12 +134,17 @@ class RpcPeer:
def serialize(self, value, requireProxy):
if (not value or (not requireProxy and type(value) in jsonSerializable)):
return value
__remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr(
value, '__proxy_constructor') else type(value).__name__
proxyId = self.localProxied.get(value, None)
if proxyId:
proxiedEntry = self.localProxied.get(value, None)
if proxiedEntry:
proxiedEntry['finalizerId'] = str(self.proxyCounter)
self.proxyCounter = self.proxyCounter + 1
ret = {
'__remote_proxy_id': proxyId,
'__remote_proxy_id': proxiedEntry['id'],
'__remote_proxy_finalizer_id': proxiedEntry['finalizerId'],
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
@@ -140,13 +159,15 @@ class RpcPeer:
}
return ret
serializerMapName = self.constructorSerializerMap.get(type(value), None)
serializerMapName = self.constructorSerializerMap.get(
type(value), None)
if serializerMapName:
__remote_constructor_name = serializerMapName
serializer = self.nameDeserializerMap.get(serializerMapName, None)
serialized = serializer.serialize(value)
ret = {
'__remote_proxy_id': None,
'__remote_proxy_finalizer_id': None,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
@@ -156,11 +177,16 @@ class RpcPeer:
proxyId = str(self.proxyCounter)
self.proxyCounter = self.proxyCounter + 1
self.localProxied[value] = proxyId
proxiedEntry = {
'id': proxyId,
'finalizerId': proxyId,
}
self.localProxied[value] = proxiedEntry
self.localProxyMap[proxyId] = value
ret = {
'__remote_proxy_id': proxyId,
'__remote_proxy_finalizer_id': proxyId,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
@@ -168,20 +194,26 @@ class RpcPeer:
return ret
def finalize(self, id: str):
def finalize(self, localProxiedEntry: LocalProxiedEntry):
id = localProxiedEntry['id']
self.remoteWeakProxies.pop(id, None)
rpcFinalize = {
'__local_proxy_id': id,
'__local_proxy_finalizer_id': localProxiedEntry['finalizerId'],
'type': 'finalize',
}
self.send(rpcFinalize)
def newProxy(self, proxyId: str, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]):
proxy = RpcProxy(self, proxyId, proxyConstructorName,
localProxiedEntry: LocalProxiedEntry = {
'id': proxyId,
'finalizerId': None,
}
proxy = RpcProxy(self, localProxiedEntry, proxyConstructorName,
proxyProps, proxyOneWayMethods)
wr = weakref.ref(proxy)
self.remoteWeakProxies[proxyId] = wr
weakref.finalize(proxy, lambda: self.finalize(proxyId))
weakref.finalize(proxy, lambda: self.finalize(localProxiedEntry))
return proxy
def deserialize(self, value):
@@ -192,6 +224,8 @@ class RpcPeer:
return value
__remote_proxy_id = value.get('__remote_proxy_id', None)
__remote_proxy_finalizer_id = value.get(
'__remote_proxy_finalizer_id', None)
__local_proxy_id = value.get('__local_proxy_id', None)
__remote_constructor_name = value.get(
'__remote_constructor_name', None)
@@ -206,6 +240,7 @@ class RpcPeer:
if not proxy:
proxy = self.newProxy(__remote_proxy_id, __remote_constructor_name,
__remote_proxy_props, __remote_proxy_oneway_methods)
proxy.__proxy_finalizer_id = __remote_proxy_finalizer_id
return proxy
if __local_proxy_id:
@@ -297,9 +332,16 @@ class RpcPeer:
future.set_result(self.deserialize(
message.get('result', None)))
elif messageType == 'finalize':
local = self.localProxyMap.pop(
message['__local_proxy_id'], None)
self.localProxied.pop(local, None)
finalizerId = message.get('__local_proxy_finalizer_id', None)
proxyId = message['__local_proxy_id']
local = self.localProxyMap.get(proxyId, None)
if local:
localProxiedEntry = self.localProxied.get(local)
if localProxiedEntry and finalizerId and localProxiedEntry['finalizerId'] != finalizerId:
print('mismatch finalizer id', file=sys.stderr)
return
self.localProxied.pop(local, None)
local = self.localProxyMap.pop(proxyId, None)
else:
raise RpcResultException(
None, 'unknown rpc message type %s' % type)