Files
scrypted/plugins/tensorflow-lite/src/detect/__init__.py
2021-12-10 11:39:31 -08:00

210 lines
7.9 KiB
Python

from __future__ import annotations
from asyncio.events import AbstractEventLoop, TimerHandle
from asyncio.futures import Future
from typing import Any, Mapping, List
from .safe_set_result import safe_set_result
import scrypted_sdk
import json
import asyncio
import time
import os
import binascii
from urllib.parse import urlparse
import multiprocessing
from . import gstreamer
from scrypted_sdk.types import FFMpegInput, MediaObject, ObjectDetection, ObjectDetectionModel, ObjectDetectionSession, ObjectsDetected, ScryptedInterface, ScryptedMimeTypes
class DetectionSession:
id: str
timerHandle: TimerHandle
future: Future
loop: AbstractEventLoop
score_threshold: float
running: bool
thread: Any
def __init__(self) -> None:
self.timerHandle = None
self.future = Future()
self.running = False
def cancel(self):
if self.timerHandle:
self.timerHandle.cancel()
self.timerHandle = None
def timedOut(self):
safe_set_result(self.future)
def setTimeout(self, duration: float):
self.cancel()
self.loop.call_later(duration, lambda: self.timedOut())
class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
# derp these are class statics, fix this
detection_sessions: Mapping[str, DetectionSession] = {}
session_mutex = multiprocessing.Lock()
def __init__(self, nativeId: str | None = None):
super().__init__(nativeId=nativeId)
async def getInferenceModels(self) -> list[ObjectDetectionModel]:
ret: List[ObjectDetectionModel] = []
d = {
'id': 'opencv',
'name': 'OpenCV',
'classes': ['motion'],
# 'inputShape': [int(width), int(height), int(channels)],
}
ret.append(d)
return ret
def detection_event(self, detection_session: DetectionSession, detection_result: ObjectsDetected, event_buffer: bytes = None):
detection_result['detectionId'] = detection_session.id
detection_result['timestamp'] = int(time.time() * 1000)
asyncio.run_coroutine_threadsafe(self.onDeviceEvent(
ScryptedInterface.ObjectDetection.value, detection_result), loop=detection_session.loop)
def end_session(self, detection_session: DetectionSession):
print('detection ended', detection_session.id)
detection_session.cancel()
safe_set_result(detection_session.future)
with self.session_mutex:
self.detection_sessions.pop(detection_session.id, None)
detection_result: ObjectsDetected = {}
detection_result['running'] = False
detection_result['timestamp'] = int(time.time() * 1000)
self.detection_event(detection_session, detection_result)
def create_detection_result_status(self, detection_id: str, running: bool):
detection_result: ObjectsDetected = {}
detection_result['detectionId'] = detection_id
detection_result['running'] = running
detection_result['timestamp'] = int(time.time() * 1000)
return detection_result
def run_detection_jpeg(self, detection_session: DetectionSession, image_bytes: bytes, min_score: float) -> ObjectsDetected:
pass
def get_detection_input_size(self, src_size):
pass
def create_detection_session(self):
return DetectionSession()
def run_detection_gstsample(self, detection_session: DetectionSession, gst_sample, min_score: float, src_size, inference_box, scale)-> ObjectsDetected:
pass
async def detectObjects(self, mediaObject: MediaObject, session: ObjectDetectionSession = None) -> ObjectsDetected:
score_threshold = None
duration = None
detection_id = None
detection_session = None
if session:
detection_id = session.get('detectionId', None)
duration = session.get('duration', None)
score_threshold = session.get('minScore', None)
is_image = mediaObject and mediaObject.mimeType.startswith('image/')
ending = False
with self.session_mutex:
if not is_image and not detection_id:
detection_id = binascii.b2a_hex(os.urandom(15)).decode('utf8')
if detection_id:
detection_session = self.detection_sessions.get(
detection_id, None)
if not duration and not is_image:
ending = True
elif detection_id and not detection_session:
if not mediaObject:
raise Exception(
'session %s inactive and no mediaObject provided' % detection_id)
detection_session = self.create_detection_session()
detection_session.id = detection_id
detection_session.score_threshold = score_threshold or - \
float('inf')
loop = asyncio.get_event_loop()
detection_session.loop = loop
self.detection_sessions[detection_id] = detection_session
detection_session.future.add_done_callback(
lambda _: self.end_session(detection_session))
if ending:
if detection_session:
self.end_session(detection_session)
return self.create_detection_result_status(detection_id, False)
if is_image:
return self.run_detection_jpeg(detection_session, bytes(await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, 'image/jpeg')), score_threshold)
new_session = not detection_session.running
if new_session:
detection_session.running = True
detection_session.setTimeout(duration / 1000)
if score_threshold != None:
detection_session.score_threshold = score_threshold
if not new_session:
print("existing session", detection_session.id)
return self.create_detection_result_status(detection_id, detection_session.running)
print('detection starting', detection_id)
b = await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, ScryptedMimeTypes.MediaStreamUrl.value)
s = b.decode('utf8')
j: FFMpegInput = json.loads(s)
container = j.get('container', None)
videosrc = j['url']
if container == 'mpegts' and videosrc.startswith('tcp://'):
parsed_url = urlparse(videosrc)
videosrc = 'tcpclientsrc port=%s host=%s ! tsdemux' % (
parsed_url.port, parsed_url.hostname)
size = j['mediaStreamOptions']['video']
src_size = (size['width'], size['height'])
self.run_pipeline(detection_session, duration, src_size, videosrc)
return self.create_detection_result_status(detection_id, True)
def get_pixel_format(self):
return 'RGB'
def run_pipeline(self, detection_session: DetectionSession, duration, src_size, video_input):
inference_size = self.get_detection_input_size(src_size)
width, height = inference_size
w, h = src_size
scale = (width / w, height / h)
def user_callback(gst_sample, src_size, inference_box):
try:
detection_result = self.run_detection_gstsample(
detection_session, gst_sample, detection_session.score_threshold, src_size, inference_box, scale)
if detection_result:
self.detection_event(detection_session, detection_result)
if not detection_session or not duration:
safe_set_result(detection_session.future)
finally:
pass
pipeline = gstreamer.run_pipeline(detection_session.future, user_callback,
src_size,
appsink_size=inference_size,
video_input=video_input,
pixel_format=self.get_pixel_format())
task = pipeline.run()
asyncio.ensure_future(task)