diff --git a/server/python/rpc.py b/server/python/rpc.py index f8d4a90ee..8d2370be6 100644 --- a/server/python/rpc.py +++ b/server/python/rpc.py @@ -1,8 +1,10 @@ from asyncio.futures import Future -from typing import Callable, Mapping, List +from typing import Any, Callable, Mapping, List import traceback import inspect +from typing_extensions import TypedDict import weakref +import sys jsonSerializable = set() jsonSerializable.add(float) @@ -45,21 +47,35 @@ class RpcProxyMethod: return self.__proxy.__apply__(self.__proxy_method_name, args) +class LocalProxiedEntry(TypedDict): + id: str + finalizerId: str + + class RpcProxy(object): - def __init__(self, peer, proxyId: str, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]): - self.__dict__['__proxy_id'] = proxyId + def __init__(self, peer, entry: LocalProxiedEntry, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]): + self.__dict__['__proxy_id'] = entry['id'] + self.__dict__['__proxy_entry'] = entry self.__dict__['__proxy_constructor'] = proxyConstructorName self.__dict__['__proxy_peer'] = peer self.__dict__['__proxy_props'] = proxyProps self.__dict__['__proxy_oneway_methods'] = proxyOneWayMethods def __getattr__(self, name): + if name == '__proxy_finalizer_id': + return self.dict['__proxy_entry']['finalizerId'] if name in self.__dict__: return self.__dict__[name] if self.__dict__['__proxy_props'] and name in self.__dict__['__proxy_props']: return self.__dict__['__proxy_props'][name] return RpcProxyMethod(self, name) + def __setattr__(self, name: str, value: Any) -> None: + if name == '__proxy_finalizer_id': + self.dict['__proxy_entry']['finalizerId'] = value + + return super().__setattr__(name, value) + def __call__(self, *args, **kwargs): print('call') pass @@ -69,20 +85,18 @@ class RpcProxy(object): class RpcPeer: - # todo: these are all class statics lol, fix this. - idCounter = 1 - peerName = 'Unnamed Peer' - params: Mapping[str, any] = {} - localProxied: Mapping[any, str] = {} - localProxyMap: Mapping[str, any] = {} - constructorSerializerMap = {} - proxyCounter = 1 - pendingResults: Mapping[str, Future] = {} - remoteWeakProxies: Mapping[str, any] = {} - nameDeserializerMap: Mapping[str, RpcSerializer] = {} - def __init__(self, send: Callable[[object, Callable[[Exception], None]], None]) -> None: self.send = send + self.idCounter = 1 + self.peerName = 'Unnamed Peer' + self.params: Mapping[str, any] = {} + self.localProxied: Mapping[any, LocalProxiedEntry] = {} + self.localProxyMap: Mapping[str, any] = {} + self.constructorSerializerMap = {} + self.proxyCounter = 1 + self.pendingResults: Mapping[str, Future] = {} + self.remoteWeakProxies: Mapping[str, any] = {} + self.nameDeserializerMap: Mapping[str, RpcSerializer] = {} def __apply__(self, proxyId: str, oneWayMethods: List[str], method: str, args: list): serializedArgs = [] @@ -120,12 +134,17 @@ class RpcPeer: def serialize(self, value, requireProxy): if (not value or (not requireProxy and type(value) in jsonSerializable)): return value + __remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr( value, '__proxy_constructor') else type(value).__name__ - proxyId = self.localProxied.get(value, None) - if proxyId: + + proxiedEntry = self.localProxied.get(value, None) + if proxiedEntry: + proxiedEntry['finalizerId'] = str(self.proxyCounter) + self.proxyCounter = self.proxyCounter + 1 ret = { - '__remote_proxy_id': proxyId, + '__remote_proxy_id': proxiedEntry['id'], + '__remote_proxy_finalizer_id': proxiedEntry['finalizerId'], '__remote_constructor_name': __remote_constructor_name, '__remote_proxy_props': getattr(value, '__proxy_props', None), '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), @@ -140,13 +159,15 @@ class RpcPeer: } return ret - serializerMapName = self.constructorSerializerMap.get(type(value), None) + serializerMapName = self.constructorSerializerMap.get( + type(value), None) if serializerMapName: __remote_constructor_name = serializerMapName serializer = self.nameDeserializerMap.get(serializerMapName, None) serialized = serializer.serialize(value) ret = { '__remote_proxy_id': None, + '__remote_proxy_finalizer_id': None, '__remote_constructor_name': __remote_constructor_name, '__remote_proxy_props': getattr(value, '__proxy_props', None), '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), @@ -156,11 +177,16 @@ class RpcPeer: proxyId = str(self.proxyCounter) self.proxyCounter = self.proxyCounter + 1 - self.localProxied[value] = proxyId + proxiedEntry = { + 'id': proxyId, + 'finalizerId': proxyId, + } + self.localProxied[value] = proxiedEntry self.localProxyMap[proxyId] = value ret = { '__remote_proxy_id': proxyId, + '__remote_proxy_finalizer_id': proxyId, '__remote_constructor_name': __remote_constructor_name, '__remote_proxy_props': getattr(value, '__proxy_props', None), '__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None), @@ -168,20 +194,26 @@ class RpcPeer: return ret - def finalize(self, id: str): + def finalize(self, localProxiedEntry: LocalProxiedEntry): + id = localProxiedEntry['id'] self.remoteWeakProxies.pop(id, None) rpcFinalize = { '__local_proxy_id': id, + '__local_proxy_finalizer_id': localProxiedEntry['finalizerId'], 'type': 'finalize', } self.send(rpcFinalize) def newProxy(self, proxyId: str, proxyConstructorName: str, proxyProps: any, proxyOneWayMethods: List[str]): - proxy = RpcProxy(self, proxyId, proxyConstructorName, + localProxiedEntry: LocalProxiedEntry = { + 'id': proxyId, + 'finalizerId': None, + } + proxy = RpcProxy(self, localProxiedEntry, proxyConstructorName, proxyProps, proxyOneWayMethods) wr = weakref.ref(proxy) self.remoteWeakProxies[proxyId] = wr - weakref.finalize(proxy, lambda: self.finalize(proxyId)) + weakref.finalize(proxy, lambda: self.finalize(localProxiedEntry)) return proxy def deserialize(self, value): @@ -192,6 +224,8 @@ class RpcPeer: return value __remote_proxy_id = value.get('__remote_proxy_id', None) + __remote_proxy_finalizer_id = value.get( + '__remote_proxy_finalizer_id', None) __local_proxy_id = value.get('__local_proxy_id', None) __remote_constructor_name = value.get( '__remote_constructor_name', None) @@ -206,6 +240,7 @@ class RpcPeer: if not proxy: proxy = self.newProxy(__remote_proxy_id, __remote_constructor_name, __remote_proxy_props, __remote_proxy_oneway_methods) + proxy.__proxy_finalizer_id = __remote_proxy_finalizer_id return proxy if __local_proxy_id: @@ -297,9 +332,16 @@ class RpcPeer: future.set_result(self.deserialize( message.get('result', None))) elif messageType == 'finalize': - local = self.localProxyMap.pop( - message['__local_proxy_id'], None) - self.localProxied.pop(local, None) + finalizerId = message.get('__local_proxy_finalizer_id', None) + proxyId = message['__local_proxy_id'] + local = self.localProxyMap.get(proxyId, None) + if local: + localProxiedEntry = self.localProxied.get(local) + if localProxiedEntry and finalizerId and localProxiedEntry['finalizerId'] != finalizerId: + print('mismatch finalizer id', file=sys.stderr) + return + self.localProxied.pop(local, None) + local = self.localProxyMap.pop(proxyId, None) else: raise RpcResultException( None, 'unknown rpc message type %s' % type) diff --git a/server/src/rpc.ts b/server/src/rpc.ts index 3bb7b5016..fb7b81bff 100644 --- a/server/src/rpc.ts +++ b/server/src/rpc.ts @@ -1,5 +1,7 @@ import vm from 'vm'; +const finalizerIdSymbol = Symbol('rpcFinalizerId'); + function getDefaultTransportSafeArgumentTypes() { const jsonSerializable = new Set(); jsonSerializable.add(Number.name); @@ -40,6 +42,7 @@ interface RpcOob extends RpcMessage { interface RpcRemoteProxyValue { __remote_proxy_id: string; + __remote_proxy_finalizer_id: string; __remote_constructor_name: string; __remote_proxy_props: any; __remote_proxy_oneway_methods: string[]; @@ -52,6 +55,7 @@ interface RpcLocalProxyValue { interface RpcFinalize extends RpcMessage { __local_proxy_id: string; + __local_proxy_finalizer_id: string; } interface Deferred { @@ -78,20 +82,17 @@ export const PROPERTY_PROXY_PROPERTIES = '__proxy_props'; export const PROPERTY_JSON_COPY_SERIALIZE_CHILDREN = '__json_copy_serialize_children'; class RpcProxy implements ProxyHandler { + constructor(public peer: RpcPeer, - public id: string, + public entry: LocalProxiedEntry, public constructorName: string, public proxyProps: any, public proxyOneWayMethods: string[]) { - this.peer = peer; - this.id = id; - this.constructorName = constructorName; - this.proxyProps = proxyProps; } get(target: any, p: PropertyKey, receiver: any): any { if (p === '__proxy_id') - return this.id; + return this.entry.id; if (p === '__proxy_constructor') return this.constructorName; if (p === '__proxy_peer') @@ -114,6 +115,12 @@ class RpcProxy implements ProxyHandler { return new Proxy(() => p, this); } + set(target: any, p: string | symbol, value: any, receiver: any): boolean { + if (p === finalizerIdSymbol) + this.entry.finalizerId = value; + return true; + } + apply(target: any, thisArg: any, argArray?: any): any { const method = target(); const args: any[] = []; @@ -124,7 +131,7 @@ class RpcProxy implements ProxyHandler { const rpcApply: RpcApply = { type: "apply", id: undefined, - proxyId: this.id, + proxyId: this.entry.id, args, method, }; @@ -187,16 +194,21 @@ export interface RpcSerializer { deserialize(serialized: any): any; } +interface LocalProxiedEntry { + id: string; + finalizerId: string; +} + export class RpcPeer { idCounter = 1; onOob: (oob: any) => void; params: { [name: string]: any } = {}; pendingResults: { [id: string]: Deferred } = {}; proxyCounter = 1; - localProxied = new Map(); + localProxied = new Map(); localProxyMap: { [id: string]: any } = {}; remoteWeakProxies: { [id: string]: WeakRef } = {}; - finalizers = new FinalizationRegistry(id => this.finalize(id as string)); + finalizers = new FinalizationRegistry(entry => this.finalize(entry as LocalProxiedEntry)); nameDeserializerMap = new Map(); constructorSerializerMap = new Map(); transportSafeArgumentTypes = getDefaultTransportSafeArgumentTypes(); @@ -238,10 +250,11 @@ export class RpcPeer { this.constructorSerializerMap.set(ctr, name); } - finalize(id: string) { - delete this.remoteWeakProxies[id]; + finalize(entry: LocalProxiedEntry) { + delete this.remoteWeakProxies[entry.id]; const rpcFinalize: RpcFinalize = { - __local_proxy_id: id, + __local_proxy_id: entry.id, + __local_proxy_finalizer_id: entry.finalizerId, type: 'finalize', } this.send(rpcFinalize); @@ -294,9 +307,12 @@ export class RpcPeer { return ret; } - const { __remote_proxy_id, __local_proxy_id, __remote_constructor_name, __serialized_value, __remote_proxy_props, __remote_proxy_oneway_methods } = value; + const { __remote_proxy_id, __remote_proxy_finalizer_id, __local_proxy_id, __remote_constructor_name, __serialized_value, __remote_proxy_props, __remote_proxy_oneway_methods } = value; if (__remote_proxy_id) { - const proxy = this.remoteWeakProxies[__remote_proxy_id]?.deref() || this.newProxy(__remote_proxy_id, __remote_constructor_name, __remote_proxy_props, __remote_proxy_oneway_methods); + let proxy = this.remoteWeakProxies[__remote_proxy_id]?.deref(); + if (!proxy) + proxy = this.newProxy(__remote_proxy_id, __remote_constructor_name, __remote_proxy_props, __remote_proxy_oneway_methods); + proxy[finalizerIdSymbol] = __remote_proxy_finalizer_id; return proxy; } @@ -329,10 +345,13 @@ export class RpcPeer { let __remote_constructor_name = value.__proxy_constructor || value.constructor?.name?.toString(); - let proxyId = this.localProxied.get(value); - if (proxyId) { + let proxiedEntry = this.localProxied.get(value); + if (proxiedEntry) { + const __remote_proxy_finalizer_id = (this.proxyCounter++).toString(); + proxiedEntry.finalizerId = __remote_proxy_finalizer_id; const ret: RpcRemoteProxyValue = { - __remote_proxy_id: proxyId, + __remote_proxy_id: proxiedEntry.id, + __remote_proxy_finalizer_id, __remote_constructor_name, __remote_proxy_props: value?.[PROPERTY_PROXY_PROPERTIES], __remote_proxy_oneway_methods: value?.[PROPERTY_PROXY_ONEWAY_METHODS], @@ -355,6 +374,7 @@ export class RpcPeer { const serialized = serializer.serialize(value); const ret: RpcRemoteProxyValue = { __remote_proxy_id: undefined, + __remote_proxy_finalizer_id: undefined, __remote_constructor_name, __remote_proxy_props: value?.[PROPERTY_PROXY_PROPERTIES], __remote_proxy_oneway_methods: value?.[PROPERTY_PROXY_ONEWAY_METHODS], @@ -363,12 +383,17 @@ export class RpcPeer { return ret; } - proxyId = (this.proxyCounter++).toString(); - this.localProxied.set(value, proxyId); - this.localProxyMap[proxyId] = value; + const __remote_proxy_id = (this.proxyCounter++).toString(); + proxiedEntry = { + id: __remote_proxy_id, + finalizerId: __remote_proxy_id, + }; + this.localProxied.set(value, proxiedEntry); + this.localProxyMap[__remote_proxy_id] = value; const ret: RpcRemoteProxyValue = { - __remote_proxy_id: proxyId, + __remote_proxy_id, + __remote_proxy_finalizer_id: __remote_proxy_id, __remote_constructor_name, __remote_proxy_props: value?.[PROPERTY_PROXY_PROPERTIES], __remote_proxy_oneway_methods: value?.[PROPERTY_PROXY_ONEWAY_METHODS], @@ -378,12 +403,16 @@ export class RpcPeer { } newProxy(proxyId: string, proxyConstructorName: string, proxyProps: any, proxyOneWayMethods: string[]) { - const rpc = new RpcProxy(this, proxyId, proxyConstructorName, proxyProps, proxyOneWayMethods); + const localProxiedEntry: LocalProxiedEntry = { + id: proxyId, + finalizerId: undefined, + } + const rpc = new RpcProxy(this, localProxiedEntry, proxyConstructorName, proxyProps, proxyOneWayMethods); const target = proxyConstructorName === 'Function' || proxyConstructorName === 'AsyncFunction' ? function () { } : rpc; const proxy = new Proxy(target, rpc); const weakref = new WeakRef(proxy); this.remoteWeakProxies[proxyId] = weakref; - this.finalizers.register(rpc, proxyId); + this.finalizers.register(rpc, localProxiedEntry); global.gc?.(); return proxy; } @@ -460,8 +489,16 @@ export class RpcPeer { case 'finalize': { const rpcFinalize = message as RpcFinalize; const local = this.localProxyMap[rpcFinalize.__local_proxy_id]; - delete this.localProxyMap[rpcFinalize.__local_proxy_id]; - this.localProxied.delete(local); + if (local) { + const localProxiedEntry = this.localProxied.get(local); + // if a finalizer id is specified, it must match. + if (rpcFinalize.__local_proxy_finalizer_id && rpcFinalize.__local_proxy_finalizer_id !== localProxiedEntry?.finalizerId) { + console.error(this.selfName, this.peerName, 'finalizer mismatch') + break; + } + delete this.localProxyMap[rpcFinalize.__local_proxy_id]; + this.localProxied.delete(local); + } break; } case 'oob': {