From 39c637a95f9d5afff1ac39fc2a20e9c78adfc1b6 Mon Sep 17 00:00:00 2001 From: Koushik Dutta Date: Mon, 22 Apr 2024 12:57:11 -0700 Subject: [PATCH] coreml: wip refactor text recognition --- plugins/coreml/src/coreml/__init__.py | 47 +++- plugins/coreml/src/coreml/text_recognition.py | 39 +++ plugins/coreml/src/requirements.optional.txt | 1 + .../src/predict/craft_utils.py | 259 ++++++++++++++++++ .../tensorflow-lite/src/predict/recognize.py | 2 + .../src/predict/text_recognize.py | 107 ++++++++ 6 files changed, 443 insertions(+), 12 deletions(-) create mode 100644 plugins/coreml/src/coreml/text_recognition.py create mode 100644 plugins/tensorflow-lite/src/predict/craft_utils.py create mode 100644 plugins/tensorflow-lite/src/predict/text_recognize.py diff --git a/plugins/coreml/src/coreml/__init__.py b/plugins/coreml/src/coreml/__init__.py index c362d1c85..a1d5478cb 100644 --- a/plugins/coreml/src/coreml/__init__.py +++ b/plugins/coreml/src/coreml/__init__.py @@ -14,6 +14,11 @@ from scrypted_sdk import Setting, SettingValue from common import yolo from coreml.recognition import CoreMLRecognition + +try: + from coreml.text_recognition import CoreMLTextRecognition +except: + CoreMLTextRecognition = None from predict import Prediction, PredictPlugin from predict.rectangle import Rectangle @@ -131,25 +136,43 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv async def prepareRecognitionModels(self): try: + devices = [ + { + "nativeId": "recognition", + "type": scrypted_sdk.ScryptedDeviceType.Builtin.value, + "interfaces": [ + scrypted_sdk.ScryptedInterface.ObjectDetection.value, + ], + "name": "CoreML Recognition", + }, + ] + + if CoreMLTextRecognition: + devices.append( + { + "nativeId": "textrecognition", + "type": scrypted_sdk.ScryptedDeviceType.Builtin.value, + "interfaces": [ + scrypted_sdk.ScryptedInterface.ObjectDetection.value, + ], + "name": "CoreML Text Recognition", + }, + ) + await scrypted_sdk.deviceManager.onDevicesChanged( { - "devices": [ - { - "nativeId": "recognition", - "type": scrypted_sdk.ScryptedDeviceType.Builtin.value, - "interfaces": [ - scrypted_sdk.ScryptedInterface.ObjectDetection.value, - ], - "name": "CoreML Recognition", - } - ] + "devices": devices, } ) except: pass async def getDevice(self, nativeId: str) -> Any: - return CoreMLRecognition(nativeId) + if nativeId == "recognition": + return CoreMLRecognition(nativeId) + if nativeId == "textrecognition": + return CoreMLTextRecognition(nativeId) + raise Exception("unknown device") async def getSettings(self) -> list[Setting]: model = self.storage.getItem("model") or "Default" @@ -174,7 +197,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv def get_input_size(self) -> Tuple[float, float]: return (self.inputwidth, self.inputheight) - + async def detect_batch(self, inputs: List[Any]) -> List[Any]: out_dicts = await asyncio.get_event_loop().run_in_executor( predictExecutor, lambda: self.model.predict(inputs) diff --git a/plugins/coreml/src/coreml/text_recognition.py b/plugins/coreml/src/coreml/text_recognition.py new file mode 100644 index 000000000..bc05de7ed --- /dev/null +++ b/plugins/coreml/src/coreml/text_recognition.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import os + +import coremltools as ct + +from predict.text_recognize import TextRecognition + + +class CoreMLTextRecognition(TextRecognition): + def __init__(self, nativeId: str | None = None): + super().__init__(nativeId=nativeId) + + def downloadModel(self, model: str): + model_version = "v7" + mlmodel = "model" + + files = [ + f"{model}/{model}.mlpackage/Data/com.apple.CoreML/weights/weight.bin", + f"{model}/{model}.mlpackage/Data/com.apple.CoreML/{mlmodel}.mlmodel", + f"{model}/{model}.mlpackage/Manifest.json", + ] + + for f in files: + p = self.downloadFile( + f"https://github.com/koush/coreml-models/raw/main/{f}", + f"{model_version}/{f}", + ) + modelFile = os.path.dirname(p) + + model = ct.models.MLModel(modelFile) + inputName = model.get_spec().description.input[0].name + return model, inputName + + def predictDetectModel(self, input): + model, inputName = self.detectModel + out_dict = model.predict({inputName: input}) + results = list(out_dict.values())[0] + return results diff --git a/plugins/coreml/src/requirements.optional.txt b/plugins/coreml/src/requirements.optional.txt index e69de29bb..0dd006bbc 100644 --- a/plugins/coreml/src/requirements.optional.txt +++ b/plugins/coreml/src/requirements.optional.txt @@ -0,0 +1 @@ +opencv-python diff --git a/plugins/tensorflow-lite/src/predict/craft_utils.py b/plugins/tensorflow-lite/src/predict/craft_utils.py new file mode 100644 index 000000000..c3de45691 --- /dev/null +++ b/plugins/tensorflow-lite/src/predict/craft_utils.py @@ -0,0 +1,259 @@ +""" +Copyright (c) 2019-present NAVER Corp. +MIT License +""" + +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import math + +def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): + # should be RGB order + img = in_img.copy().astype(np.float32) + + img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32) + img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32) + return img + +""" auxiliary functions """ +# unwarp corodinates +def warpCoord(Minv, pt): + out = np.matmul(Minv, (pt[0], pt[1], 1)) + return np.array([out[0]/out[2], out[1]/out[2]]) +""" end of auxiliary functions """ + + +def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text, estimate_num_chars=False): + # prepare data + linkmap = linkmap.copy() + textmap = textmap.copy() + img_h, img_w = textmap.shape + + """ labeling method """ + ret, text_score = cv2.threshold(textmap, low_text, 1, 0) + ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) + + text_score_comb = np.clip(text_score + link_score, 0, 1) + nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4) + + det = [] + mapper = [] + for k in range(1,nLabels): + # size filtering + size = stats[k, cv2.CC_STAT_AREA] + if size < 10: continue + + # thresholding + if np.max(textmap[labels==k]) < text_threshold: continue + + # make segmentation map + segmap = np.zeros(textmap.shape, dtype=np.uint8) + segmap[labels==k] = 255 + if estimate_num_chars: + from scipy.ndimage import label + _, character_locs = cv2.threshold((textmap - linkmap) * segmap /255., text_threshold, 1, 0) + _, n_chars = label(character_locs) + mapper.append(n_chars) + else: + mapper.append(k) + segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area + x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] + w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] + niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) + sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1 + # boundary check + if sx < 0 : sx = 0 + if sy < 0 : sy = 0 + if ex >= img_w: ex = img_w + if ey >= img_h: ey = img_h + kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter)) + segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) + + # make box + np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2) + rectangle = cv2.minAreaRect(np_contours) + box = cv2.boxPoints(rectangle) + + # align diamond-shape + w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) + box_ratio = max(w, h) / (min(w, h) + 1e-5) + if abs(1 - box_ratio) <= 0.1: + l, r = min(np_contours[:,0]), max(np_contours[:,0]) + t, b = min(np_contours[:,1]), max(np_contours[:,1]) + box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) + + # make clock-wise order + startidx = box.sum(axis=1).argmin() + box = np.roll(box, 4-startidx, 0) + box = np.array(box) + + det.append(box) + + return det, labels, mapper + +def getPoly_core(boxes, labels, mapper, linkmap): + # configs + num_cp = 5 + max_len_ratio = 0.7 + expand_ratio = 1.45 + max_r = 2.0 + step_r = 0.2 + + polys = [] + for k, box in enumerate(boxes): + # size filter for small instance + w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1) + if w < 10 or h < 10: + polys.append(None); continue + + # warp image + tar = np.float32([[0,0],[w,0],[w,h],[0,h]]) + M = cv2.getPerspectiveTransform(box, tar) + word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST) + try: + Minv = np.linalg.inv(M) + except: + polys.append(None); continue + + # binarization for selected label + cur_label = mapper[k] + word_label[word_label != cur_label] = 0 + word_label[word_label > 0] = 1 + + """ Polygon generation """ + # find top/bottom contours + cp = [] + max_len = -1 + for i in range(w): + region = np.where(word_label[:,i] != 0)[0] + if len(region) < 2 : continue + cp.append((i, region[0], region[-1])) + length = region[-1] - region[0] + 1 + if length > max_len: max_len = length + + # pass if max_len is similar to h + if h * max_len_ratio < max_len: + polys.append(None); continue + + # get pivot points with fixed length + tot_seg = num_cp * 2 + 1 + seg_w = w / tot_seg # segment width + pp = [None] * num_cp # init pivot points + cp_section = [[0, 0]] * tot_seg + seg_height = [0] * num_cp + seg_num = 0 + num_sec = 0 + prev_h = -1 + for i in range(0,len(cp)): + (x, sy, ey) = cp[i] + if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: + # average previous segment + if num_sec == 0: break + cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec] + num_sec = 0 + + # reset variables + seg_num += 1 + prev_h = -1 + + # accumulate center points + cy = (sy + ey) * 0.5 + cur_h = ey - sy + 1 + cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy] + num_sec += 1 + + if seg_num % 2 == 0: continue # No polygon area + + if prev_h < cur_h: + pp[int((seg_num - 1)/2)] = (x, cy) + seg_height[int((seg_num - 1)/2)] = cur_h + prev_h = cur_h + + # processing last segment + if num_sec != 0: + cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec] + + # pass if num of pivots is not sufficient or segment width is smaller than character height + if None in pp or seg_w < np.max(seg_height) * 0.25: + polys.append(None); continue + + # calc median maximum of pivot points + half_char_h = np.median(seg_height) * expand_ratio / 2 + + # calc gradiant and apply to make horizontal pivots + new_pp = [] + for i, (x, cy) in enumerate(pp): + dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] + dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] + if dx == 0: # gradient if zero + new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) + continue + rad = - math.atan2(dy, dx) + c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) + new_pp.append([x - s, cy - c, x + s, cy + c]) + + # get edge points to cover character heatmaps + isSppFound, isEppFound = False, False + grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0]) + grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0]) + for r in np.arange(0.5, max_r, step_r): + dx = 2 * half_char_h * r + if not isSppFound: + line_img = np.zeros(word_label.shape, dtype=np.uint8) + dy = grad_s * dx + p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) + cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) + if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: + spp = p + isSppFound = True + if not isEppFound: + line_img = np.zeros(word_label.shape, dtype=np.uint8) + dy = grad_e * dx + p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) + cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) + if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: + epp = p + isEppFound = True + if isSppFound and isEppFound: + break + + # pass if boundary of polygon is not found + if not (isSppFound and isEppFound): + polys.append(None); continue + + # make final polygon + poly = [] + poly.append(warpCoord(Minv, (spp[0], spp[1]))) + for p in new_pp: + poly.append(warpCoord(Minv, (p[0], p[1]))) + poly.append(warpCoord(Minv, (epp[0], epp[1]))) + poly.append(warpCoord(Minv, (epp[2], epp[3]))) + for p in reversed(new_pp): + poly.append(warpCoord(Minv, (p[2], p[3]))) + poly.append(warpCoord(Minv, (spp[2], spp[3]))) + + # add to final result + polys.append(np.array(poly)) + + return polys + +def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False, estimate_num_chars=False): + if poly and estimate_num_chars: + raise Exception("Estimating the number of characters not currently supported with poly.") + boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text, estimate_num_chars) + + if poly: + polys = getPoly_core(boxes, labels, mapper, linkmap) + else: + polys = [None] * len(boxes) + + return boxes, polys, mapper + +def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2): + if len(polys) > 0: + polys = np.array(polys) + for k in range(len(polys)): + if polys[k] is not None: + polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) + return polys diff --git a/plugins/tensorflow-lite/src/predict/recognize.py b/plugins/tensorflow-lite/src/predict/recognize.py index 7efbd3c3a..ab1330ca1 100644 --- a/plugins/tensorflow-lite/src/predict/recognize.py +++ b/plugins/tensorflow-lite/src/predict/recognize.py @@ -204,6 +204,8 @@ class RecognizeDetection(PredictPlugin): futures.append(asyncio.ensure_future(self.setEmbedding(d, image))) elif d["className"] == "plate": futures.append(asyncio.ensure_future(self.setLabel(d, image))) + # elif d['className'] == 'text': + # futures.append(asyncio.ensure_future(self.setLabel(d, image))) if len(futures): await asyncio.wait(futures) diff --git a/plugins/tensorflow-lite/src/predict/text_recognize.py b/plugins/tensorflow-lite/src/predict/text_recognize.py new file mode 100644 index 000000000..bc18e47a0 --- /dev/null +++ b/plugins/tensorflow-lite/src/predict/text_recognize.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +from typing import Any, List, Tuple + +import numpy as np +import scrypted_sdk +from PIL import Image + +from predict import Prediction, PredictPlugin +from predict.craft_utils import normalizeMeanVariance +from predict.rectangle import Rectangle + +from .craft_utils import adjustResultCoordinates, getDetBoxes + +predictExecutor = concurrent.futures.ThreadPoolExecutor(1, "TextDetect") + +class TextRecognition(PredictPlugin): + def __init__(self, nativeId: str | None = None): + super().__init__(nativeId=nativeId) + + self.inputheight = 640 + self.inputwidth = 640 + + self.labels = { + 0: "text", + } + self.loop = asyncio.get_event_loop() + self.minThreshold = 0.1 + + self.detectModel = self.downloadModel("craft") + + + def downloadModel(self, model: str): + pass + + def predictDetectModel(self, input): + pass + + async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss) -> scrypted_sdk.ObjectsDetected: + image_tensor = normalizeMeanVariance(np.array(input)) + # reshape to c w h + image_tensor = image_tensor.transpose([2, 0, 1]) + # add extra dimension to tensor + image_tensor = np.expand_dims(image_tensor, axis=0) + + y = await asyncio.get_event_loop().run_in_executor( + predictExecutor, lambda: self.predictDetectModel(image_tensor) + ) + + estimate_num_chars = False + ratio_h = ratio_w = 1 + text_threshold = .7 + link_threshold = .7 + low_text = .4 + poly = False + + boxes_list, polys_list = [], [] + for out in y: + # make score and link map + score_text = out[:, :, 0] + score_link = out[:, :, 1] + + # Post-processing + boxes, polys, mapper = getDetBoxes( + score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) + if not len(boxes): + continue + + # coordinate adjustment + boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) + polys = adjustResultCoordinates(polys, ratio_w, ratio_h) + if estimate_num_chars: + boxes = list(boxes) + polys = list(polys) + for k in range(len(polys)): + if estimate_num_chars: + boxes[k] = (boxes[k], mapper[k]) + if polys[k] is None: + polys[k] = boxes[k] + boxes_list.append(boxes) + polys_list.append(polys) + + preds: List[Prediction] = [] + for boxes in boxes_list: + for box in boxes: + tl, tr, br, bl = box + l = tl[0] + t = tl[1] + r = br[0] + b = br[1] + + pred = Prediction(0, 1, Rectangle(l, t, r, b)) + preds.append(pred) + + return self.create_detection_result(preds, src_size, cvss) + + # width, height, channels + def get_input_details(self) -> Tuple[int, int, int]: + return (self.inputwidth, self.inputheight, 3) + + def get_input_size(self) -> Tuple[float, float]: + return (self.inputwidth, self.inputheight) + + def get_input_format(self) -> str: + return "rgb" \ No newline at end of file