mirror of
https://github.com/koush/scrypted.git
synced 2026-03-16 23:22:07 +00:00
coreml: clip
This commit is contained in:
@@ -15,6 +15,7 @@ from scrypted_sdk import Setting, SettingValue
|
||||
from common import yolo
|
||||
from coreml.face_recognition import CoreMLFaceRecognition
|
||||
from coreml.custom_detection import CoreMLCustomDetection
|
||||
from coreml.clip_embedding import CoreMLClipEmbedding
|
||||
|
||||
try:
|
||||
from coreml.text_recognition import CoreMLTextRecognition
|
||||
@@ -146,6 +147,7 @@ class CoreMLPlugin(
|
||||
|
||||
self.faceDevice = None
|
||||
self.textDevice = None
|
||||
self.clipDevice = None
|
||||
|
||||
if not self.forked:
|
||||
asyncio.ensure_future(self.prepareRecognitionModels(), loop=self.loop)
|
||||
@@ -177,6 +179,19 @@ class CoreMLPlugin(
|
||||
},
|
||||
)
|
||||
|
||||
await scrypted_sdk.deviceManager.onDeviceDiscovered(
|
||||
{
|
||||
"nativeId": "clipembedding",
|
||||
"type": scrypted_sdk.ScryptedDeviceType.Builtin.value,
|
||||
"interfaces": [
|
||||
scrypted_sdk.ScryptedInterface.ClusterForkInterface.value,
|
||||
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
|
||||
scrypted_sdk.ScryptedInterface.TextEmbedding.value,
|
||||
scrypted_sdk.ScryptedInterface.ImageEmbedding.value,
|
||||
],
|
||||
"name": "CoreML CLIP Embedding",
|
||||
}
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -184,9 +199,12 @@ class CoreMLPlugin(
|
||||
if nativeId == "facerecognition":
|
||||
self.faceDevice = self.faceDevice or CoreMLFaceRecognition(self, nativeId)
|
||||
return self.faceDevice
|
||||
if nativeId == "textrecognition":
|
||||
elif nativeId == "textrecognition":
|
||||
self.textDevice = self.textDevice or CoreMLTextRecognition(self, nativeId)
|
||||
return self.textDevice
|
||||
elif nativeId == "clipembedding":
|
||||
self.clipDevice = self.clipDevice or CoreMLClipEmbedding(self, nativeId)
|
||||
return self.clipDevice
|
||||
custom_model = self.custom_models.get(nativeId, None)
|
||||
if custom_model:
|
||||
return custom_model
|
||||
|
||||
84
plugins/coreml/src/coreml/clip_embedding.py
Normal file
84
plugins/coreml/src/coreml/clip_embedding.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import coremltools as ct
|
||||
from predict.clip import ClipEmbedding
|
||||
from scrypted_sdk import ObjectsDetected
|
||||
import os
|
||||
import concurrent.futures
|
||||
|
||||
class CoreMLClipEmbedding(ClipEmbedding):
|
||||
def __init__(self, plugin, nativeId: str):
|
||||
super().__init__(plugin=plugin, nativeId=nativeId)
|
||||
self.predictExecutor = concurrent.futures.ThreadPoolExecutor(1, "detect-custom")
|
||||
|
||||
def getFiles(self):
|
||||
return [
|
||||
"text.mlpackage/Manifest.json",
|
||||
"text.mlpackage/Data/com.apple.CoreML/weights/weight.bin",
|
||||
"text.mlpackage/Data/com.apple.CoreML/model.mlmodel",
|
||||
|
||||
"vision.mlpackage/Manifest.json",
|
||||
"vision.mlpackage/Data/com.apple.CoreML/weights/weight.bin",
|
||||
"vision.mlpackage/Data/com.apple.CoreML/model.mlmodel",
|
||||
]
|
||||
|
||||
def loadModel(self, files):
|
||||
# find the xml file in the files list
|
||||
text_manifest = [f for f in files if f.lower().endswith('text.mlpackage/manifest.json')]
|
||||
if not text_manifest:
|
||||
raise ValueError("No XML model file found in the provided files list")
|
||||
text_manifest = text_manifest[0]
|
||||
|
||||
vision_manifest = [f for f in files if f.lower().endswith('vision.mlpackage/manifest.json')]
|
||||
if not vision_manifest:
|
||||
raise ValueError("No XML model file found in the provided files list")
|
||||
vision_manifest = vision_manifest[0]
|
||||
|
||||
|
||||
textModel = ct.models.MLModel(os.path.dirname(text_manifest))
|
||||
visionModel = ct.models.MLModel(os.path.dirname(vision_manifest))
|
||||
|
||||
return textModel, visionModel
|
||||
|
||||
async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):
|
||||
def predict():
|
||||
inputs = self.processor(images=input, return_tensors="np", padding="max_length", truncation=True)
|
||||
_, vision_model = self.model
|
||||
vision_predictions = vision_model.predict({'x': inputs['pixel_values']})
|
||||
image_embeds = vision_predictions['var_877']
|
||||
# this is a hack to utilize the existing image massaging infrastructure
|
||||
embedding = bytearray(image_embeds.astype(np.float32).tobytes())
|
||||
ret: ObjectsDetected = {
|
||||
"detections": [
|
||||
{
|
||||
"embedding": embedding,
|
||||
}
|
||||
],
|
||||
"inputDimensions": src_size
|
||||
}
|
||||
|
||||
return ret
|
||||
|
||||
ret = await asyncio.get_event_loop().run_in_executor(
|
||||
self.predictExecutor, lambda: predict()
|
||||
)
|
||||
return ret
|
||||
|
||||
async def getTextEmbedding(self, input):
|
||||
def predict():
|
||||
inputs = self.processor(text=input, return_tensors="np", padding="max_length", truncation=True)
|
||||
text_model, _ = self.model
|
||||
text_predictions = text_model.predict({'input_ids_1': inputs['input_ids'].astype(np.float32), 'attention_mask_1': inputs['attention_mask'].astype(np.float32)})
|
||||
text_embeds = text_predictions['var_1050']
|
||||
return bytearray(text_embeds.astype(np.float32).tobytes())
|
||||
|
||||
ret = await asyncio.get_event_loop().run_in_executor(
|
||||
self.predictExecutor, lambda: predict()
|
||||
)
|
||||
return ret
|
||||
@@ -1,3 +1,5 @@
|
||||
coremltools==8.0
|
||||
Pillow==10.3.0
|
||||
opencv-python-headless==4.10.0.84
|
||||
|
||||
transformers==4.52.4
|
||||
|
||||
Reference in New Issue
Block a user