tensorflow-lite: refactor

This commit is contained in:
Koushik Dutta
2021-12-10 11:39:31 -08:00
parent 37e420ed2f
commit 13207fc57f
5 changed files with 255 additions and 233 deletions

View File

@@ -1,6 +1,6 @@
{
"name": "@scrypted/tensorflow-lite",
"description": "Object Detection Service.",
"description": "Scrypted Object Detection Service.",
"keywords": [
"scrypted",
"plugin",
@@ -27,7 +27,7 @@
"scrypted-webpack": "scrypted-webpack"
},
"scrypted": {
"name": "TensorFlow Lite + Coral TPU Object Detection Plugin",
"name": "TensorFlow Lite (and Coral) Object Detection",
"runtime": "python",
"type": "API",
"interfaces": [

View File

@@ -0,0 +1,209 @@
from __future__ import annotations
from asyncio.events import AbstractEventLoop, TimerHandle
from asyncio.futures import Future
from typing import Any, Mapping, List
from .safe_set_result import safe_set_result
import scrypted_sdk
import json
import asyncio
import time
import os
import binascii
from urllib.parse import urlparse
import multiprocessing
from . import gstreamer
from scrypted_sdk.types import FFMpegInput, MediaObject, ObjectDetection, ObjectDetectionModel, ObjectDetectionSession, ObjectsDetected, ScryptedInterface, ScryptedMimeTypes
class DetectionSession:
id: str
timerHandle: TimerHandle
future: Future
loop: AbstractEventLoop
score_threshold: float
running: bool
thread: Any
def __init__(self) -> None:
self.timerHandle = None
self.future = Future()
self.running = False
def cancel(self):
if self.timerHandle:
self.timerHandle.cancel()
self.timerHandle = None
def timedOut(self):
safe_set_result(self.future)
def setTimeout(self, duration: float):
self.cancel()
self.loop.call_later(duration, lambda: self.timedOut())
class DetectPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
# derp these are class statics, fix this
detection_sessions: Mapping[str, DetectionSession] = {}
session_mutex = multiprocessing.Lock()
def __init__(self, nativeId: str | None = None):
super().__init__(nativeId=nativeId)
async def getInferenceModels(self) -> list[ObjectDetectionModel]:
ret: List[ObjectDetectionModel] = []
d = {
'id': 'opencv',
'name': 'OpenCV',
'classes': ['motion'],
# 'inputShape': [int(width), int(height), int(channels)],
}
ret.append(d)
return ret
def detection_event(self, detection_session: DetectionSession, detection_result: ObjectsDetected, event_buffer: bytes = None):
detection_result['detectionId'] = detection_session.id
detection_result['timestamp'] = int(time.time() * 1000)
asyncio.run_coroutine_threadsafe(self.onDeviceEvent(
ScryptedInterface.ObjectDetection.value, detection_result), loop=detection_session.loop)
def end_session(self, detection_session: DetectionSession):
print('detection ended', detection_session.id)
detection_session.cancel()
safe_set_result(detection_session.future)
with self.session_mutex:
self.detection_sessions.pop(detection_session.id, None)
detection_result: ObjectsDetected = {}
detection_result['running'] = False
detection_result['timestamp'] = int(time.time() * 1000)
self.detection_event(detection_session, detection_result)
def create_detection_result_status(self, detection_id: str, running: bool):
detection_result: ObjectsDetected = {}
detection_result['detectionId'] = detection_id
detection_result['running'] = running
detection_result['timestamp'] = int(time.time() * 1000)
return detection_result
def run_detection_jpeg(self, detection_session: DetectionSession, image_bytes: bytes, min_score: float) -> ObjectsDetected:
pass
def get_detection_input_size(self, src_size):
pass
def create_detection_session(self):
return DetectionSession()
def run_detection_gstsample(self, detection_session: DetectionSession, gst_sample, min_score: float, src_size, inference_box, scale)-> ObjectsDetected:
pass
async def detectObjects(self, mediaObject: MediaObject, session: ObjectDetectionSession = None) -> ObjectsDetected:
score_threshold = None
duration = None
detection_id = None
detection_session = None
if session:
detection_id = session.get('detectionId', None)
duration = session.get('duration', None)
score_threshold = session.get('minScore', None)
is_image = mediaObject and mediaObject.mimeType.startswith('image/')
ending = False
with self.session_mutex:
if not is_image and not detection_id:
detection_id = binascii.b2a_hex(os.urandom(15)).decode('utf8')
if detection_id:
detection_session = self.detection_sessions.get(
detection_id, None)
if not duration and not is_image:
ending = True
elif detection_id and not detection_session:
if not mediaObject:
raise Exception(
'session %s inactive and no mediaObject provided' % detection_id)
detection_session = self.create_detection_session()
detection_session.id = detection_id
detection_session.score_threshold = score_threshold or - \
float('inf')
loop = asyncio.get_event_loop()
detection_session.loop = loop
self.detection_sessions[detection_id] = detection_session
detection_session.future.add_done_callback(
lambda _: self.end_session(detection_session))
if ending:
if detection_session:
self.end_session(detection_session)
return self.create_detection_result_status(detection_id, False)
if is_image:
return self.run_detection_jpeg(detection_session, bytes(await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, 'image/jpeg')), score_threshold)
new_session = not detection_session.running
if new_session:
detection_session.running = True
detection_session.setTimeout(duration / 1000)
if score_threshold != None:
detection_session.score_threshold = score_threshold
if not new_session:
print("existing session", detection_session.id)
return self.create_detection_result_status(detection_id, detection_session.running)
print('detection starting', detection_id)
b = await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, ScryptedMimeTypes.MediaStreamUrl.value)
s = b.decode('utf8')
j: FFMpegInput = json.loads(s)
container = j.get('container', None)
videosrc = j['url']
if container == 'mpegts' and videosrc.startswith('tcp://'):
parsed_url = urlparse(videosrc)
videosrc = 'tcpclientsrc port=%s host=%s ! tsdemux' % (
parsed_url.port, parsed_url.hostname)
size = j['mediaStreamOptions']['video']
src_size = (size['width'], size['height'])
self.run_pipeline(detection_session, duration, src_size, videosrc)
return self.create_detection_result_status(detection_id, True)
def get_pixel_format(self):
return 'RGB'
def run_pipeline(self, detection_session: DetectionSession, duration, src_size, video_input):
inference_size = self.get_detection_input_size(src_size)
width, height = inference_size
w, h = src_size
scale = (width / w, height / h)
def user_callback(gst_sample, src_size, inference_box):
try:
detection_result = self.run_detection_gstsample(
detection_session, gst_sample, detection_session.score_threshold, src_size, inference_box, scale)
if detection_result:
self.detection_event(detection_session, detection_result)
if not detection_session or not duration:
safe_set_result(detection_session.future)
finally:
pass
pipeline = gstreamer.run_pipeline(detection_session.future, user_callback,
src_size,
appsink_size=inference_size,
video_input=video_input,
pixel_format=self.get_pixel_format())
task = pipeline.run()
asyncio.ensure_future(task)

View File

@@ -13,14 +13,13 @@
# limitations under the License.
from asyncio.futures import Future
import sys
import threading
import gi
gi.require_version('Gst', '1.0')
gi.require_version('GstBase', '1.0')
from safe_set_result import safe_set_result
from detect.safe_set_result import safe_set_result
from gi.repository import GLib, GObject, Gst
GObject.threads_init()
@@ -123,16 +122,7 @@ class GstPipeline:
gstsample = self.gstsample
self.gstsample = None
# Passing Gst.Buffer as input tensor avoids 2 copies of it.
gstbuffer = gstsample.get_buffer()
svg = self.user_function(gstbuffer, self.src_size, self.get_box())
if svg:
if self.overlay:
self.overlay.set_property('data', svg)
if self.gloverlay:
self.gloverlay.emit('set-svg', svg, gstbuffer.pts)
if self.overlaysink:
self.overlaysink.set_property('svg', svg)
self.user_function(gstsample, self.src_size, self.get_box())
def get_dev_board_model():
try:
@@ -148,45 +138,25 @@ def run_pipeline(finished,
user_function,
src_size,
appsink_size,
videosrc='/dev/video1',
videofmt='raw'):
if videofmt == 'h264':
SRC_CAPS = 'video/x-h264,width={width},height={height},framerate=30/1'
elif videofmt == 'jpeg':
SRC_CAPS = 'image/jpeg,width={width},height={height},framerate=30/1'
else:
SRC_CAPS = 'video/x-raw,width={width},height={height},framerate=30/1'
if videosrc.startswith('/dev/video'):
PIPELINE = 'v4l2src device=%s ! {src_caps}'%videosrc
elif videosrc.startswith('http'):
PIPELINE = 'souphttpsrc location=%s'%videosrc
elif videosrc.startswith('rtsp'):
PIPELINE = 'rtspsrc location=%s'%videosrc
else:
demux = 'avidemux' if videosrc.endswith('avi') else 'qtdemux'
PIPELINE = """filesrc location=%s ! %s name=demux demux.video_0
! queue ! decodebin ! videorate
! videoconvert n-threads=4 ! videoscale n-threads=4
! {src_caps} ! {leaky_q} """ % (videosrc, demux)
if videofmt == 'gst':
PIPELINE = videosrc
video_input,
pixel_format):
PIPELINE = video_input
scale = min(appsink_size[0] / src_size[0], appsink_size[1] / src_size[1])
scale = tuple(int(x * scale) for x in src_size)
scale_caps = 'video/x-raw,width={width},height={height}'.format(width=scale[0], height=scale[1])
# scale_caps = 'video/x-raw,width={width},height={height}'.format(width=appsink_size[0], height=appsink_size[1])
PIPELINE += """ ! decodebin ! queue leaky=downstream max-size-buffers=10 ! videoconvert ! videoscale
! {scale_caps} ! videobox name=box autocrop=true ! queue leaky=downstream max-size-buffers=1 ! {sink_caps} ! {sink_element}
"""
SINK_ELEMENT = 'appsink name=appsink emit-signals=true max-buffers=1 drop=true sync=false'
SINK_CAPS = 'video/x-raw,format=RGB,width={width},height={height}'
SINK_CAPS = 'video/x-raw,format={pixel_format},width={width},height={height}'
LEAKY_Q = 'queue max-size-buffers=100 leaky=upstream'
src_caps = SRC_CAPS.format(width=src_size[0], height=src_size[1])
sink_caps = SINK_CAPS.format(width=appsink_size[0], height=appsink_size[1])
sink_caps = SINK_CAPS.format(width=appsink_size[0], height=appsink_size[1], pixel_format=pixel_format)
pipeline = PIPELINE.format(leaky_q=LEAKY_Q,
src_caps=src_caps, sink_caps=sink_caps,
sink_caps=sink_caps,
sink_element=SINK_ELEMENT, scale_caps=scale_caps)
print('Gstreamer pipeline:\n', pipeline)

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
import matplotlib
from detect import DetectionSession, DetectPlugin
matplotlib.use('Agg')
from asyncio.events import AbstractEventLoop, TimerHandle
from asyncio.futures import Future
from typing import Any, Mapping, List
from safe_set_result import safe_set_result
from typing import List
import scrypted_sdk
import numpy as np
import re
@@ -19,20 +18,17 @@ from pycoral.adapters import detect
from PIL import Image
import common
import io
import gstreamer
import json
import asyncio
import time
import os
import binascii
from urllib.parse import urlparse
from gi.repository import Gst
import multiprocessing
from third_party.sort import Sort
import threading
from detect.safe_set_result import safe_set_result
import asyncio
from scrypted_sdk.types import FFMpegInput, Lock, MediaObject, ObjectDetection, ObjectDetectionModel, ObjectDetectionResult, ObjectDetectionSession, OnOff, ObjectsDetected, ScryptedInterface, ScryptedMimeTypes
from scrypted_sdk.types import ObjectDetectionModel, ObjectDetectionResult, ObjectsDetected
class TrackerDetectionSession(DetectionSession):
def __init__(self) -> None:
super().__init__()
self.tracker = Sort()
def parse_label_contents(contents: str):
lines = contents.splitlines()
@@ -45,39 +41,7 @@ def parse_label_contents(contents: str):
ret[row_number] = content.strip()
return ret
class DetectionSession:
id: str
timerHandle: TimerHandle
future: Future
loop: AbstractEventLoop
score_threshold: float
running: bool
thread: Any
def __init__(self) -> None:
self.timerHandle = None
self.future = Future()
self.tracker = Sort()
self.running = False
def cancel(self):
if self.timerHandle:
self.timerHandle.cancel()
self.timerHandle = None
def timedOut(self):
safe_set_result(self.future)
def setTimeout(self, duration: float):
self.cancel()
self.loop.call_later(duration, lambda: self.timedOut())
class CoralPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
detection_sessions: Mapping[str, DetectionSession] = {}
session_mutex = multiprocessing.Lock()
class CoralPlugin(DetectPlugin):
def __init__(self, nativeId: str | None = None):
super().__init__(nativeId=nativeId)
labels_contents = scrypted_sdk.zip.open(
@@ -165,160 +129,39 @@ class CoralPlugin(scrypted_sdk.ScryptedDeviceBase, ObjectDetection):
return detection_result
def detection_event(self, detection_session: DetectionSession, detection_result: ObjectsDetected, event_buffer: bytes = None):
detection_result['detectionId'] = detection_session.id
detection_result['timestamp'] = int(time.time() * 1000)
asyncio.run_coroutine_threadsafe(self.onDeviceEvent(
ScryptedInterface.ObjectDetection.value, detection_result), loop=detection_session.loop)
def run_detection_jpeg(self, detection_session: TrackerDetectionSession, image_bytes: bytes, min_score: float) -> ObjectsDetected:
stream = io.BytesIO(image_bytes)
image = Image.open(stream)
def end_session(self, detection_session: DetectionSession):
print('detection ended', detection_session.id)
detection_session.cancel()
safe_set_result(detection_session.future)
with self.session_mutex:
self.detection_sessions.pop(detection_session.id, None)
_, scale = common.set_resized_input(
self.interpreter, image.size, lambda size: image.resize(size, Image.ANTIALIAS))
detection_result: ObjectsDetected = {}
detection_result['running'] = False
detection_result['timestamp'] = int(time.time() * 1000)
tracker = None
if detection_session:
tracker = detection_session.tracker
self.detection_event(detection_session, detection_result)
with self.mutex:
self.interpreter.invoke()
objs = detect.get_objects(
self.interpreter, score_threshold=min_score or -float('inf'), image_scale=scale)
def create_detection_result_status(self, detection_id: str, running: bool):
detection_result: ObjectsDetected = {}
detection_result['detectionId'] = detection_id
detection_result['running'] = running
detection_result['timestamp'] = int(time.time() * 1000)
return detection_result
return self.create_detection_result(objs, image.size, tracker=tracker)
async def detectObjects(self, mediaObject: MediaObject, session: ObjectDetectionSession = None) -> ObjectsDetected:
score_threshold = None
duration = None
detection_id = None
detection_session = None
def get_detection_input_size(self, src_size):
return input_size(self.interpreter)
if session:
detection_id = session.get('detectionId', None)
duration = session.get('duration', None)
score_threshold = session.get('minScore', None)
def run_detection_gstsample(self, detection_session: TrackerDetectionSession, gstsample, min_score: float, src_size, inference_box, scale)-> ObjectsDetected:
gst_buffer = gstsample.get_buffer()
with self.mutex:
run_inference(self.interpreter, gst_buffer)
objs = detect.get_objects(
self.interpreter, score_threshold=min_score, image_scale=scale)
is_image = mediaObject and mediaObject.mimeType.startswith('image/')
return self.create_detection_result(objs, src_size, detection_session.tracker)
ending = False
with self.session_mutex:
if not is_image and not detection_id:
detection_id = binascii.b2a_hex(os.urandom(15)).decode('utf8')
if detection_id:
detection_session = self.detection_sessions.get(
detection_id, None)
if not duration and not is_image:
ending = True
elif detection_id and not detection_session:
if not mediaObject:
raise Exception(
'session %s inactive and no mediaObject provided' % detection_id)
detection_session = DetectionSession()
detection_session.id = detection_id
detection_session.score_threshold = score_threshold or - \
float('inf')
loop = asyncio.get_event_loop()
detection_session.loop = loop
self.detection_sessions[detection_id] = detection_session
detection_session.future.add_done_callback(
lambda _: self.end_session(detection_session))
if ending:
if detection_session:
self.end_session(detection_session)
return self.create_detection_result_status(detection_id, False)
if is_image:
stream = io.BytesIO(bytes(await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, 'image/jpeg')))
image = Image.open(stream)
_, scale = common.set_resized_input(
self.interpreter, image.size, lambda size: image.resize(size, Image.ANTIALIAS))
tracker = None
if detection_session:
tracker = detection_session.tracker
with self.mutex:
self.interpreter.invoke()
objs = detect.get_objects(
self.interpreter, score_threshold=score_threshold or -float('inf'), image_scale=scale)
return self.create_detection_result(objs, image.size, tracker=tracker)
new_session = not detection_session.running
if new_session:
detection_session.running = True
detection_session.setTimeout(duration / 1000)
if score_threshold != None:
detection_session.score_threshold = score_threshold
if not new_session:
print("existing session", detection_session.id)
return self.create_detection_result_status(detection_id, detection_session.running)
print('detection starting', detection_id)
b = await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, ScryptedMimeTypes.MediaStreamUrl.value)
s = b.decode('utf8')
j: FFMpegInput = json.loads(s)
container = j.get('container', None)
videofmt = 'raw'
videosrc = j['url']
if container == 'mpegts' and videosrc.startswith('tcp://'):
parsed_url = urlparse(videosrc)
videofmt = 'gst'
videosrc = 'tcpclientsrc port=%s host=%s ! tsdemux' % (
parsed_url.port, parsed_url.hostname)
size = j['mediaStreamOptions']['video']
inference_size = input_size(self.interpreter)
width, height = inference_size
w, h = (size['width'], size['height'])
scale = (width / w, height / h)
def user_callback(input_tensor, src_size, inference_box):
with self.mutex:
run_inference(self.interpreter, input_tensor)
objs = detect.get_objects(
self.interpreter, score_threshold=detection_session.score_threshold, image_scale=scale)
# (result, mapinfo) = input_tensor.map(Gst.MapFlags.READ)
try:
detection_result = self.create_detection_result(objs,
src_size, detection_session.tracker)
# self.detection_event(detection_session, detection_result, mapinfo.data.tobytes())
self.detection_event(detection_session, detection_result)
if not session or not duration:
safe_set_result(detection_session.future)
finally:
# input_tensor.unmap(mapinfo)
pass
pipeline = gstreamer.run_pipeline(detection_session.future, user_callback,
src_size=(
size['width'], size['height']),
appsink_size=inference_size,
videosrc=videosrc,
videofmt=videofmt)
task = pipeline.run()
asyncio.ensure_future(task)
# detection_session.thread = threading.Thread(target=lambda: pipeline.run())
# detection_session.thread.start()
return self.create_detection_result_status(detection_id, True)
def create_detection_session(self):
return TrackerDetectionSession()
def create_scrypted_plugin():
return CoralPlugin()
#