server: additional python rpc transport fixes

This commit is contained in:
Koushik Dutta
2023-03-17 23:21:07 -07:00
parent 6e08f11578
commit 997a4732ec
2 changed files with 23 additions and 8 deletions

View File

@@ -13,6 +13,7 @@ import rpc
import concurrent.futures
import json
class BufferSerializer(rpc.RpcSerializer):
def serialize(self, value, serializationContext):
return base64.b64encode(value).decode('utf8')
@@ -35,6 +36,7 @@ class SidebandBufferSerializer(rpc.RpcSerializer):
buffer = buffers.pop()
return buffer
class RpcTransport:
async def prepare(self):
pass
@@ -48,6 +50,7 @@ class RpcTransport:
def writeJSON(self, json, reject):
pass
class RpcFileTransport(RpcTransport):
reader: asyncio.StreamReader
@@ -71,7 +74,7 @@ class RpcFileTransport(RpcTransport):
return data
message = json.loads(data)
return message
def writeMessage(self, type: int, buffer, reject):
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
@@ -88,15 +91,24 @@ class RpcFileTransport(RpcTransport):
def writeBuffer(self, buffer, reject):
return self.writeMessage(1, buffer, reject)
class RpcStreamTransport(RpcTransport):
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__()
self.reader = reader
self.writer = writer
async def read(self, n: int):
return await self.reader.readexactly(n)
async def read(self):
lengthBytes = await self.reader.readexactly(4)
typeBytes = await self.reader.readexactly(1)
type = typeBytes[0]
length = int.from_bytes(lengthBytes, 'big')
data = await self.reader.readexactly(length - 1)
if type == 1:
return data
message = json.loads(data)
return message
def writeMessage(self, type: int, buffer, reject):
length = len(buffer) + 1
lb = length.to_bytes(4, 'big')
@@ -113,6 +125,7 @@ class RpcStreamTransport(RpcTransport):
def writeBuffer(self, buffer, reject):
return self.writeMessage(1, buffer, reject)
class RpcConnectionTransport(RpcTransport):
def __init__(self, connection: multiprocessing.connection.Connection) -> None:
super().__init__()
@@ -121,7 +134,7 @@ class RpcConnectionTransport(RpcTransport):
async def read(self):
return await asyncio.get_event_loop().run_in_executor(self.executor, lambda: self.connection.recv())
def writeMessage(self, json, reject):
try:
self.connection.send(json)
@@ -131,10 +144,11 @@ class RpcConnectionTransport(RpcTransport):
def writeJSON(self, json, reject):
return self.writeMessage(json, reject)
def writeBuffer(self, buffer, reject):
return self.writeMessage(bytes(buffer), reject)
async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport):
deserializationContext = {
'buffers': []
@@ -154,6 +168,7 @@ async def readLoop(loop, peer: rpc.RpcPeer, rpcTransport: RpcTransport):
'buffers': []
}
async def prepare_peer_readloop(loop: AbstractEventLoop, rpcTransport: RpcTransport):
await rpcTransport.prepare()