From 62e23880fda3c9287f838d2f99f8c2f181a87cd6 Mon Sep 17 00:00:00 2001 From: Koushik Dutta Date: Fri, 19 Apr 2024 10:33:37 -0700 Subject: [PATCH] coreml: handle batching hint failures --- plugins/coreml/package-lock.json | 4 +-- plugins/coreml/package.json | 2 +- .../tensorflow-lite/src/predict/__init__.py | 31 ++++++++++++++----- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/plugins/coreml/package-lock.json b/plugins/coreml/package-lock.json index fb1d14291..31b7d86db 100644 --- a/plugins/coreml/package-lock.json +++ b/plugins/coreml/package-lock.json @@ -1,12 +1,12 @@ { "name": "@scrypted/coreml", - "version": "0.1.42", + "version": "0.1.43", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@scrypted/coreml", - "version": "0.1.42", + "version": "0.1.43", "devDependencies": { "@scrypted/sdk": "file:../../sdk" } diff --git a/plugins/coreml/package.json b/plugins/coreml/package.json index b79b98057..b1bbbc6f4 100644 --- a/plugins/coreml/package.json +++ b/plugins/coreml/package.json @@ -42,5 +42,5 @@ "devDependencies": { "@scrypted/sdk": "file:../../sdk" }, - "version": "0.1.42" + "version": "0.1.43" } diff --git a/plugins/tensorflow-lite/src/predict/__init__.py b/plugins/tensorflow-lite/src/predict/__init__.py index 3d9105bc6..18301aa4a 100644 --- a/plugins/tensorflow-lite/src/predict/__init__.py +++ b/plugins/tensorflow-lite/src/predict/__init__.py @@ -47,6 +47,7 @@ class PredictPlugin(DetectPlugin): self.batch: List[Tuple[Any, asyncio.Future]] = [] self.batching = 0 + self.batch_flush = None def downloadFile(self, url: str, filename: str): try: @@ -144,20 +145,36 @@ class PredictPlugin(DetectPlugin): async def detect_batch(self, inputs: List[Any]) -> List[Any]: pass + async def run_batch(self): + batch = self.batch + self.batch = [] + self.batching = 0 + + if len(batch): + inputs = [x[0] for x in batch] + try: + results = await self.detect_batch(inputs) + for i, result in enumerate(results): + batch[i][1].set_result(result) + except Exception as e: + for i, result in enumerate(results): + batch[i][1].set_exception(e) + + async def flush_batch(self): + self.batch_flush = None + await self.run_batch() + async def queue_batch(self, input: Any) -> List[Any]: future = asyncio.Future(loop = asyncio.get_event_loop()) self.batch.append((input, future)) if self.batching: self.batching = self.batching - 1 if self.batching: + # if there is any sort of error or backlog, . + if not self.batch_flush: + self.batch_flush = self.loop.call_later(.5, lambda: asyncio.ensure_future(self.flush_batch())) return await future - batch = self.batch - self.batch = [] - if len(batch): - inputs = [x[0] for x in batch] - results = await self.detect_batch(inputs) - for i, result in enumerate(results): - batch[i][1].set_result(result) + await self.run_batch() return await future async def safe_detect_once(self, input: Image.Image, settings: Any, src_size, cvss) -> ObjectsDetected: