reimplement with ptpython repl

This commit is contained in:
Brett Jia
2024-03-05 12:30:56 -05:00
parent 424f91c696
commit 77ce4f4d2e
2 changed files with 82 additions and 240 deletions

View File

@@ -1,162 +1,22 @@
import ast
import asyncio
from asyncio import futures
import code
import concurrent
import inspect
import os
import platform
import pty
import signal
import concurrent.futures
from prompt_toolkit import print_formatted_text
from prompt_toolkit.contrib.telnet.server import TelnetServer
from ptpython.repl import embed, PythonRepl
import socket
import sys
import telnetlib
import threading
import traceback
import types
from typing import List, Dict, Any
from scrypted_python.scrypted_sdk import ScryptedStatic, ScryptedDevice
from rpc import maybe_await
import connect_to_repl
def is_pid_alive(pid):
if platform.system() == 'Windows':
# On Windows, use os.kill with signal 0 to check if the process exists
import ctypes
kernel32 = ctypes.windll.kernel32
handle = kernel32.OpenProcess(1, 0, pid)
if handle:
kernel32.CloseHandle(handle)
return True
else:
return False
else:
# On Unix-like systems, use os.kill with signal 0 to check if the process exists
try:
os.kill(pid, 0)
except OSError:
return False
else:
return True
# This section is a bit of a hack - the REPL's eval capabilities triggers
# sys.displayhook to print the result of the eval. We want to capture the
# result and send it to the correct Scrypted REPL connection instead of printing
# it to the default Scrypted server console.
REPL_WRITER_KEY = "__scrypted_repl_writer__"
default_displayhook = sys.displayhook
def repl_displayhook(value):
stack = inspect.stack()
writer = None
for f in stack:
if REPL_WRITER_KEY in f.frame.f_locals:
writer = f.frame.f_locals[REPL_WRITER_KEY]
break
if not writer:
default_displayhook(value)
return
writer.write(repr(value) + "\n")
writer.flush()
sys.displayhook = repl_displayhook
class REPL(code.InteractiveConsole):
# based on AsyncIOInteractiveConsole and InteractiveConsole from Python source code
def __init__(self, locals, loop, reader, writer):
super().__init__(locals)
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
self.loop = loop
self.reader = reader
self.writer = writer
def runcode(self, code):
future = concurrent.futures.Future()
def callback():
self.repl_future = None
self.repl_future_interrupted = False
func = types.FunctionType(code, self.locals)
try:
coro = func()
except SystemExit:
raise
except KeyboardInterrupt as ex:
self.repl_future_interrupted = True
future.set_exception(ex)
return
except BaseException as ex:
future.set_exception(ex)
return
if not inspect.iscoroutine(coro):
future.set_result(coro)
return
try:
self.repl_future = self.loop.create_task(coro)
futures._chain_future(self.repl_future, future)
except BaseException as exc:
future.set_exception(exc)
self.loop.call_soon_threadsafe(callback)
try:
result = future.result()
return result
except SystemExit:
raise
except BaseException:
if self.repl_future_interrupted:
self.write("\nKeyboardInterrupt\n")
else:
self.showtraceback()
def showsyntaxerror(self, filename=None):
type, value, tb = sys.exc_info()
sys.last_type = type
sys.last_value = value
sys.last_traceback = tb
if filename and type is SyntaxError:
# Work hard to stuff the correct filename in the exception
try:
msg, (dummy_filename, lineno, offset, line) = value.args
except ValueError:
# Not the format we expect; leave it alone
pass
else:
# Stuff in the right filename
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_value = value
lines = traceback.format_exception_only(type, value)
self.write(''.join(lines))
def showtraceback(self) -> types.NoneType:
sys.last_type, sys.last_value, last_tb = ei = sys.exc_info()
sys.last_traceback = last_tb
try:
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
self.write(''.join(lines))
finally:
last_tb = ei = None
def raw_input(self, prompt: str = "") -> str:
self.write(prompt)
while not self.reader.closed:
try:
return self.reader.readline()
except:
pass
def write(self, data: str) -> None:
self.writer.write(data)
self.writer.flush()
def configure(repl: PythonRepl) -> None:
repl.confirm_exit = False
repl.enable_system_bindings = False
repl.enable_mouse_support = False
async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
@@ -164,9 +24,14 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
systemManager = sdk.systemManager
mediaManager = sdk.mediaManager
async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
filter = await reader.read(4096)
filter = filter.decode("utf-8").strip()
# Create the proxy server to handle initial control messages
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(None)
sock.bind(('localhost', 0))
sock.listen(1)
async def start_telnet_repl(future, filter) -> None:
if filter == "undefined":
filter = None
@@ -187,95 +52,85 @@ async def createREPLServer(sdk: ScryptedStatic, plugin: ScryptedDevice) -> int:
realDevice = systemManager.getDeviceById(device.id)
loop = asyncio.get_event_loop()
# Select a free port for the telnet server
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('localhost', 0))
telnet_port = s.getsockname()[1]
s.close()
# start tcp server
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('localhost', 0))
sock.listen()
sock.settimeout(None)
def repl_thread():
conn, addr = sock.accept()
conn_reader = conn.makefile("r")
conn_writer = conn.makefile("w")
builtins = {}
builtins.update(__builtins__)
# redirect print to our repl connection
builtins["print"] = lambda *args, **kwargs: print(*args, **kwargs, file=conn_writer)
# these builtins cause problems with the repl
del builtins["input"]
del builtins["help"]
del builtins["license"]
locals = {
async def interact(connection) -> None:
global_dict = {**globals(), "print": print_formatted_text}
locals_dict = {
"device": device,
"realDevice": realDevice,
"sdk": sdk,
"mediaManager": mediaManager,
"systemManager": systemManager,
"deviceManager": deviceManager,
"mediaManager": mediaManager,
"sdk": sdk,
"realDevice": realDevice
}
vars_prompt = '\n'.join([f" {k}" for k in locals.keys()])
vars_prompt = '\n'.join([f" {k}" for k in locals_dict.keys()])
banner = f"Python REPL variables:\n{vars_prompt}"
console = REPL(
locals={
**locals,
REPL_WRITER_KEY: conn_writer,
"__builtins__": builtins,
},
loop=loop,
reader=conn_reader,
writer=conn_writer,
)
console.interact(banner=banner)
conn.close()
t = threading.Thread(target=repl_thread, daemon=True)
t.start()
print_formatted_text(banner)
await embed(return_asyncio_coroutine=True, globals=global_dict, locals=locals_dict, configure=configure)
addr = sock.getsockname()
port = addr[1]
# Start the REPL server
telnet_server = TelnetServer(interact=interact, port=telnet_port, enable_cpr=False)
telnet_server.start()
print(f"Running telnet server on port {telnet_port}...")
# fork a pty and subprocess to connect to the repl
pid, fd = pty.fork()
if pid == 0:
# child
os.execv(sys.executable, [sys.executable, connect_to_repl.__file__, "localhost", str(port)])
future.set_result(telnet_port)
# read from p in separate thread
q = asyncio.Queue()
def reader_thread():
while is_pid_alive(pid):
try:
data = os.read(fd, 4096)
loop.call_soon_threadsafe(q.put_nowait, data)
except:
pass
loop.call_soon_threadsafe(q.put_nowait, None)
t = threading.Thread(target=reader_thread, daemon=True)
t.start()
loop = asyncio.get_event_loop()
async def forward():
def handle_connection(conn):
filter = conn.recv(1024).decode()
print(f"Filter: {filter}")
future = concurrent.futures.Future()
loop.call_soon_threadsafe(loop.create_task, start_telnet_repl(future, filter))
telnet_port = future.result()
telnet_client = telnetlib.Telnet('localhost', telnet_port, timeout=None)
def telnet_negotiation_cb(telnet_socket, command, option):
pass # ignore telnet negotiation
telnet_client.set_option_negotiation_callback(telnet_negotiation_cb)
print('Connected to telnet server')
# initialize telnet terminal
telnet_client.get_socket().sendall(b'\xff\xfb\x18\xff\xfa\x18\x00\x61\x6e\x73\x69\xff\xf0')
telnet_client.get_socket().sendall(b'\r\n')
#telnet_client.get_socket().sendall(b'\xff\xfa\x18\x39\x36\x2c\x32\x34\xff\xf0')
#telnet_client.get_socket().sendall(b'\r\n')
# Bridge the connection to the telnet server, two way
def forward_to_telnet():
while True:
data = await reader.read(4096)
data = conn.recv(1024)
if not data:
break
os.write(fd, data)
async def backward():
telnet_client.write(data)
def forward_to_socket():
while True:
data = await q.get()
data = telnet_client.read_some()
if not data:
conn.sendall('REPL exited'.encode())
break
writer.write(data)
await writer.drain()
await asyncio.gather(forward(), backward())
os.kill(pid, signal.SIGKILL)
print(data)
conn.sendall(data)
server = await asyncio.start_server(handler, 'localhost', 0)
addr = server.sockets[0].getsockname()
port = addr[1]
return port
threading.Thread(target=forward_to_telnet).start()
threading.Thread(target=forward_to_socket).start()
def accept_connection():
while True:
conn, addr = sock.accept()
threading.Thread(target=handle_connection, args=(conn,)).start()
threading.Thread(target=accept_connection).start()
proxy_port = sock.getsockname()[1]
print(f"Running proxy server on port {proxy_port}...")
return proxy_port