mirror of
https://github.com/koush/scrypted.git
synced 2026-06-21 00:50:30 +01:00
predict: make models a separate download
This commit is contained in:
@@ -1,20 +0,0 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
rm -rf all_models
|
||||
mkdir -p all_models
|
||||
cd all_models
|
||||
wget https://github.com/koush/coreml-survival-guide/raw/master/MobileNetV2%2BSSDLite/ObjectDetection/ObjectDetection/MobileNetV2_SSDLite.mlmodel
|
||||
wget https://raw.githubusercontent.com/koush/coreml-survival-guide/master/MobileNetV2%2BSSDLite/coco_labels.txt
|
||||
@@ -1 +0,0 @@
|
||||
../all_models/MobileNetV2_SSDLite.mlmodel
|
||||
@@ -1 +0,0 @@
|
||||
../all_models/coco_labels.txt
|
||||
@@ -29,16 +29,17 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Set
|
||||
def __init__(self, nativeId: str | None = None):
|
||||
super().__init__(MIME_TYPE, nativeId=nativeId)
|
||||
|
||||
modelPath = os.path.join(os.environ['SCRYPTED_PLUGIN_VOLUME'], 'zip', 'unzipped', 'fs', 'MobileNetV2_SSDLite.mlmodel')
|
||||
self.model = ct.models.MLModel(modelPath)
|
||||
labelsFile = self.downloadFile('https://raw.githubusercontent.com/koush/coreml-survival-guide/master/MobileNetV2%2BSSDLite/coco_labels.txt', 'coco_labels.txt')
|
||||
modelFile = self.downloadFile('https://github.com/koush/coreml-survival-guide/raw/master/MobileNetV2%2BSSDLite/ObjectDetection/ObjectDetection/MobileNetV2_SSDLite.mlmodel', 'MobileNetV2_SSDLite.mlmodel')
|
||||
|
||||
self.model = ct.models.MLModel(modelFile)
|
||||
|
||||
self.modelspec = self.model.get_spec()
|
||||
self.inputdesc = self.modelspec.description.input[0]
|
||||
self.inputheight = self.inputdesc.type.imageType.height
|
||||
self.inputwidth = self.inputdesc.type.imageType.width
|
||||
|
||||
labels_contents = scrypted_sdk.zip.open(
|
||||
'fs/coco_labels.txt').read().decode('utf8')
|
||||
labels_contents = open(labelsFile, 'r').read()
|
||||
self.labels = parse_label_contents(labels_contents)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2019 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
mkdir -p all_models
|
||||
wget https://dl.google.com/coral/canned_models/all_models.tar.gz
|
||||
tar -C all_models -xvzf all_models.tar.gz
|
||||
rm -f all_models.tar.gz
|
||||
@@ -1 +0,0 @@
|
||||
../all_models/coco_labels.txt
|
||||
@@ -1 +0,0 @@
|
||||
../all_models/mobilenet_ssd_v2_coco_quant_postprocess.tflite
|
||||
@@ -1 +0,0 @@
|
||||
../all_models/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite
|
||||
@@ -8,6 +8,8 @@ from typing import Any, List, Tuple, Mapping
|
||||
import asyncio
|
||||
import time
|
||||
from .rectangle import Rectangle, intersect_area, intersect_rect, to_bounding_box, from_bounding_box, combine_rect
|
||||
import urllib.request
|
||||
import os
|
||||
|
||||
from detect import DetectionSession, DetectPlugin
|
||||
|
||||
@@ -126,6 +128,17 @@ class PredictPlugin(DetectPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Set
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.call_later(4 * 60 * 60, lambda: self.requestRestart())
|
||||
|
||||
def downloadFile(self, url: str, filename: str):
|
||||
filesPath = os.path.join(os.environ['SCRYPTED_PLUGIN_VOLUME'], 'files')
|
||||
fullpath = os.path.join(filesPath, filename)
|
||||
if os.path.isfile(fullpath):
|
||||
return fullpath
|
||||
os.makedirs(filesPath, exist_ok=True)
|
||||
tmp = fullpath + '.tmp'
|
||||
urllib.request.urlretrieve(url, tmp)
|
||||
os.rename(tmp, fullpath)
|
||||
return fullpath
|
||||
|
||||
def getClasses(self) -> list[str]:
|
||||
return list(self.labels.values())
|
||||
|
||||
|
||||
@@ -40,8 +40,11 @@ class TensorFlowLitePlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted
|
||||
def __init__(self, nativeId: str | None = None):
|
||||
super().__init__(MIME_TYPE, nativeId=nativeId)
|
||||
|
||||
labels_contents = scrypted_sdk.zip.open(
|
||||
'fs/coco_labels.txt').read().decode('utf8')
|
||||
tfliteFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/ssd_mobilenet_v2_coco_quant_postprocess.tflite', 'ssd_mobilenet_v2_coco_quant_postprocess.tflite')
|
||||
edgetpuFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite', 'ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite')
|
||||
labelsFile = self.downloadFile('https://raw.githubusercontent.com/google-coral/test_data/master/coco_labels.txt', 'coco_labels.txt')
|
||||
|
||||
labels_contents = open(labelsFile, 'r').read()
|
||||
self.labels = parse_label_contents(labels_contents)
|
||||
self.interpreters = queue.Queue()
|
||||
self.interpreter_count = 0
|
||||
@@ -54,13 +57,11 @@ class TensorFlowLitePlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted
|
||||
self.edge_tpu_found = str(edge_tpus)
|
||||
# todo co-compile
|
||||
# https://coral.ai/docs/edgetpu/compiler/#co-compiling-multiple-models
|
||||
model = scrypted_sdk.zip.open(
|
||||
'fs/mobilenet_ssd_v2_coco_quant_postprocess_edgetpu.tflite').read()
|
||||
# face_model = scrypted_sdk.zip.open(
|
||||
# 'fs/mobilenet_ssd_v2_face_quant_postprocess.tflite').read()
|
||||
for idx, edge_tpu in enumerate(edge_tpus):
|
||||
try:
|
||||
interpreter = make_interpreter(model, ":%s" % idx)
|
||||
interpreter = make_interpreter(edgetpuFile, ":%s" % idx)
|
||||
interpreter.allocate_tensors()
|
||||
_, height, width, channels = interpreter.get_input_details()[
|
||||
0]['shape']
|
||||
@@ -77,11 +78,9 @@ class TensorFlowLitePlugin(PredictPlugin, scrypted_sdk.BufferConverter, scrypted
|
||||
except Exception as e:
|
||||
print('unable to use Coral Edge TPU', e)
|
||||
self.edge_tpu_found = 'Edge TPU not found'
|
||||
model = scrypted_sdk.zip.open(
|
||||
'fs/mobilenet_ssd_v2_coco_quant_postprocess.tflite').read()
|
||||
# face_model = scrypted_sdk.zip.open(
|
||||
# 'fs/mobilenet_ssd_v2_face_quant_postprocess.tflite').read()
|
||||
interpreter = tflite.Interpreter(model_content=model)
|
||||
interpreter = tflite.Interpreter(model_path=tfliteFile)
|
||||
interpreter.allocate_tensors()
|
||||
_, height, width, channels = interpreter.get_input_details()[
|
||||
0]['shape']
|
||||
|
||||
Reference in New Issue
Block a user