diff --git a/plugins/tensorflow-lite/package.json b/plugins/tensorflow-lite/package.json index a43a7a307..28db2fdc9 100644 --- a/plugins/tensorflow-lite/package.json +++ b/plugins/tensorflow-lite/package.json @@ -1,6 +1,6 @@ { "name": "@scrypted/tensorflow-lite", - "description": "Object Detection Service.", + "description": "Scrypted Object Detection Service.", "keywords": [ "scrypted", "plugin", @@ -27,7 +27,7 @@ "scrypted-webpack": "scrypted-webpack" }, "scrypted": { - "name": "TensorFlow Lite + Coral TPU Object Detection Plugin", + "name": "TensorFlow Lite (and Coral) Object Detection", "runtime": "python", "type": "API", "interfaces": [ diff --git a/plugins/tensorflow-lite/src/detect/__init__.py b/plugins/tensorflow-lite/src/detect/__init__.py new file mode 100644 index 000000000..9cc18c2be --- /dev/null +++ b/plugins/tensorflow-lite/src/detect/__init__.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from asyncio.events import AbstractEventLoop, TimerHandle +from asyncio.futures import Future +from typing import Any, Mapping, List +from .safe_set_result import safe_set_result +import scrypted_sdk +import json +import asyncio +import time +import os +import binascii +from urllib.parse import urlparse +import multiprocessing +from . import gstreamer + +from scrypted_sdk.types import FFMpegInput, MediaObject, ObjectDetection, ObjectDetectionModel, ObjectDetectionSession, ObjectsDetected, ScryptedInterface, ScryptedMimeTypes + +class DetectionSession: + id: str + timerHandle: TimerHandle + future: Future + loop: AbstractEventLoop + score_threshold: float + running: bool + thread: Any + + def __init__(self) -> None: + self.timerHandle = None + self.future = Future() + self.running = False + + def cancel(self): + if self.timerHandle: + self.timerHandle.cancel() + self.timerHandle = None + + def timedOut(self): + safe_set_result(self.future) + + def setTimeout(self, duration: float): + self.cancel() + self.loop.call_later(duration, lambda: self.timedOut()) + + +class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection): + # derp these are class statics, fix this + detection_sessions: Mapping[str, DetectionSession] = {} + session_mutex = multiprocessing.Lock() + + def __init__(self, nativeId: str | None = None): + super().__init__(nativeId=nativeId) + + async def getInferenceModels(self) -> list[ObjectDetectionModel]: + ret: List[ObjectDetectionModel] = [] + + d = { + 'id': 'opencv', + 'name': 'OpenCV', + 'classes': ['motion'], + # 'inputShape': [int(width), int(height), int(channels)], + } + ret.append(d) + return ret + + def detection_event(self, detection_session: DetectionSession, detection_result: ObjectsDetected, event_buffer: bytes = None): + detection_result['detectionId'] = detection_session.id + detection_result['timestamp'] = int(time.time() * 1000) + asyncio.run_coroutine_threadsafe(self.onDeviceEvent( + ScryptedInterface.ObjectDetection.value, detection_result), loop=detection_session.loop) + + def end_session(self, detection_session: DetectionSession): + print('detection ended', detection_session.id) + detection_session.cancel() + safe_set_result(detection_session.future) + with self.session_mutex: + self.detection_sessions.pop(detection_session.id, None) + + detection_result: ObjectsDetected = {} + detection_result['running'] = False + detection_result['timestamp'] = int(time.time() * 1000) + + self.detection_event(detection_session, detection_result) + + def create_detection_result_status(self, detection_id: str, running: bool): + detection_result: ObjectsDetected = {} + detection_result['detectionId'] = detection_id + detection_result['running'] = running + detection_result['timestamp'] = int(time.time() * 1000) + return detection_result + + def run_detection_jpeg(self, detection_session: DetectionSession, image_bytes: bytes, min_score: float) -> ObjectsDetected: + pass + + def get_detection_input_size(self, src_size): + pass + + def create_detection_session(self): + return DetectionSession() + + def run_detection_gstsample(self, detection_session: DetectionSession, gst_sample, min_score: float, src_size, inference_box, scale)-> ObjectsDetected: + pass + + async def detectObjects(self, mediaObject: MediaObject, session: ObjectDetectionSession = None) -> ObjectsDetected: + score_threshold = None + duration = None + detection_id = None + detection_session = None + + if session: + detection_id = session.get('detectionId', None) + duration = session.get('duration', None) + score_threshold = session.get('minScore', None) + + is_image = mediaObject and mediaObject.mimeType.startswith('image/') + + ending = False + with self.session_mutex: + if not is_image and not detection_id: + detection_id = binascii.b2a_hex(os.urandom(15)).decode('utf8') + + if detection_id: + detection_session = self.detection_sessions.get( + detection_id, None) + + if not duration and not is_image: + ending = True + elif detection_id and not detection_session: + if not mediaObject: + raise Exception( + 'session %s inactive and no mediaObject provided' % detection_id) + + detection_session = self.create_detection_session() + detection_session.id = detection_id + detection_session.score_threshold = score_threshold or - \ + float('inf') + loop = asyncio.get_event_loop() + detection_session.loop = loop + self.detection_sessions[detection_id] = detection_session + + detection_session.future.add_done_callback( + lambda _: self.end_session(detection_session)) + + if ending: + if detection_session: + self.end_session(detection_session) + return self.create_detection_result_status(detection_id, False) + + if is_image: + return self.run_detection_jpeg(detection_session, bytes(await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, 'image/jpeg')), score_threshold) + + new_session = not detection_session.running + if new_session: + detection_session.running = True + + detection_session.setTimeout(duration / 1000) + if score_threshold != None: + detection_session.score_threshold = score_threshold + + if not new_session: + print("existing session", detection_session.id) + return self.create_detection_result_status(detection_id, detection_session.running) + + print('detection starting', detection_id) + b = await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, ScryptedMimeTypes.MediaStreamUrl.value) + s = b.decode('utf8') + j: FFMpegInput = json.loads(s) + container = j.get('container', None) + videosrc = j['url'] + if container == 'mpegts' and videosrc.startswith('tcp://'): + parsed_url = urlparse(videosrc) + videosrc = 'tcpclientsrc port=%s host=%s ! tsdemux' % ( + parsed_url.port, parsed_url.hostname) + + size = j['mediaStreamOptions']['video'] + src_size = (size['width'], size['height']) + + self.run_pipeline(detection_session, duration, src_size, videosrc) + + return self.create_detection_result_status(detection_id, True) + + def get_pixel_format(self): + return 'RGB' + + def run_pipeline(self, detection_session: DetectionSession, duration, src_size, video_input): + inference_size = self.get_detection_input_size(src_size) + width, height = inference_size + w, h = src_size + scale = (width / w, height / h) + + def user_callback(gst_sample, src_size, inference_box): + try: + detection_result = self.run_detection_gstsample( + detection_session, gst_sample, detection_session.score_threshold, src_size, inference_box, scale) + if detection_result: + self.detection_event(detection_session, detection_result) + + if not detection_session or not duration: + safe_set_result(detection_session.future) + finally: + pass + + pipeline = gstreamer.run_pipeline(detection_session.future, user_callback, + src_size, + appsink_size=inference_size, + video_input=video_input, + pixel_format=self.get_pixel_format()) + task = pipeline.run() + asyncio.ensure_future(task) diff --git a/plugins/tensorflow-lite/src/gstreamer.py b/plugins/tensorflow-lite/src/detect/gstreamer.py similarity index 75% rename from plugins/tensorflow-lite/src/gstreamer.py rename to plugins/tensorflow-lite/src/detect/gstreamer.py index ef4bfb40e..aca8733bf 100644 --- a/plugins/tensorflow-lite/src/gstreamer.py +++ b/plugins/tensorflow-lite/src/detect/gstreamer.py @@ -13,14 +13,13 @@ # limitations under the License. from asyncio.futures import Future -import sys import threading import gi gi.require_version('Gst', '1.0') gi.require_version('GstBase', '1.0') -from safe_set_result import safe_set_result +from detect.safe_set_result import safe_set_result from gi.repository import GLib, GObject, Gst GObject.threads_init() @@ -123,16 +122,7 @@ class GstPipeline: gstsample = self.gstsample self.gstsample = None - # Passing Gst.Buffer as input tensor avoids 2 copies of it. - gstbuffer = gstsample.get_buffer() - svg = self.user_function(gstbuffer, self.src_size, self.get_box()) - if svg: - if self.overlay: - self.overlay.set_property('data', svg) - if self.gloverlay: - self.gloverlay.emit('set-svg', svg, gstbuffer.pts) - if self.overlaysink: - self.overlaysink.set_property('svg', svg) + self.user_function(gstsample, self.src_size, self.get_box()) def get_dev_board_model(): try: @@ -148,45 +138,25 @@ def run_pipeline(finished, user_function, src_size, appsink_size, - videosrc='/dev/video1', - videofmt='raw'): - if videofmt == 'h264': - SRC_CAPS = 'video/x-h264,width={width},height={height},framerate=30/1' - elif videofmt == 'jpeg': - SRC_CAPS = 'image/jpeg,width={width},height={height},framerate=30/1' - else: - SRC_CAPS = 'video/x-raw,width={width},height={height},framerate=30/1' - if videosrc.startswith('/dev/video'): - PIPELINE = 'v4l2src device=%s ! {src_caps}'%videosrc - elif videosrc.startswith('http'): - PIPELINE = 'souphttpsrc location=%s'%videosrc - elif videosrc.startswith('rtsp'): - PIPELINE = 'rtspsrc location=%s'%videosrc - else: - demux = 'avidemux' if videosrc.endswith('avi') else 'qtdemux' - PIPELINE = """filesrc location=%s ! %s name=demux demux.video_0 - ! queue ! decodebin ! videorate - ! videoconvert n-threads=4 ! videoscale n-threads=4 - ! {src_caps} ! {leaky_q} """ % (videosrc, demux) - - if videofmt == 'gst': - PIPELINE = videosrc + video_input, + pixel_format): + PIPELINE = video_input scale = min(appsink_size[0] / src_size[0], appsink_size[1] / src_size[1]) scale = tuple(int(x * scale) for x in src_size) scale_caps = 'video/x-raw,width={width},height={height}'.format(width=scale[0], height=scale[1]) + # scale_caps = 'video/x-raw,width={width},height={height}'.format(width=appsink_size[0], height=appsink_size[1]) PIPELINE += """ ! decodebin ! queue leaky=downstream max-size-buffers=10 ! videoconvert ! videoscale ! {scale_caps} ! videobox name=box autocrop=true ! queue leaky=downstream max-size-buffers=1 ! {sink_caps} ! {sink_element} """ SINK_ELEMENT = 'appsink name=appsink emit-signals=true max-buffers=1 drop=true sync=false' - SINK_CAPS = 'video/x-raw,format=RGB,width={width},height={height}' + SINK_CAPS = 'video/x-raw,format={pixel_format},width={width},height={height}' LEAKY_Q = 'queue max-size-buffers=100 leaky=upstream' - src_caps = SRC_CAPS.format(width=src_size[0], height=src_size[1]) - sink_caps = SINK_CAPS.format(width=appsink_size[0], height=appsink_size[1]) + sink_caps = SINK_CAPS.format(width=appsink_size[0], height=appsink_size[1], pixel_format=pixel_format) pipeline = PIPELINE.format(leaky_q=LEAKY_Q, - src_caps=src_caps, sink_caps=sink_caps, + sink_caps=sink_caps, sink_element=SINK_ELEMENT, scale_caps=scale_caps) print('Gstreamer pipeline:\n', pipeline) diff --git a/plugins/tensorflow-lite/src/safe_set_result.py b/plugins/tensorflow-lite/src/detect/safe_set_result.py similarity index 100% rename from plugins/tensorflow-lite/src/safe_set_result.py rename to plugins/tensorflow-lite/src/detect/safe_set_result.py diff --git a/plugins/tensorflow-lite/src/main.py b/plugins/tensorflow-lite/src/main.py index 036579551..d01dd1bd0 100644 --- a/plugins/tensorflow-lite/src/main.py +++ b/plugins/tensorflow-lite/src/main.py @@ -1,12 +1,11 @@ from __future__ import annotations import matplotlib + +from detect import DetectionSession, DetectPlugin matplotlib.use('Agg') -from asyncio.events import AbstractEventLoop, TimerHandle -from asyncio.futures import Future -from typing import Any, Mapping, List -from safe_set_result import safe_set_result +from typing import List import scrypted_sdk import numpy as np import re @@ -19,20 +18,17 @@ from pycoral.adapters import detect from PIL import Image import common import io -import gstreamer -import json -import asyncio -import time -import os -import binascii -from urllib.parse import urlparse -from gi.repository import Gst import multiprocessing from third_party.sort import Sort -import threading +from detect.safe_set_result import safe_set_result +import asyncio -from scrypted_sdk.types import FFMpegInput, Lock, MediaObject, ObjectDetection, ObjectDetectionModel, ObjectDetectionResult, ObjectDetectionSession, OnOff, ObjectsDetected, ScryptedInterface, ScryptedMimeTypes +from scrypted_sdk.types import ObjectDetectionModel, ObjectDetectionResult, ObjectsDetected +class TrackerDetectionSession(DetectionSession): + def __init__(self) -> None: + super().__init__() + self.tracker = Sort() def parse_label_contents(contents: str): lines = contents.splitlines() @@ -45,39 +41,7 @@ def parse_label_contents(contents: str): ret[row_number] = content.strip() return ret - -class DetectionSession: - id: str - timerHandle: TimerHandle - future: Future - loop: AbstractEventLoop - score_threshold: float - running: bool - thread: Any - - def __init__(self) -> None: - self.timerHandle = None - self.future = Future() - self.tracker = Sort() - self.running = False - - def cancel(self): - if self.timerHandle: - self.timerHandle.cancel() - self.timerHandle = None - - def timedOut(self): - safe_set_result(self.future) - - def setTimeout(self, duration: float): - self.cancel() - self.loop.call_later(duration, lambda: self.timedOut()) - - -class CoralPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection): - detection_sessions: Mapping[str, DetectionSession] = {} - session_mutex = multiprocessing.Lock() - +class CoralPlugin(DetectPlugin): def __init__(self, nativeId: str | None = None): super().__init__(nativeId=nativeId) labels_contents = scrypted_sdk.zip.open( @@ -165,160 +129,39 @@ class CoralPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection): return detection_result - def detection_event(self, detection_session: DetectionSession, detection_result: ObjectsDetected, event_buffer: bytes = None): - detection_result['detectionId'] = detection_session.id - detection_result['timestamp'] = int(time.time() * 1000) - asyncio.run_coroutine_threadsafe(self.onDeviceEvent( - ScryptedInterface.ObjectDetection.value, detection_result), loop=detection_session.loop) + def run_detection_jpeg(self, detection_session: TrackerDetectionSession, image_bytes: bytes, min_score: float) -> ObjectsDetected: + stream = io.BytesIO(image_bytes) + image = Image.open(stream) - def end_session(self, detection_session: DetectionSession): - print('detection ended', detection_session.id) - detection_session.cancel() - safe_set_result(detection_session.future) - with self.session_mutex: - self.detection_sessions.pop(detection_session.id, None) + _, scale = common.set_resized_input( + self.interpreter, image.size, lambda size: image.resize(size, Image.ANTIALIAS)) - detection_result: ObjectsDetected = {} - detection_result['running'] = False - detection_result['timestamp'] = int(time.time() * 1000) + tracker = None + if detection_session: + tracker = detection_session.tracker - self.detection_event(detection_session, detection_result) + with self.mutex: + self.interpreter.invoke() + objs = detect.get_objects( + self.interpreter, score_threshold=min_score or -float('inf'), image_scale=scale) - def create_detection_result_status(self, detection_id: str, running: bool): - detection_result: ObjectsDetected = {} - detection_result['detectionId'] = detection_id - detection_result['running'] = running - detection_result['timestamp'] = int(time.time() * 1000) - return detection_result + return self.create_detection_result(objs, image.size, tracker=tracker) - async def detectObjects(self, mediaObject: MediaObject, session: ObjectDetectionSession = None) -> ObjectsDetected: - score_threshold = None - duration = None - detection_id = None - detection_session = None + def get_detection_input_size(self, src_size): + return input_size(self.interpreter) - if session: - detection_id = session.get('detectionId', None) - duration = session.get('duration', None) - score_threshold = session.get('minScore', None) + def run_detection_gstsample(self, detection_session: TrackerDetectionSession, gstsample, min_score: float, src_size, inference_box, scale)-> ObjectsDetected: + gst_buffer = gstsample.get_buffer() + with self.mutex: + run_inference(self.interpreter, gst_buffer) + objs = detect.get_objects( + self.interpreter, score_threshold=min_score, image_scale=scale) - is_image = mediaObject and mediaObject.mimeType.startswith('image/') + return self.create_detection_result(objs, src_size, detection_session.tracker) - ending = False - with self.session_mutex: - if not is_image and not detection_id: - detection_id = binascii.b2a_hex(os.urandom(15)).decode('utf8') - - if detection_id: - detection_session = self.detection_sessions.get( - detection_id, None) - - if not duration and not is_image: - ending = True - elif detection_id and not detection_session: - if not mediaObject: - raise Exception( - 'session %s inactive and no mediaObject provided' % detection_id) - - detection_session = DetectionSession() - detection_session.id = detection_id - detection_session.score_threshold = score_threshold or - \ - float('inf') - loop = asyncio.get_event_loop() - detection_session.loop = loop - self.detection_sessions[detection_id] = detection_session - - detection_session.future.add_done_callback( - lambda _: self.end_session(detection_session)) - - if ending: - if detection_session: - self.end_session(detection_session) - return self.create_detection_result_status(detection_id, False) - - if is_image: - stream = io.BytesIO(bytes(await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, 'image/jpeg'))) - image = Image.open(stream) - - _, scale = common.set_resized_input( - self.interpreter, image.size, lambda size: image.resize(size, Image.ANTIALIAS)) - - tracker = None - if detection_session: - tracker = detection_session.tracker - - with self.mutex: - self.interpreter.invoke() - objs = detect.get_objects( - self.interpreter, score_threshold=score_threshold or -float('inf'), image_scale=scale) - - return self.create_detection_result(objs, image.size, tracker=tracker) - - new_session = not detection_session.running - if new_session: - detection_session.running = True - - detection_session.setTimeout(duration / 1000) - if score_threshold != None: - detection_session.score_threshold = score_threshold - - if not new_session: - print("existing session", detection_session.id) - return self.create_detection_result_status(detection_id, detection_session.running) - - print('detection starting', detection_id) - b = await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, ScryptedMimeTypes.MediaStreamUrl.value) - s = b.decode('utf8') - j: FFMpegInput = json.loads(s) - container = j.get('container', None) - videofmt = 'raw' - videosrc = j['url'] - if container == 'mpegts' and videosrc.startswith('tcp://'): - parsed_url = urlparse(videosrc) - videofmt = 'gst' - videosrc = 'tcpclientsrc port=%s host=%s ! tsdemux' % ( - parsed_url.port, parsed_url.hostname) - - size = j['mediaStreamOptions']['video'] - inference_size = input_size(self.interpreter) - width, height = inference_size - w, h = (size['width'], size['height']) - scale = (width / w, height / h) - - def user_callback(input_tensor, src_size, inference_box): - with self.mutex: - run_inference(self.interpreter, input_tensor) - objs = detect.get_objects( - self.interpreter, score_threshold=detection_session.score_threshold, image_scale=scale) - - # (result, mapinfo) = input_tensor.map(Gst.MapFlags.READ) - - try: - detection_result = self.create_detection_result(objs, - src_size, detection_session.tracker) - # self.detection_event(detection_session, detection_result, mapinfo.data.tobytes()) - self.detection_event(detection_session, detection_result) - - if not session or not duration: - safe_set_result(detection_session.future) - finally: - # input_tensor.unmap(mapinfo) - pass - - pipeline = gstreamer.run_pipeline(detection_session.future, user_callback, - src_size=( - size['width'], size['height']), - appsink_size=inference_size, - videosrc=videosrc, - videofmt=videofmt) - task = pipeline.run() - asyncio.ensure_future(task) - # detection_session.thread = threading.Thread(target=lambda: pipeline.run()) - # detection_session.thread.start() - - return self.create_detection_result_status(detection_id, True) + def create_detection_session(self): + return TrackerDetectionSession() def create_scrypted_plugin(): return CoralPlugin() -#