predict: cluster for should enforce compute/darwin label

This commit is contained in:
Koushik Dutta
2024-11-22 17:08:43 -08:00
parent ea065f506c
commit ab4738973d
5 changed files with 16 additions and 4 deletions

View File

@@ -41,6 +41,7 @@
],
"labels": {
"require": [
"compute",
"darwin"
]
}

View File

@@ -147,6 +147,9 @@ class CoreMLPlugin(
if not self.forked:
asyncio.ensure_future(self.prepareRecognitionModels(), loop=self.loop)
def getClusterLabels(self):
return {"labels": {"require": ["compute", "darwin"]}}
async def prepareRecognitionModels(self):
try:
devices = [

View File

@@ -34,6 +34,9 @@ class CoreMLFaceRecognition(FaceRecognizeDetection):
self.detectExecutor = concurrent.futures.ThreadPoolExecutor(1, "detect-face")
self.recogExecutor = concurrent.futures.ThreadPoolExecutor(1, "recog-face")
def getClusterLabels(self):
return {"labels": {"require": ["compute", "darwin"]}}
def downloadModel(self, model: str):
model_version = "v7"
mlmodel = "model"

View File

@@ -19,6 +19,9 @@ class CoreMLTextRecognition(TextRecognition):
self.detectExecutor = concurrent.futures.ThreadPoolExecutor(1, "detect-text")
self.recogExecutor = concurrent.futures.ThreadPoolExecutor(1, "recog-text")
def getClusterLabels(self):
return {"labels": {"require": ["compute", "darwin"]}}
def downloadModel(self, model: str):
model_version = "v8"
mlmodel = "model"

View File

@@ -310,6 +310,9 @@ class PredictPlugin(DetectPlugin, scrypted_sdk.ClusterForkInterface):
finally:
data.close()
def getClusterLabels(self):
return {"labels": {"require": ["compute"]}}
async def forkInterfaceInternal(self, options: dict):
if self.plugin:
return await self.plugin.forkInterfaceInternal(options)
@@ -323,9 +326,7 @@ class PredictPlugin(DetectPlugin, scrypted_sdk.ClusterForkInterface):
forked = self.forks.get(clusterWorkerId, None)
if not forked:
forked = scrypted_sdk.fork(
{"labels": {"require": ["compute"]}, **(options or {})}
)
forked = scrypted_sdk.fork({**self.getClusterLabels(), **(options or {})})
def clusterWorkerExit(result):
print("cluster worker exit", clusterWorkerId)
@@ -380,7 +381,8 @@ class PredictPlugin(DetectPlugin, scrypted_sdk.ClusterForkInterface):
{"clusterWorkerId": clusterWorkerId}
)
except:
traceback.print_exc()
# traceback.print_exc()
pass
asyncio.ensure_future(startClusterWorker(), loop=self.loop)