mirror of
https://github.com/koush/scrypted.git
synced 2026-05-04 21:30:30 +01:00
coreml: recognition fixes
This commit is contained in:
4
plugins/coreml/package-lock.json
generated
4
plugins/coreml/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "@scrypted/coreml",
|
||||
"version": "0.1.35",
|
||||
"version": "0.1.37",
|
||||
"lockfileVersion": 2,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "@scrypted/coreml",
|
||||
"version": "0.1.35",
|
||||
"version": "0.1.37",
|
||||
"devDependencies": {
|
||||
"@scrypted/sdk": "file:../../sdk"
|
||||
}
|
||||
|
||||
@@ -42,5 +42,5 @@
|
||||
"devDependencies": {
|
||||
"@scrypted/sdk": "file:../../sdk"
|
||||
},
|
||||
"version": "0.1.35"
|
||||
"version": "0.1.37"
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ availableModels = [
|
||||
"Default",
|
||||
"scrypted_yolov9c_320",
|
||||
"scrypted_yolov9c",
|
||||
"scrypted_yolov8n_320",
|
||||
"scrypted_yolov8n",
|
||||
"ssdlite_mobilenet_v2",
|
||||
"yolov4-tiny",
|
||||
]
|
||||
@@ -63,10 +65,10 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
|
||||
self.storage.setItem("model", "Default")
|
||||
model = "scrypted_yolov9c_320"
|
||||
self.yolo = "yolo" in model
|
||||
self.yolov9 = "yolov9" in model
|
||||
self.scrypted_yolo = "scrypted_yolo" in model
|
||||
self.scrypted_model = "scrypted" in model
|
||||
model_version = "v5"
|
||||
mlmodel = "model" if self.yolov9 else model
|
||||
model_version = "v7"
|
||||
mlmodel = "model" if self.scrypted_yolo else model
|
||||
|
||||
print(f"model: {model}")
|
||||
|
||||
@@ -77,7 +79,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
|
||||
f"{model}.mlmodel",
|
||||
)
|
||||
else:
|
||||
if self.yolov9:
|
||||
if self.scrypted_yolo:
|
||||
files = [
|
||||
f"{model}/{model}.mlpackage/Data/com.apple.CoreML/weights/weight.bin",
|
||||
f"{model}/{model}.mlpackage/Data/com.apple.CoreML/{mlmodel}.mlmodel",
|
||||
@@ -112,6 +114,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
|
||||
self.inputdesc = self.modelspec.description.input[0]
|
||||
self.inputheight = self.inputdesc.type.imageType.height
|
||||
self.inputwidth = self.inputdesc.type.imageType.width
|
||||
self.input_name = self.model.get_spec().description.input[0].name
|
||||
|
||||
self.labels = parse_labels(self.modelspec.description.metadata.userDefined)
|
||||
self.loop = asyncio.get_event_loop()
|
||||
@@ -134,7 +137,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
|
||||
"interfaces": [
|
||||
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
|
||||
],
|
||||
"name": "Vision Framework",
|
||||
"name": "CoreML Recognition",
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -176,15 +179,14 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
|
||||
|
||||
# run in executor if this is the plugin loop
|
||||
if self.yolo:
|
||||
input_name = "image" if self.yolov9 else "input_1"
|
||||
if asyncio.get_event_loop() is self.loop:
|
||||
out_dict = await asyncio.get_event_loop().run_in_executor(
|
||||
predictExecutor, lambda: self.model.predict({input_name: input})
|
||||
predictExecutor, lambda: self.model.predict({self.input_name: input})
|
||||
)
|
||||
else:
|
||||
out_dict = self.model.predict({input_name: input})
|
||||
out_dict = self.model.predict({self.input_name: input})
|
||||
|
||||
if self.yolov9:
|
||||
if self.scrypted_yolo:
|
||||
results = list(out_dict.values())[0][0]
|
||||
objs = yolo.parse_yolov9(results)
|
||||
ret = self.create_detection_result(objs, src_size, cvss)
|
||||
|
||||
@@ -100,7 +100,7 @@ class VisionPlugin(PredictPlugin):
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.minThreshold = 0.7
|
||||
|
||||
self.detectModel = self.downloadModel("scrypted_yolov9c_flt_320")
|
||||
self.detectModel = self.downloadModel("scrypted_yolov8n_flt_320")
|
||||
self.detectInput = self.detectModel.get_spec().description.input[0].name
|
||||
|
||||
self.textModel = self.downloadModel("vgg_english_g2")
|
||||
@@ -110,7 +110,7 @@ class VisionPlugin(PredictPlugin):
|
||||
self.faceInput = self.faceModel.get_spec().description.input[0].name
|
||||
|
||||
def downloadModel(self, model: str):
|
||||
model_version = "v3"
|
||||
model_version = "v7"
|
||||
mlmodel = "model"
|
||||
|
||||
files = [
|
||||
@@ -244,7 +244,7 @@ class VisionPlugin(PredictPlugin):
|
||||
|
||||
out_dict = await asyncio.get_event_loop().run_in_executor(
|
||||
predictExecutor,
|
||||
lambda: self.faceModel.predict({self.textInput: processed_tensor}),
|
||||
lambda: self.faceModel.predict({self.faceInput: processed_tensor}),
|
||||
)
|
||||
|
||||
output = out_dict["var_2167"][0]
|
||||
|
||||
Reference in New Issue
Block a user