predict: refactor, add support for yolov8 on tflite

This commit is contained in:
Koushik Dutta
2023-06-16 12:08:04 -07:00
parent b10b4d047e
commit 2b9a0f082d
10 changed files with 182 additions and 89 deletions

View File

@@ -1,12 +1,12 @@
{
"name": "@scrypted/coreml",
"version": "0.1.18",
"version": "0.1.19",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "@scrypted/coreml",
"version": "0.1.18",
"version": "0.1.19",
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
}

View File

@@ -40,5 +40,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.1.18"
"version": "0.1.19"
}

View File

@@ -39,7 +39,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Set
if model == "Default":
# model = "ssdlite_mobilenet_v2"
if "arm" in platform.processor():
model = "yolov8"
model = "yolov8n"
else:
model = "ssdlite_mobilenet_v2"
self.yolo = "yolo" in model

View File

@@ -2,7 +2,7 @@
{
// docker installation
// "scrypted.debugHost": "koushik-ubuntu",
// "scrypted.serverRoot": "/server",
"scrypted.serverRoot": "/server",
// pi local installation
// "scrypted.debugHost": "192.168.2.119",
@@ -12,7 +12,7 @@
// "scrypted.debugHost": "127.0.0.1",
// "scrypted.serverRoot": "/Users/koush/.scrypted",
"scrypted.debugHost": "koushik-windows",
"scrypted.serverRoot": "C:\\Users\\koush\\.scrypted",
// "scrypted.serverRoot": "C:\\Users\\koush\\.scrypted",
"scrypted.pythonRemoteRoot": "${config:scrypted.serverRoot}/volume/plugin.zip",
"python.analysis.extraPaths": [

View File

@@ -4,19 +4,19 @@ import numpy as np
from predict import Prediction, Rectangle
def parse_yolov8(results):
def parse_yolov8(results, scale = 1):
objs = []
keep = np.argwhere(results[4:] > 0.2)
keep = np.argwhere(results[4:] > .2)
for indices in keep:
class_id = indices[0]
index = indices[1]
confidence = results[class_id + 4, index]
x = results[0][index].astype(float)
y = results[1][index].astype(float)
w = results[2][index].astype(float)
h = results[3][index].astype(float)
x = results[0][index].astype(float) * scale
y = results[1][index].astype(float) * scale
w = results[2][index].astype(float) * scale
h = results[3][index].astype(float) * scale
obj = Prediction(
class_id,
int(class_id),
confidence.astype(float),
Rectangle(
x - w / 2,

View File

@@ -2,16 +2,16 @@
{
// docker installation
// "scrypted.debugHost": "koushik-ubuntu",
// "scrypted.serverRoot": "/server",
"scrypted.serverRoot": "/server",
// pi local installation
// "scrypted.debugHost": "192.168.2.119",
// "scrypted.serverRoot": "/home/pi/.scrypted",
// local checkout
"scrypted.debugHost": "127.0.0.1",
"scrypted.serverRoot": "/Users/koush/.scrypted",
// "scrypted.debugHost": "koushik-windows",
// "scrypted.debugHost": "127.0.0.1",
// "scrypted.serverRoot": "/Users/koush/.scrypted",
"scrypted.debugHost": "koushik-windows",
// "scrypted.serverRoot": "C:\\Users\\koush\\.scrypted",
"scrypted.pythonRemoteRoot": "${config:scrypted.serverRoot}/volume/plugin.zip",

View File

@@ -1,12 +1,12 @@
{
"name": "@scrypted/tensorflow-lite",
"version": "0.1.17",
"version": "0.1.19",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "@scrypted/tensorflow-lite",
"version": "0.1.17",
"version": "0.1.19",
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
}

View File

@@ -49,5 +49,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.1.17"
"version": "0.1.19"
}

View File

@@ -1,101 +1,166 @@
from __future__ import annotations
from .common import *
from PIL import Image
from pycoral.adapters import detect
from .common import *
loaded_py_coral = False
try:
from pycoral.utils.edgetpu import list_edge_tpus
from pycoral.utils.edgetpu import make_interpreter
from pycoral.utils.edgetpu import list_edge_tpus, make_interpreter
loaded_py_coral = True
print('coral edge tpu library loaded successfully')
print("coral edge tpu library loaded successfully")
except Exception as e:
print('coral edge tpu library load failed', e)
print("coral edge tpu library load failed", e)
pass
import tflite_runtime.interpreter as tflite
import re
import scrypted_sdk
from scrypted_sdk.types import Setting
from typing import Any, Tuple
from predict import PredictPlugin
import asyncio
import concurrent.futures
import queue
import asyncio
import re
import traceback
from typing import Any, Tuple
import scrypted_sdk
import tflite_runtime.interpreter as tflite
from scrypted_sdk.types import Setting, SettingValue
import yolo
from predict import PredictPlugin
def parse_label_contents(contents: str):
lines = contents.splitlines()
ret = {}
for row_number, content in enumerate(lines):
pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
pair = re.split(r"[:\s]+", content.strip(), maxsplit=1)
if len(pair) == 2 and pair[0].strip().isdigit():
ret[int(pair[0])] = pair[1].strip()
else:
ret[row_number] = content.strip()
return ret
class TensorFlowLitePlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Settings):
class TensorFlowLitePlugin(
PredictPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Settings
):
def __init__(self, nativeId: str | None = None):
super().__init__(nativeId=nativeId)
tfliteFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/ssd_mobilenet_v2_coco_quant_postprocess.tflite', 'ssd_mobilenet_v2_coco_quant_postprocess.tflite')
edgetpuFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite', 'ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite')
labelsFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/coco_labels.txt', 'coco_labels.txt')
edge_tpus = None
try:
edge_tpus = list_edge_tpus()
print("edge tpus", edge_tpus)
if not len(edge_tpus):
raise Exception("no edge tpu found")
except Exception as e:
print("unable to use Coral Edge TPU", e)
edge_tpus = None
pass
labels_contents = open(labelsFile, 'r').read()
model = self.storage.getItem("model") or "Default"
if model == "Default":
if edge_tpus:
model = "yolov8n_full_integer_quant"
else:
model = "ssd_mobilenet_v2_coco_quant_postprocess"
self.yolo = "yolo" in model
self.yolov8 = "yolov8" in model
print(f'model: {model}')
model_version = "v5"
if self.yolo:
labelsFile = self.downloadFile(
"https://raw.githubusercontent.com/koush/tflite-models/main/coco_80cl.txt",
f"{model_version}/coco_80cl.txt",
)
else:
labelsFile = self.downloadFile(
"https://raw.githubusercontent.com/koush/tflite-models/main/coco_labels.txt",
f"{model_version}/coco_labels.txt",
)
labels_contents = open(labelsFile, "r").read()
self.labels = parse_label_contents(labels_contents)
self.interpreters = queue.Queue()
self.interpreter_count = 0
try:
edge_tpus = list_edge_tpus()
print('edge tpus', edge_tpus)
if not len(edge_tpus):
raise Exception('no edge tpu found')
self.edge_tpu_found = str(edge_tpus)
# todo co-compile
# https://coral.ai/docs/edgetpu/compiler/#co-compiling-multiple-models
# face_model = scrypted_sdk.zip.open(
# 'fs/mobilenet_ssd_v2_face_quant_postprocess.tflite').read()
for idx, edge_tpu in enumerate(edge_tpus):
try:
interpreter = make_interpreter(edgetpuFile, ":%s" % idx)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[
0]['shape']
self.input_details = int(width), int(height), int(channels)
self.interpreters.put(interpreter)
self.interpreter_count = self.interpreter_count + 1
print('added tpu %s' % (edge_tpu))
except Exception as e:
print('unable to use Coral Edge TPU', e)
def downloadModel():
return self.downloadFile(
f"https://github.com/koush/tflite-models/raw/main/{model}/{model}{suffix}.tflite",
f"{model_version}/{model}{suffix}.tflite",
)
if not self.interpreter_count:
raise Exception('all tpus failed to load')
# self.face_interpreter = make_interpreter(face_model)
try:
if edge_tpus:
suffix = "_edgetpu"
modelFile = downloadModel()
self.edge_tpu_found = str(edge_tpus)
for idx, edge_tpu in enumerate(edge_tpus):
try:
interpreter = make_interpreter(modelFile, ":%s" % idx)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[0][
"shape"
]
self.input_details = int(width), int(height), int(channels)
self.interpreters.put(interpreter)
self.interpreter_count = self.interpreter_count + 1
print("added tpu %s" % (edge_tpu))
except Exception as e:
print("unable to use Coral Edge TPU", e)
if not self.interpreter_count:
raise Exception("all tpus failed to load")
else:
raise Exception()
except Exception as e:
print('unable to use Coral Edge TPU', e)
self.edge_tpu_found = 'Edge TPU not found'
# face_model = scrypted_sdk.zip.open(
# 'fs/mobilenet_ssd_v2_face_quant_postprocess.tflite').read()
interpreter = tflite.Interpreter(model_path=tfliteFile)
self.edge_tpu_found = "Edge TPU not found"
suffix = ""
modelFile = downloadModel()
interpreter = tflite.Interpreter(model_path=modelFile)
interpreter.allocate_tensors()
_, height, width, channels = interpreter.get_input_details()[
0]['shape']
_, height, width, channels = interpreter.get_input_details()[0]["shape"]
self.input_details = int(width), int(height), int(channels)
self.interpreters.put(interpreter)
self.interpreter_count = self.interpreter_count + 1
# self.face_interpreter = make_interpreter(face_model)
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.interpreter_count, thread_name_prefix="tflite", )
print(modelFile, labelsFile)
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.interpreter_count,
thread_name_prefix="tflite",
)
async def putSetting(self, key: str, value: SettingValue):
self.storage.setItem(key, value)
await self.onDeviceEvent(scrypted_sdk.ScryptedInterface.Settings.value, None)
await scrypted_sdk.deviceManager.requestRestart()
async def getSettings(self) -> list[Setting]:
coral: Setting = {
'title': 'Detected Edge TPU',
'description': 'The device paths of the Coral Edge TPUs that will be used for detections.',
'value': self.edge_tpu_found,
'readonly': True,
'key': 'coral',
}
return [coral]
model = self.storage.getItem("model") or "Default"
return [
{
"title": "Detected Edge TPU",
"description": "The device paths of the Coral Edge TPUs that will be used for detections.",
"value": self.edge_tpu_found,
"readonly": True,
"key": "coral",
},
{
"key": "model",
"title": "Model",
"description": "The detection model used to find objects.",
"choices": [
"Default",
"ssd_mobilenet_v2_coco_quant_postprocess",
"yolov8n_full_integer_quant",
],
"value": model,
},
]
# width, height, channels
def get_input_details(self) -> Tuple[int, int, int]:
@@ -108,17 +173,44 @@ class TensorFlowLitePlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted
def predict():
interpreter = self.interpreters.get()
try:
common.set_input(
interpreter, input)
scale = (1, 1)
# _, scale = common.set_resized_input(
# self.interpreter, cropped.size, lambda size: cropped.resize(size, Image.ANTIALIAS))
interpreter.invoke()
objs = detect.get_objects(
interpreter, score_threshold=.2, image_scale=scale)
if self.yolo:
tensor_index = input_details(interpreter, 'index')
im = np.stack([input])
i = interpreter.get_input_details()[0]
if i['dtype'] == np.int8:
im = im.view(np.int8)
else:
im = im.astype(np.float32) / 255.0
interpreter.set_tensor(tensor_index, im)
interpreter.invoke()
output_details = interpreter.get_output_details()
y = []
for output in output_details:
x = interpreter.get_tensor(output['index'])
if output['dtype'] == np.int8:
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
y.append(x)
if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
objs = yolo.parse_yolov8(y[0][0], scale=640)
else:
common.set_input(interpreter, input)
interpreter.invoke()
objs = detect.get_objects(
interpreter, score_threshold=0.2, image_scale=(1, 1)
)
return objs
except:
print('tensorflow-lite encountered an error while detecting. requesting plugin restart.')
traceback.print_exc()
print(
"tensorflow-lite encountered an error while detecting. requesting plugin restart."
)
self.requestRestart()
raise e
finally:

View File

@@ -0,0 +1 @@
../../openvino/src/yolo