rpc: implement python async iterator

This commit is contained in:
Koushik Dutta
2023-03-02 21:03:29 -08:00
parent b2e5801426
commit 096c036ea2
7 changed files with 211 additions and 93 deletions

View File

@@ -4,7 +4,6 @@ import traceback
import inspect
from typing_extensions import TypedDict
import weakref
import sys
jsonSerializable = set()
jsonSerializable.add(float)
@@ -16,14 +15,16 @@ jsonSerializable.add(list)
async def maybe_await(value):
if (inspect.iscoroutinefunction(value) or inspect.iscoroutine(value)):
if (inspect.isawaitable(value)):
return await value
return value
class RpcResultException(Exception):
name = None
stack = None
class RPCResultError(Exception):
name: str
stack: str
message: str
caught: Exception
def __init__(self, caught, message):
self.caught = caught
@@ -85,6 +86,8 @@ class RpcProxy(object):
class RpcPeer:
RPC_RESULT_ERROR_NAME = 'RPCResultError'
def __init__(self, send: Callable[[object, Callable[[Exception], None], Dict], None]) -> None:
self.send = send
self.idCounter = 1
@@ -127,10 +130,52 @@ class RpcPeer:
def kill(self):
self.killed = True
def createErrorResult(self, result: Any, name: str, message: str, tb: str):
result['stack'] = tb if tb else 'no stack'
result['result'] = name if name else 'no name'
result['message'] = message if message else 'no message'
def createErrorResult(self, result: Any, e: Exception):
s = self.serializeError(e)
result['result'] = s
result['throw'] = True
# TODO 3/2/2023 deprecate these properties
tb = traceback.format_exc()
message = str(e)
result['stack'] = tb or '[no stack]',
result['message'] = message or '[no message]',
# END TODO
def deserializeError(e: Dict) -> RPCResultError:
error = RPCResultError(None, e.get('message'))
error.stack = e.get('stack')
error.name = e.get('name')
return error
def serializeError(self, e: Exception):
tb = traceback.format_exc()
name = type(e).__name__
message = str(e)
serialized = {
'stack': tb or '[no stack]',
'name': name or '[no name]',
'message': message or '[no message]',
}
return {
'__remote_constructor_name': RpcPeer.RPC_RESULT_ERROR_NAME,
'__serialized_value': serialized,
}
def getProxyProperties(self, value):
if not hasattr(value, '__aiter__') or not hasattr(value, '__anext__'):
return getattr(value, '__proxy_props', None)
props = getattr(value, '__proxy_props', None) or {}
props['Symbol(Symbol.asyncIterator)'] = {
'next': '__anext__',
'throw': 'athrow',
'return': 'asend',
}
return props
def serialize(self, value, requireProxy, serializationContext: Dict):
if (not value or (not requireProxy and type(value) in jsonSerializable)):
@@ -139,6 +184,9 @@ class RpcPeer:
__remote_constructor_name = 'Function' if callable(value) else value.__proxy_constructor if hasattr(
value, '__proxy_constructor') else type(value).__name__
if isinstance(value, Exception):
return self.serializeError(value)
proxiedEntry = self.localProxied.get(value, None)
if proxiedEntry:
proxiedEntry['finalizerId'] = str(self.proxyCounter)
@@ -147,7 +195,7 @@ class RpcPeer:
'__remote_proxy_id': proxiedEntry['id'],
'__remote_proxy_finalizer_id': proxiedEntry['finalizerId'],
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
}
return ret
@@ -170,7 +218,7 @@ class RpcPeer:
'__remote_proxy_id': None,
'__remote_proxy_finalizer_id': None,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
'__serialized_value': serialized,
}
@@ -189,7 +237,7 @@ class RpcPeer:
'__remote_proxy_id': proxyId,
'__remote_proxy_finalizer_id': proxyId,
'__remote_constructor_name': __remote_constructor_name,
'__remote_proxy_props': getattr(value, '__proxy_props', None),
'__remote_proxy_props': self.getProxyProperties(value),
'__remote_proxy_oneway_methods': getattr(value, '__proxy_oneway_methods', None),
}
@@ -235,6 +283,9 @@ class RpcPeer:
__remote_proxy_oneway_methods = value.get(
'__remote_proxy_oneway_methods', None)
if __remote_constructor_name == RpcPeer.RPC_RESULT_ERROR_NAME:
return self.deserializeError(__serialized_value);
if __remote_proxy_id:
weakref = self.remoteWeakProxies.get('__remote_proxy_id', None)
proxy = weakref() if weakref else None
@@ -247,7 +298,7 @@ class RpcPeer:
if __local_proxy_id:
ret = self.localProxyMap.get(__local_proxy_id, None)
if not ret:
raise RpcResultException(
raise RPCResultError(
None, 'invalid local proxy id %s' % __local_proxy_id)
return ret
@@ -258,7 +309,7 @@ class RpcPeer:
return value
async def handleMessage(self, message: Any, deserializationContext: Dict):
async def handleMessage(self, message: Dict, deserializationContext: Dict):
try:
messageType = message['type']
if messageType == 'param':
@@ -310,11 +361,10 @@ class RpcPeer:
value = await maybe_await(target(*args))
result['result'] = self.serialize(value, False, serializationContext)
except StopAsyncIteration as e:
self.createErrorResult(result, e)
except Exception as e:
tb = traceback.format_exc()
# print('failure', method, e, tb)
self.createErrorResult(
result, type(e).__name__, str(e), tb)
self.createErrorResult(result, e)
if not message.get('oneway', False):
self.send(result, None, serializationContext)
@@ -323,18 +373,21 @@ class RpcPeer:
id = message['id']
future = self.pendingResults.get(id, None)
if not future:
raise RpcResultException(
raise RPCResultError(
None, 'unknown result %s' % id)
del self.pendingResults[id]
if hasattr(message, 'message') or hasattr(message, 'stack'):
e = RpcResultException(
if (hasattr(message, 'message') or hasattr(message, 'stack')) and not hasattr(message, 'throw'):
e = RPCResultError(
None, message.get('message', None))
e.stack = message.get('stack', None)
e.name = message.get('name', None)
future.set_exception(e)
return
future.set_result(self.deserialize(
message.get('result', None), deserializationContext))
deserialized = self.deserialize(message.get('result', None), deserializationContext)
if message.get('throw'):
future.set_exception(deserialized)
else:
future.set_result(deserialized)
elif messageType == 'finalize':
finalizerId = message.get('__local_proxy_finalizer_id', None)
proxyId = message['__local_proxy_id']
@@ -347,7 +400,7 @@ class RpcPeer:
self.localProxied.pop(local, None)
local = self.localProxyMap.pop(proxyId, None)
else:
raise RpcResultException(
raise RPCResultError(
None, 'unknown rpc message type %s' % type)
except Exception as e:
print("unhandled rpc error", self.peerName, e)
@@ -361,7 +414,7 @@ class RpcPeer:
self.idCounter = self.idCounter + 1
future = Future()
self.pendingResults[id] = future
await cb(id, lambda e: future.set_exception(RpcResultException(e, None)))
await cb(id, lambda e: future.set_exception(RPCResultError(e, None)))
return await future
async def getParam(self, param):