coreml: recognition fixes

This commit is contained in:
Koushik Dutta
2024-04-12 12:22:37 -07:00
parent 2ab74bc0f8
commit 4684ea6592
4 changed files with 17 additions and 15 deletions

View File

@@ -1,12 +1,12 @@
{
"name": "@scrypted/coreml",
"version": "0.1.35",
"version": "0.1.37",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "@scrypted/coreml",
"version": "0.1.35",
"version": "0.1.37",
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
}

View File

@@ -42,5 +42,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.1.35"
"version": "0.1.37"
}

View File

@@ -22,6 +22,8 @@ availableModels = [
"Default",
"scrypted_yolov9c_320",
"scrypted_yolov9c",
"scrypted_yolov8n_320",
"scrypted_yolov8n",
"ssdlite_mobilenet_v2",
"yolov4-tiny",
]
@@ -63,10 +65,10 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
self.storage.setItem("model", "Default")
model = "scrypted_yolov9c_320"
self.yolo = "yolo" in model
self.yolov9 = "yolov9" in model
self.scrypted_yolo = "scrypted_yolo" in model
self.scrypted_model = "scrypted" in model
model_version = "v5"
mlmodel = "model" if self.yolov9 else model
model_version = "v7"
mlmodel = "model" if self.scrypted_yolo else model
print(f"model: {model}")
@@ -77,7 +79,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
f"{model}.mlmodel",
)
else:
if self.yolov9:
if self.scrypted_yolo:
files = [
f"{model}/{model}.mlpackage/Data/com.apple.CoreML/weights/weight.bin",
f"{model}/{model}.mlpackage/Data/com.apple.CoreML/{mlmodel}.mlmodel",
@@ -112,6 +114,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
self.inputdesc = self.modelspec.description.input[0]
self.inputheight = self.inputdesc.type.imageType.height
self.inputwidth = self.inputdesc.type.imageType.width
self.input_name = self.model.get_spec().description.input[0].name
self.labels = parse_labels(self.modelspec.description.metadata.userDefined)
self.loop = asyncio.get_event_loop()
@@ -134,7 +137,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
"interfaces": [
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
],
"name": "Vision Framework",
"name": "CoreML Recognition",
}
]
}
@@ -176,15 +179,14 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
# run in executor if this is the plugin loop
if self.yolo:
input_name = "image" if self.yolov9 else "input_1"
if asyncio.get_event_loop() is self.loop:
out_dict = await asyncio.get_event_loop().run_in_executor(
predictExecutor, lambda: self.model.predict({input_name: input})
predictExecutor, lambda: self.model.predict({self.input_name: input})
)
else:
out_dict = self.model.predict({input_name: input})
out_dict = self.model.predict({self.input_name: input})
if self.yolov9:
if self.scrypted_yolo:
results = list(out_dict.values())[0][0]
objs = yolo.parse_yolov9(results)
ret = self.create_detection_result(objs, src_size, cvss)

View File

@@ -100,7 +100,7 @@ class VisionPlugin(PredictPlugin):
self.loop = asyncio.get_event_loop()
self.minThreshold = 0.7
self.detectModel = self.downloadModel("scrypted_yolov9c_flt_320")
self.detectModel = self.downloadModel("scrypted_yolov8n_flt_320")
self.detectInput = self.detectModel.get_spec().description.input[0].name
self.textModel = self.downloadModel("vgg_english_g2")
@@ -110,7 +110,7 @@ class VisionPlugin(PredictPlugin):
self.faceInput = self.faceModel.get_spec().description.input[0].name
def downloadModel(self, model: str):
model_version = "v3"
model_version = "v7"
mlmodel = "model"
files = [
@@ -244,7 +244,7 @@ class VisionPlugin(PredictPlugin):
out_dict = await asyncio.get_event_loop().run_in_executor(
predictExecutor,
lambda: self.faceModel.predict({self.textInput: processed_tensor}),
lambda: self.faceModel.predict({self.faceInput: processed_tensor}),
)
output = out_dict["var_2167"][0]