detection: support libav throughout

This commit is contained in:
Koushik Dutta
2023-02-11 08:24:45 -08:00
parent d1d4f10039
commit cb9c5f26a9
6 changed files with 97 additions and 154 deletions

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
from time import sleep
from detect import DetectionSession, DetectPlugin
from typing import Any, List
from typing import Any, List, Tuple
import numpy as np
import cv2
import imutils
from gi.repository import Gst
from scrypted_sdk.types import ObjectDetectionModel, ObjectDetectionResult, ObjectsDetected, Setting
from PIL import Image
class OpenCVDetectionSession(DetectionSession):
def __init__(self) -> None:
@@ -45,11 +45,11 @@ class OpenCVPlugin(DetectPlugin):
self.pixelFormat = "BGRA"
self.pixelFormatChannelCount = 4
async def getDetectionModel(self) -> ObjectDetectionModel:
d: ObjectDetectionModel = {
'name': '@scrypted/opencv',
'classes': ['motion'],
}
def getClasses(self) -> list[str]:
return ['motion']
async def getSettings(self) -> list[Setting]:
settings = [
{
'title': "Motion Area",
@@ -99,8 +99,8 @@ class OpenCVPlugin(DetectPlugin):
],
}
]
d['settings'] = settings
return d
return settings
def get_pixel_format(self):
return self.pixelFormat
@@ -175,9 +175,6 @@ class OpenCVPlugin(DetectPlugin):
return detection_result
def run_detection_jpeg(self, detection_session: DetectionSession, image_bytes: bytes, min_score: float) -> ObjectsDetected:
raise Exception('can not run motion detection on image')
def get_detection_input_size(self, src_size):
# The initial implementation of this plugin used BGRA
# because it seemed impossible to pull the Y frame out of I420 without corruption.
@@ -211,6 +208,24 @@ class OpenCVPlugin(DetectPlugin):
detection_session.cap = None
return super().end_session(detection_session)
def run_detection_image(self, detection_session: DetectionSession, image: Image.Image, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
# todo
raise Exception('can not run motion detection on image')
def run_detection_avframe(self, detection_session: DetectionSession, avframe, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
format = avframe.format
if format != 'yuv420p' or format != 'yuvj420p':
format = 'yuvj420p'
else:
format = None
mat = avframe.to_ndarray(format=format)
detections = self.detect(
detection_session, mat, settings, src_size, convert_to_src_size)
if not detections or not len(detections['detections']):
self.detection_sleep(settings)
return None, None
return detections, None
def run_detection_gstsample(self, detection_session: OpenCVDetectionSession, gst_sample, settings: Any, src_size, convert_to_src_size) -> ObjectsDetected:
buf = gst_sample.get_buffer()
caps = gst_sample.get_caps()

View File

@@ -1,4 +1,5 @@
# plugin
numpy>=1.16.2
PyGObject>=3.30.4
imutils>=0.5.0
av>=10.0.0; sys_platform != 'linux' or platform_machine == 'x86_64' or platform_machine == 'aarch64'
imutils>=0.5.0

View File

@@ -47,55 +47,6 @@ def optional_chain(root, *keys):
return result
class PipelineValve:
allowPacketCounter: int
def __init__(self, gst, name) -> None:
self.allowPacketCounter = 1
self.mutex = threading.Lock()
valve = gst.get_by_name(name + "Valve")
self.pad = valve.get_static_pad("src")
self.name = name
needRemove = False
def probe(pad, info):
nonlocal needRemove
if needRemove:
self.close()
return Gst.PadProbeReturn.DROP
# REMOVE - remove this probe, passing the data.
needRemove = True
return Gst.PadProbeReturn.PASS
# need one buffer to go through to go into flowing state
self.probe = self.pad.add_probe(
Gst.PadProbeType.BLOCK | Gst.PadProbeType.BUFFER | Gst.PadProbeType.BUFFER_LIST, probe)
def open(self):
with self.mutex:
if self.probe != None:
self.pad.remove_probe(self.probe)
self.probe = None
def close(self):
with self.mutex:
if self.probe != None:
self.pad.remove_probe(self.probe)
self.probe = None
def probe(pad, info):
return Gst.PadProbeReturn.OK
self.probe = self.pad.add_probe(
Gst.PadProbeType.BLOCK | Gst.PadProbeType.BUFFER | Gst.PadProbeType.BUFFER_LIST, probe)
def setupPipelineValve(name: str, gst: Any) -> PipelineValve:
ret = PipelineValve(gst, name)
return ret
class DetectionSession:
id: str
timerHandle: TimerHandle
@@ -110,9 +61,7 @@ class DetectionSession:
self.timerHandle = None
self.future = Future()
self.running = False
self.attached = False
self.mutex = threading.Lock()
self.valve: PipelineValve = None
self.last_sample = time.time()
def clearTimeoutLocked(self):
@@ -161,33 +110,42 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
async def putSetting(self, key: str, value: scrypted_sdk.SettingValue) -> None:
pass
def getClasses(self) -> list[str]:
pass
def get_input_details(self) -> Tuple[int, int, int]:
pass
async def getDetectionModel(self, settings: Any = None) -> ObjectDetectionModel:
height, width, channels = self.get_input_details()
d: ObjectDetectionModel = {
'name': self.pluginId,
'classes': list(self.labels.values()),
'inputSize': [int(width), int(height), int(channels)],
'classes': self.getClasses(),
'inputSize': self.get_input_details(),
'settings': [],
}
decoderSetting: Setting = {
'title': "Decoder",
'description': "The gstreamer element used to decode the stream",
'combobox': True,
'value': 'Default',
'placeholder': 'Default',
'key': 'decoder',
'choices': [
'Default',
'decodebin',
'vtdec_hw',
'nvh264dec',
'vaapih264dec',
],
}
if Gst:
decoderSetting: Setting = {
'title': "Decoder",
'description': "The toolto use to decode the stream. The may be libav or the gstreamer element.",
'combobox': True,
'value': 'Default',
'placeholder': 'Default',
'key': 'decoder',
'choices': [
'Default',
'decodebin',
'vtdec_hw',
'nvh264dec',
'vaapih264dec',
],
}
if av:
decoderSetting['choices'].append('libav')
d['settings'].append(decoderSetting)
d['settings'] = [
decoderSetting,
]
return d
async def detection_event(self, detection_session: DetectionSession, detection_result: ObjectsDetected, redetect: Any = None, mediaObject = None):
@@ -209,19 +167,11 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
print('detection ended', detection_session.id)
detection_session.clearTimeout()
if detection_session.attached:
with detection_session.mutex:
if detection_session.running:
print("choked session", detection_session.id)
detection_session.running = False
if detection_session.valve:
detection_session.valve.close()
else:
# leave detection_session.running as True to avoid race conditions.
# the removal from detection_sessions will restart it.
safe_set_result(detection_session.loop, detection_session.future)
with self.session_mutex:
self.detection_sessions.pop(detection_session.id, None)
# leave detection_session.running as True to avoid race conditions.
# the removal from detection_sessions will restart it.
safe_set_result(detection_session.loop, detection_session.future)
with self.session_mutex:
self.detection_sessions.pop(detection_session.id, None)
detection_result: ObjectsDetected = {}
detection_result['running'] = False
@@ -247,6 +197,10 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
def run_detection_gstsample(self, detection_session: DetectionSession, gst_sample, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
pass
def run_detection_avframe(self, detection_session: DetectionSession, avframe, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
pil: Image.Image = avframe.to_image()
return self.run_detection_image(detection_session, pil, settings, src_size, convert_to_src_size)
def run_detection_image(self, detection_session: DetectionSession, image: Image.Image, settings: Any, src_size, convert_to_src_size) -> Tuple[ObjectsDetected, Any]:
pass
@@ -315,13 +269,6 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
if not new_session:
print("existing session", detection_session.id)
if detection_session.attached:
with detection_session.mutex:
if not detection_session.running:
print("unchoked session", detection_session.id)
detection_session.running = True
if detection_session.valve:
detection_session.valve.open()
return (False, detection_session, self.create_detection_result_status(detection_id, detection_session.running))
return (True, detection_session, None)
@@ -363,8 +310,24 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
container = j.get('container', None)
videosrc = j['url']
if not Gst:
user_callback = self.create_user_callback(self.run_detection_image, detection_session, duration)
decoder = settings and settings.get('decoder')
if decoder == 'libav' and not av:
decoder = None
elif decoder != 'libav' and not Gst:
decoder = None
decoder = decoder or 'Default'
if decoder == 'Default':
if Gst:
if platform.system() == 'Darwin':
decoder = 'vtdec_hw'
else:
decoder = 'decodebin'
elif av:
decoder = 'libav'
if decoder == 'libav':
user_callback = self.create_user_callback(self.run_detection_avframe, detection_session, duration)
async def inference_loop():
options = {
@@ -388,11 +351,11 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
# print('too slow, skipping frame')
continue
# print(frame)
pil: Image.Image = frame.to_image()
def convert_to_src_size(point, normalize):
size = (frame.width, frame.height)
def convert_to_src_size(point, normalize = False):
x, y = point
return (int(math.ceil(x)), int(math.ceil(y)), True)
await user_callback(pil, pil.size, convert_to_src_size)
await user_callback(frame, size, convert_to_src_size)
def thread_main():
loop = asyncio.new_event_loop()
@@ -401,6 +364,9 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
thread = threading.Thread(target=thread_main)
thread.start()
return self.create_detection_result_status(detection_id, True)
if not Gst:
raise Exception('Gstreamer is unavailable')
videoCodec = optional_chain(j, 'mediaStreamOptions', 'video', 'codec')
@@ -419,14 +385,6 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
if videoCodec == 'h264':
videosrc += ' ! rtph264depay ! h264parse'
decoder = settings and settings.get('decoder', 'decodebin')
decoder = decoder or 'Default'
if decoder == 'Default':
if platform.system() == 'Darwin':
decoder = 'vtdec_hw'
else:
decoder = 'decodebin'
# decoder = 'decodebin'
videosrc += " ! %s" % decoder
width = optional_chain(j, 'mediaStreamOptions',
@@ -533,43 +491,10 @@ class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
return user_callback
def attach_pipeline(self, gstPipeline: GstPipelineBase, session: ObjectDetectionSession, valveName: str = None):
create, detection_session, objects_detected = self.ensure_session(
'video/dummy', session)
if detection_session and valveName:
valve = setupPipelineValve(valveName, gstPipeline.gst)
detection_session.valve = valve
if not create:
return create, detection_session, objects_detected, None
detection_session.attached = True
duration = None
if session:
duration = session.get('duration', None)
pipeline = GstPipeline(gstPipeline.loop, gstPipeline.finished, type(
self).__name__, self.create_user_callback(self.run_detection_gstsample, detection_session, duration))
pipeline.attach_launch(gstPipeline.gst)
return create, detection_session, objects_detected, pipeline
def detach_pipeline(self, detection_id: str):
detection_session: DetectionSession = None
with self.session_mutex:
detection_session = self.detection_sessions.pop(detection_id)
if not detection_session:
raise Exception("pipeline already detached?")
with detection_session.mutex:
detection_session.running = False
detection_session.clearTimeout()
def run_pipeline(self, detection_session: DetectionSession, duration, src_size, video_input):
inference_size = self.get_detection_input_size(src_size)
pipeline = run_pipeline(detection_session.loop, detection_session.future, self.create_user_callback(detection_session, duration),
pipeline = run_pipeline(detection_session.loop, detection_session.future, self.create_user_callback(self.run_detection_gstsample, detection_session, duration),
appsink_name=type(self).__name__,
appsink_size=inference_size,
video_input=video_input,

View File

@@ -111,6 +111,9 @@ class PredictPlugin(DetectPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Set
loop = asyncio.get_event_loop()
loop.call_later(4 * 60 * 60, lambda: self.requestRestart())
def getClasses(self) -> list[str]:
return list(self.labels.values())
async def createMedia(self, data: RawImage) -> scrypted_sdk.MediaObject:
mo = await scrypted_sdk.mediaManager.createMediaObject(data, self.fromMimeType)
return mo

View File

@@ -5,7 +5,6 @@ numpy>=1.16.2
Pillow>=5.4.1
pycoral~=2.0
PyGObject>=3.30.4; sys_platform != 'win32'
# roughly the available wheels.
av>=10.0.0; sys_platform != 'linux' or platform_machine == 'x86_64' or platform_machine == 'aarch64'
tflite-runtime==2.5.0.post1

View File

@@ -78,7 +78,7 @@ class TensorFlowLitePlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted
with self.mutex:
_, height, width, channels = self.interpreter.get_input_details()[
0]['shape']
return height, width, channels
return int(width), int(height), int(channels)
def get_input_size(self) -> Tuple[float, float]:
return input_size(self.interpreter)