server: fix race condition around rpc object finalization

This commit is contained in:
Koushik Dutta
2021-12-16 22:11:29 -08:00
parent 3be72f8786
commit d644fe1122
2 changed files with 130 additions and 51 deletions

View File

@@ -1,8 +1,10 @@
from asyncio.futures import Future
from typing import Callable, Mapping, List
from typing import Any, Callable, Mapping, List
import traceback
import inspect
from typing_extensions import TypedDict
import weakref
import sys
jsonSerializable = set()
jsonSerializable.add(float)
@@ -45,21 +47,35 @@ class RpcProxyMethod:
return self.__proxy.__apply__(self.__proxy_method_name, args)
class LocalProxiedEntry(TypedDict):
id: str
finalizerId: str
class RpcProxy(object):
def __init__(self, peer, proxyId: str, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]):
self.__dict__['__proxy_id'] = proxyId
def __init__(self, peer, entry: LocalProxiedEntry, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]):
self.__dict__['__proxy_id'] = entry['id']
self.__dict__['__proxy_entry'] = entry
self.__dict__['__proxy_constructor'] = proxyConstructorName
self.__dict__['__proxy_peer'] = peer
self.__dict__['__proxy_props'] = proxyProps
self.__dict__['__proxy_oneway_methods'] = proxyOneWayMethods
def __getattr__(self, name):
if name == '__proxy_finalizer_id':
return self.dict['__proxy_entry']['finalizerId']
if name in self.__dict__:
return self.__dict__[name]
if self.__dict__['__proxy_props'] and name in self.__dict__['__proxy_props']:
return self.__dict__['__proxy_props'][name]
return RpcProxyMethod(self, name)
def __setattr__(self, name: str, value: Any) -> None:
if name == '__proxy_finalizer_id':
self.dict['__proxy_entry']['finalizerId'] = value
return super().__setattr__(name, value)
def __call__(self, *args, **kwargs):
print('call')
pass
@@ -69,20 +85,18 @@ class RpcProxy(object):
class RpcPeer:
# todo: these are all class statics lol, fix this.
idCounter = 1
peerName = 'Unnamed Peer'
params: Mapping[str, any] = {}
localProxied: Mapping[any, str] = {}
localProxyMap: Mapping[str, any] = {}
constructorSerializerMap = {}
proxyCounter = 1
pendingResults: Mapping[str, Future] = {}
remoteWeakProxies: Mapping[str, any] = {}
nameDeserializerMap: Mapping[str, RpcSerializer] = {}
def __init__(self, send: Callable[[object, Callable[[Exception], None]], None]) -> None:
self.send = send
self.idCounter = 1
self.peerName = 'Unnamed Peer'
self.params: Mapping[str, any] = {}
self.localProxied: Mapping[any, LocalProxiedEntry] = {}
self.localProxyMap: Mapping[str, any] = {}
self.constructorSerializerMap = {}
self.proxyCounter = 1
self.pendingResults: Mapping[str, Future] = {}
self.remoteWeakProxies: Mapping[str, any] = {}
self.nameDeserializerMap: Mapping[str, RpcSerializer] = {}
def __apply__(self, proxyId: str, oneWayMethods: List[str], method: str, args: list):
serializedArgs = []
@@ -120,12 +134,17 @@ class RpcPeer:
def serialize(self, value, requireProxy):
if (not value or (not requireProxy and type(value) in jsonSerializable)):
return value
__remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr(
value, '__proxy_constructor') else type(value).__name__
proxyId = self.localProxied.get(value, None)
if proxyId:
proxiedEntry = self.localProxied.get(value, None)
if proxiedEntry:
proxiedEntry['finalizerId'] = str(self.proxyCounter)
self.proxyCounter = self.proxyCounter + 1
ret = {
'__remote_proxy_id': proxyId,
'__remote_proxy_id': proxiedEntry['id'],
'__remote_proxy_finalizer_id': proxiedEntry['finalizerId'],
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
@@ -140,13 +159,15 @@ class RpcPeer:
}
return ret
serializerMapName = self.constructorSerializerMap.get(type(value), None)
serializerMapName = self.constructorSerializerMap.get(
type(value), None)
if serializerMapName:
__remote_constructor_name = serializerMapName
serializer = self.nameDeserializerMap.get(serializerMapName, None)
serialized = serializer.serialize(value)
ret = {
'__remote_proxy_id': None,
'__remote_proxy_finalizer_id': None,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
@@ -156,11 +177,16 @@ class RpcPeer:
proxyId = str(self.proxyCounter)
self.proxyCounter = self.proxyCounter + 1
self.localProxied[value] = proxyId
proxiedEntry = {
'id': proxyId,
'finalizerId': proxyId,
}
self.localProxied[value] = proxiedEntry
self.localProxyMap[proxyId] = value
ret = {
'__remote_proxy_id': proxyId,
'__remote_proxy_finalizer_id': proxyId,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
@@ -168,20 +194,26 @@ class RpcPeer:
return ret
def finalize(self, id: str):
def finalize(self, localProxiedEntry: LocalProxiedEntry):
id = localProxiedEntry['id']
self.remoteWeakProxies.pop(id, None)
rpcFinalize = {
'__local_proxy_id': id,
'__local_proxy_finalizer_id': localProxiedEntry['finalizerId'],
'type': 'finalize',
}
self.send(rpcFinalize)
def newProxy(self, proxyId: str, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]):
proxy = RpcProxy(self, proxyId, proxyConstructorName,
localProxiedEntry: LocalProxiedEntry = {
'id': proxyId,
'finalizerId': None,
}
proxy = RpcProxy(self, localProxiedEntry, proxyConstructorName,
proxyProps, proxyOneWayMethods)
wr = weakref.ref(proxy)
self.remoteWeakProxies[proxyId] = wr
weakref.finalize(proxy, lambda: self.finalize(proxyId))
weakref.finalize(proxy, lambda: self.finalize(localProxiedEntry))
return proxy
def deserialize(self, value):
@@ -192,6 +224,8 @@ class RpcPeer:
return value
__remote_proxy_id = value.get('__remote_proxy_id', None)
__remote_proxy_finalizer_id = value.get(
'__remote_proxy_finalizer_id', None)
__local_proxy_id = value.get('__local_proxy_id', None)
__remote_constructor_name = value.get(
'__remote_constructor_name', None)
@@ -206,6 +240,7 @@ class RpcPeer:
if not proxy:
proxy = self.newProxy(__remote_proxy_id, __remote_constructor_name,
__remote_proxy_props, __remote_proxy_oneway_methods)
proxy.__proxy_finalizer_id = __remote_proxy_finalizer_id
return proxy
if __local_proxy_id:
@@ -297,9 +332,16 @@ class RpcPeer:
future.set_result(self.deserialize(
message.get('result', None)))
elif messageType == 'finalize':
local = self.localProxyMap.pop(
message['__local_proxy_id'], None)
self.localProxied.pop(local, None)
finalizerId = message.get('__local_proxy_finalizer_id', None)
proxyId = message['__local_proxy_id']
local = self.localProxyMap.get(proxyId, None)
if local:
localProxiedEntry = self.localProxied.get(local)
if localProxiedEntry and finalizerId and localProxiedEntry['finalizerId'] != finalizerId:
print('mismatch finalizer id', file=sys.stderr)
return
self.localProxied.pop(local, None)
local = self.localProxyMap.pop(proxyId, None)
else:
raise RpcResultException(
None, 'unknown rpc message type %s' % type)

View File

@@ -1,5 +1,7 @@
import vm from 'vm';
const finalizerIdSymbol = Symbol('rpcFinalizerId');
function getDefaultTransportSafeArgumentTypes() {
const jsonSerializable = new Set<string>();
jsonSerializable.add(Number.name);
@@ -40,6 +42,7 @@ interface RpcOob extends RpcMessage {
interface RpcRemoteProxyValue {
__remote_proxy_id: string;
__remote_proxy_finalizer_id: string;
__remote_constructor_name: string;
__remote_proxy_props: any;
__remote_proxy_oneway_methods: string[];
@@ -52,6 +55,7 @@ interface RpcLocalProxyValue {
interface RpcFinalize extends RpcMessage {
__local_proxy_id: string;
__local_proxy_finalizer_id: string;
}
interface Deferred {
@@ -78,20 +82,17 @@ export const PROPERTY_PROXY_PROPERTIES = '__proxy_props';
export const PROPERTY_JSON_COPY_SERIALIZE_CHILDREN = '__json_copy_serialize_children';
class RpcProxy implements ProxyHandler<any> {
constructor(public peer: RpcPeer,
public id: string,
public entry: LocalProxiedEntry,
public constructorName: string,
public proxyProps: any,
public proxyOneWayMethods: string[]) {
this.peer = peer;
this.id = id;
this.constructorName = constructorName;
this.proxyProps = proxyProps;
}
get(target: any, p: PropertyKey, receiver: any): any {
if (p === '__proxy_id')
return this.id;
return this.entry.id;
if (p === '__proxy_constructor')
return this.constructorName;
if (p === '__proxy_peer')
@@ -114,6 +115,12 @@ class RpcProxy implements ProxyHandler<any> {
return new Proxy(() => p, this);
}
set(target: any, p: string | symbol, value: any, receiver: any): boolean {
if (p === finalizerIdSymbol)
this.entry.finalizerId = value;
return true;
}
apply(target: any, thisArg: any, argArray?: any): any {
const method = target();
const args: any[] = [];
@@ -124,7 +131,7 @@ class RpcProxy implements ProxyHandler<any> {
const rpcApply: RpcApply = {
type: "apply",
id: undefined,
proxyId: this.id,
proxyId: this.entry.id,
args,
method,
};
@@ -187,16 +194,21 @@ export interface RpcSerializer {
deserialize(serialized: any): any;
}
interface LocalProxiedEntry {
id: string;
finalizerId: string;
}
export class RpcPeer {
idCounter = 1;
onOob: (oob: any) => void;
params: { [name: string]: any } = {};
pendingResults: { [id: string]: Deferred } = {};
proxyCounter = 1;
localProxied = new Map<any, string>();
localProxied = new Map<any, LocalProxiedEntry>();
localProxyMap: { [id: string]: any } = {};
remoteWeakProxies: { [id: string]: WeakRef<any> } = {};
finalizers = new FinalizationRegistry(id => this.finalize(id as string));
finalizers = new FinalizationRegistry(entry => this.finalize(entry as LocalProxiedEntry));
nameDeserializerMap = new Map<string, RpcSerializer>();
constructorSerializerMap = new Map<string, string>();
transportSafeArgumentTypes = getDefaultTransportSafeArgumentTypes();
@@ -238,10 +250,11 @@ export class RpcPeer {
this.constructorSerializerMap.set(ctr, name);
}
finalize(id: string) {
delete this.remoteWeakProxies[id];
finalize(entry: LocalProxiedEntry) {
delete this.remoteWeakProxies[entry.id];
const rpcFinalize: RpcFinalize = {
__local_proxy_id: id,
__local_proxy_id: entry.id,
__local_proxy_finalizer_id: entry.finalizerId,
type: 'finalize',
}
this.send(rpcFinalize);
@@ -294,9 +307,12 @@ export class RpcPeer {
return ret;
}
const { __remote_proxy_id, __local_proxy_id, __remote_constructor_name, __serialized_value, __remote_proxy_props, __remote_proxy_oneway_methods } = value;
const { __remote_proxy_id, __remote_proxy_finalizer_id, __local_proxy_id, __remote_constructor_name, __serialized_value, __remote_proxy_props, __remote_proxy_oneway_methods } = value;
if (__remote_proxy_id) {
const proxy = this.remoteWeakProxies[__remote_proxy_id]?.deref() || this.newProxy(__remote_proxy_id, __remote_constructor_name, __remote_proxy_props, __remote_proxy_oneway_methods);
let proxy = this.remoteWeakProxies[__remote_proxy_id]?.deref();
if (!proxy)
proxy = this.newProxy(__remote_proxy_id, __remote_constructor_name, __remote_proxy_props, __remote_proxy_oneway_methods);
proxy[finalizerIdSymbol] = __remote_proxy_finalizer_id;
return proxy;
}
@@ -329,10 +345,13 @@ export class RpcPeer {
let __remote_constructor_name = value.__proxy_constructor || value.constructor?.name?.toString();
let proxyId = this.localProxied.get(value);
if (proxyId) {
let proxiedEntry = this.localProxied.get(value);
if (proxiedEntry) {
const __remote_proxy_finalizer_id = (this.proxyCounter++).toString();
proxiedEntry.finalizerId = __remote_proxy_finalizer_id;
const ret: RpcRemoteProxyValue = {
__remote_proxy_id: proxyId,
__remote_proxy_id: proxiedEntry.id,
__remote_proxy_finalizer_id,
__remote_constructor_name,
__remote_proxy_props: value?.[PROPERTY_PROXY_PROPERTIES],
__remote_proxy_oneway_methods: value?.[PROPERTY_PROXY_ONEWAY_METHODS],
@@ -355,6 +374,7 @@ export class RpcPeer {
const serialized = serializer.serialize(value);
const ret: RpcRemoteProxyValue = {
__remote_proxy_id: undefined,
__remote_proxy_finalizer_id: undefined,
__remote_constructor_name,
__remote_proxy_props: value?.[PROPERTY_PROXY_PROPERTIES],
__remote_proxy_oneway_methods: value?.[PROPERTY_PROXY_ONEWAY_METHODS],
@@ -363,12 +383,17 @@ export class RpcPeer {
return ret;
}
proxyId = (this.proxyCounter++).toString();
this.localProxied.set(value, proxyId);
this.localProxyMap[proxyId] = value;
const __remote_proxy_id = (this.proxyCounter++).toString();
proxiedEntry = {
id: __remote_proxy_id,
finalizerId: __remote_proxy_id,
};
this.localProxied.set(value, proxiedEntry);
this.localProxyMap[__remote_proxy_id] = value;
const ret: RpcRemoteProxyValue = {
__remote_proxy_id: proxyId,
__remote_proxy_id,
__remote_proxy_finalizer_id: __remote_proxy_id,
__remote_constructor_name,
__remote_proxy_props: value?.[PROPERTY_PROXY_PROPERTIES],
__remote_proxy_oneway_methods: value?.[PROPERTY_PROXY_ONEWAY_METHODS],
@@ -378,12 +403,16 @@ export class RpcPeer {
}
newProxy(proxyId: string, proxyConstructorName: string, proxyProps: any, proxyOneWayMethods: string[]) {
const rpc = new RpcProxy(this, proxyId, proxyConstructorName, proxyProps, proxyOneWayMethods);
const localProxiedEntry: LocalProxiedEntry = {
id: proxyId,
finalizerId: undefined,
}
const rpc = new RpcProxy(this, localProxiedEntry, proxyConstructorName, proxyProps, proxyOneWayMethods);
const target = proxyConstructorName === 'Function' || proxyConstructorName === 'AsyncFunction' ? function () { } : rpc;
const proxy = new Proxy(target, rpc);
const weakref = new WeakRef(proxy);
this.remoteWeakProxies[proxyId] = weakref;
this.finalizers.register(rpc, proxyId);
this.finalizers.register(rpc, localProxiedEntry);
global.gc?.();
return proxy;
}
@@ -460,8 +489,16 @@ export class RpcPeer {
case 'finalize': {
const rpcFinalize = message as RpcFinalize;
const local = this.localProxyMap[rpcFinalize.__local_proxy_id];
delete this.localProxyMap[rpcFinalize.__local_proxy_id];
this.localProxied.delete(local);
if (local) {
const localProxiedEntry = this.localProxied.get(local);
// if a finalizer id is specified, it must match.
if (rpcFinalize.__local_proxy_finalizer_id && rpcFinalize.__local_proxy_finalizer_id !== localProxiedEntry?.finalizerId) {
console.error(this.selfName, this.peerName, 'finalizer mismatch')
break;
}
delete this.localProxyMap[rpcFinalize.__local_proxy_id];
this.localProxied.delete(local);
}
break;
}
case 'oob': {