mirror of
https://github.com/koush/scrypted.git
synced 2026-06-20 16:40:30 +01:00
predict: cluster for should enforce compute/darwin label
This commit is contained in:
@@ -41,6 +41,7 @@
|
||||
],
|
||||
"labels": {
|
||||
"require": [
|
||||
"compute",
|
||||
"darwin"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -147,6 +147,9 @@ class CoreMLPlugin(
|
||||
if not self.forked:
|
||||
asyncio.ensure_future(self.prepareRecognitionModels(), loop=self.loop)
|
||||
|
||||
def getClusterLabels(self):
|
||||
return {"labels": {"require": ["compute", "darwin"]}}
|
||||
|
||||
async def prepareRecognitionModels(self):
|
||||
try:
|
||||
devices = [
|
||||
|
||||
@@ -34,6 +34,9 @@ class CoreMLFaceRecognition(FaceRecognizeDetection):
|
||||
self.detectExecutor = concurrent.futures.ThreadPoolExecutor(1, "detect-face")
|
||||
self.recogExecutor = concurrent.futures.ThreadPoolExecutor(1, "recog-face")
|
||||
|
||||
def getClusterLabels(self):
|
||||
return {"labels": {"require": ["compute", "darwin"]}}
|
||||
|
||||
def downloadModel(self, model: str):
|
||||
model_version = "v7"
|
||||
mlmodel = "model"
|
||||
|
||||
@@ -19,6 +19,9 @@ class CoreMLTextRecognition(TextRecognition):
|
||||
self.detectExecutor = concurrent.futures.ThreadPoolExecutor(1, "detect-text")
|
||||
self.recogExecutor = concurrent.futures.ThreadPoolExecutor(1, "recog-text")
|
||||
|
||||
def getClusterLabels(self):
|
||||
return {"labels": {"require": ["compute", "darwin"]}}
|
||||
|
||||
def downloadModel(self, model: str):
|
||||
model_version = "v8"
|
||||
mlmodel = "model"
|
||||
|
||||
@@ -310,6 +310,9 @@ class PredictPlugin(DetectPlugin, scrypted_sdk.ClusterForkInterface):
|
||||
finally:
|
||||
data.close()
|
||||
|
||||
def getClusterLabels(self):
|
||||
return {"labels": {"require": ["compute"]}}
|
||||
|
||||
async def forkInterfaceInternal(self, options: dict):
|
||||
if self.plugin:
|
||||
return await self.plugin.forkInterfaceInternal(options)
|
||||
@@ -323,9 +326,7 @@ class PredictPlugin(DetectPlugin, scrypted_sdk.ClusterForkInterface):
|
||||
|
||||
forked = self.forks.get(clusterWorkerId, None)
|
||||
if not forked:
|
||||
forked = scrypted_sdk.fork(
|
||||
{"labels": {"require": ["compute"]}, **(options or {})}
|
||||
)
|
||||
forked = scrypted_sdk.fork({**self.getClusterLabels(), **(options or {})})
|
||||
|
||||
def clusterWorkerExit(result):
|
||||
print("cluster worker exit", clusterWorkerId)
|
||||
@@ -380,7 +381,8 @@ class PredictPlugin(DetectPlugin, scrypted_sdk.ClusterForkInterface):
|
||||
{"clusterWorkerId": clusterWorkerId}
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
# traceback.print_exc()
|
||||
pass
|
||||
|
||||
asyncio.ensure_future(startClusterWorker(), loop=self.loop)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user