Files
scrypted/plugins/openvino/src/predict/clip.py
The Beholder bfb8c233f4 openvino: avoid CLIP startup timeout by loading HF cache first (#1949)
Scrypted could restart the OpenVINO plugin on startup in offline/firewalled setups because CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") triggers HuggingFace Hub network checks/retries that exceed the plugin startup watchdog.
Update predict/clip.py to:
- Load the CLIP processor from the local HF cache first (local_files_only=True) so startup is fast/offline-safe.
- Refresh the processor cache online asynchronously in a background thread (asyncio.to_thread) so update checks don’t block startup.
- Add simple log prints to indicate cache load vs refresh success/failure.
2025-12-26 18:38:13 -08:00

91 lines
2.9 KiB
Python

from __future__ import annotations
import asyncio
import base64
import os
from typing import Tuple
import scrypted_sdk
from transformers import CLIPProcessor
from predict import PredictPlugin
class ClipEmbedding(PredictPlugin, scrypted_sdk.TextEmbedding, scrypted_sdk.ImageEmbedding):
def __init__(self, plugin: PredictPlugin, nativeId: str):
super().__init__(nativeId=nativeId, plugin=plugin)
hf_id = "openai/clip-vit-base-patch32"
self.inputwidth = 224
self.inputheight = 224
self.labels = {}
self.loop = asyncio.get_event_loop()
self.minThreshold = 0.5
self.model = self.initModel()
cache_dir = os.path.join(os.environ["SCRYPTED_PLUGIN_VOLUME"], "files", "hf")
os.makedirs(cache_dir, exist_ok=True)
self.processor = None
print("Loading CLIP processor from local cache.")
try:
self.processor = CLIPProcessor.from_pretrained(
hf_id,
cache_dir=cache_dir,
local_files_only=True,
)
print("Loaded CLIP processor from local cache.")
except Exception:
print("CLIP processor not available in local cache yet.")
asyncio.ensure_future(self.refreshClipProcessor(hf_id, cache_dir), loop=self.loop)
async def refreshClipProcessor(self, hf_id: str, cache_dir: str):
try:
print("Refreshing CLIP processor cache (online).")
processor = await asyncio.to_thread(
CLIPProcessor.from_pretrained,
hf_id,
cache_dir=cache_dir,
)
self.processor = processor
print("Refreshed CLIP processor cache.")
except Exception:
print("CLIP processor cache refresh failed.")
def getFiles(self):
pass
def initModel(self):
local_files: list[str] = []
for file in self.getFiles():
remote_file = "https://huggingface.co/koushd/clip/resolve/main/" + file
localFile = self.downloadFile(remote_file, f"{self.id}/{file}")
local_files.append(localFile)
return self.loadModel(local_files)
def loadModel(self, files: list[str]):
pass
async def getImageEmbedding(self, input):
detections = await super().detectObjects(input, None)
return detections["detections"][0]["embedding"]
async def detectObjects(self, mediaObject, session = None):
ret = await super().detectObjects(mediaObject, session)
embedding = ret["detections"][0]['embedding']
ret["detections"][0]['embedding'] = base64.b64encode(embedding).decode("utf-8")
return ret
# width, height, channels
def get_input_details(self) -> Tuple[int, int, int]:
return (self.inputwidth, self.inputheight, 3)
def get_input_size(self) -> Tuple[float, float]:
return (self.inputwidth, self.inputheight)
def get_input_format(self) -> str:
return "rgb"