mirror of
https://github.com/koush/scrypted.git
synced 2026-06-20 16:40:30 +01:00
openvino/coreml: wip refactor text recognition
This commit is contained in:
@@ -14,6 +14,11 @@ from scrypted_sdk import Setting, SettingValue
|
||||
|
||||
from common import yolo
|
||||
from coreml.recognition import CoreMLRecognition
|
||||
|
||||
try:
|
||||
from coreml.text_recognition import CoreMLTextRecognition
|
||||
except:
|
||||
CoreMLTextRecognition = None
|
||||
from predict import Prediction, PredictPlugin
|
||||
from predict.rectangle import Rectangle
|
||||
|
||||
@@ -131,25 +136,43 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
|
||||
|
||||
async def prepareRecognitionModels(self):
|
||||
try:
|
||||
devices = [
|
||||
{
|
||||
"nativeId": "recognition",
|
||||
"type": scrypted_sdk.ScryptedDeviceType.Builtin.value,
|
||||
"interfaces": [
|
||||
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
|
||||
],
|
||||
"name": "CoreML Recognition",
|
||||
},
|
||||
]
|
||||
|
||||
if CoreMLTextRecognition:
|
||||
devices.append(
|
||||
{
|
||||
"nativeId": "textrecognition",
|
||||
"type": scrypted_sdk.ScryptedDeviceType.Builtin.value,
|
||||
"interfaces": [
|
||||
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
|
||||
],
|
||||
"name": "CoreML Text Recognition",
|
||||
},
|
||||
)
|
||||
|
||||
await scrypted_sdk.deviceManager.onDevicesChanged(
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"nativeId": "recognition",
|
||||
"type": scrypted_sdk.ScryptedDeviceType.Builtin.value,
|
||||
"interfaces": [
|
||||
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
|
||||
],
|
||||
"name": "CoreML Recognition",
|
||||
}
|
||||
]
|
||||
"devices": devices,
|
||||
}
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
async def getDevice(self, nativeId: str) -> Any:
|
||||
return CoreMLRecognition(nativeId)
|
||||
if nativeId == "recognition":
|
||||
return CoreMLRecognition(nativeId)
|
||||
if nativeId == "textrecognition":
|
||||
return CoreMLTextRecognition(nativeId)
|
||||
raise Exception("unknown device")
|
||||
|
||||
async def getSettings(self) -> list[Setting]:
|
||||
model = self.storage.getItem("model") or "Default"
|
||||
@@ -174,7 +197,7 @@ class CoreMLPlugin(PredictPlugin, scrypted_sdk.Settings, scrypted_sdk.DeviceProv
|
||||
|
||||
def get_input_size(self) -> Tuple[float, float]:
|
||||
return (self.inputwidth, self.inputheight)
|
||||
|
||||
|
||||
async def detect_batch(self, inputs: List[Any]) -> List[Any]:
|
||||
out_dicts = await asyncio.get_event_loop().run_in_executor(
|
||||
predictExecutor, lambda: self.model.predict(inputs)
|
||||
|
||||
39
plugins/coreml/src/coreml/text_recognition.py
Normal file
39
plugins/coreml/src/coreml/text_recognition.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import coremltools as ct
|
||||
|
||||
from predict.text_recognize import TextRecognition
|
||||
|
||||
|
||||
class CoreMLTextRecognition(TextRecognition):
|
||||
def __init__(self, nativeId: str | None = None):
|
||||
super().__init__(nativeId=nativeId)
|
||||
|
||||
def downloadModel(self, model: str):
|
||||
model_version = "v7"
|
||||
mlmodel = "model"
|
||||
|
||||
files = [
|
||||
f"{model}/{model}.mlpackage/Data/com.apple.CoreML/weights/weight.bin",
|
||||
f"{model}/{model}.mlpackage/Data/com.apple.CoreML/{mlmodel}.mlmodel",
|
||||
f"{model}/{model}.mlpackage/Manifest.json",
|
||||
]
|
||||
|
||||
for f in files:
|
||||
p = self.downloadFile(
|
||||
f"https://github.com/koush/coreml-models/raw/main/{f}",
|
||||
f"{model_version}/{f}",
|
||||
)
|
||||
modelFile = os.path.dirname(p)
|
||||
|
||||
model = ct.models.MLModel(modelFile)
|
||||
inputName = model.get_spec().description.input[0].name
|
||||
return model, inputName
|
||||
|
||||
def predictDetectModel(self, input):
|
||||
model, inputName = self.detectModel
|
||||
out_dict = model.predict({inputName: input})
|
||||
results = list(out_dict.values())[0]
|
||||
return results
|
||||
@@ -0,0 +1 @@
|
||||
opencv-python
|
||||
|
||||
@@ -18,6 +18,10 @@ from predict import Prediction, PredictPlugin
|
||||
from predict.rectangle import Rectangle
|
||||
|
||||
from .recognition import OpenVINORecognition
|
||||
try:
|
||||
from .text_recognition import OpenVINOTextRecognition
|
||||
except:
|
||||
OpenVINOTextRecognition = None
|
||||
|
||||
predictExecutor = concurrent.futures.ThreadPoolExecutor(1, "OpenVINO-Predict")
|
||||
|
||||
@@ -326,22 +330,40 @@ class OpenVINOPlugin(
|
||||
|
||||
async def prepareRecognitionModels(self):
|
||||
try:
|
||||
devices = [
|
||||
{
|
||||
"nativeId": "recognition",
|
||||
"type": scrypted_sdk.ScryptedDeviceType.Builtin.value,
|
||||
"interfaces": [
|
||||
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
|
||||
],
|
||||
"name": "OpenVINO Recognition",
|
||||
},
|
||||
]
|
||||
|
||||
if OpenVINOTextRecognition:
|
||||
devices.append(
|
||||
{
|
||||
"nativeId": "textrecognition",
|
||||
"type": scrypted_sdk.ScryptedDeviceType.Builtin.value,
|
||||
"interfaces": [
|
||||
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
|
||||
],
|
||||
"name": "OpenVINO Text Recognition",
|
||||
},
|
||||
)
|
||||
|
||||
await scrypted_sdk.deviceManager.onDevicesChanged(
|
||||
{
|
||||
"devices": [
|
||||
{
|
||||
"nativeId": "recognition",
|
||||
"type": scrypted_sdk.ScryptedDeviceType.Builtin.value,
|
||||
"interfaces": [
|
||||
scrypted_sdk.ScryptedInterface.ObjectDetection.value,
|
||||
],
|
||||
"name": "OpenVINO Recognition",
|
||||
}
|
||||
]
|
||||
"devices": devices,
|
||||
}
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
async def getDevice(self, nativeId: str) -> Any:
|
||||
return OpenVINORecognition(self, nativeId)
|
||||
if nativeId == "recognition":
|
||||
return OpenVINORecognition(self, nativeId)
|
||||
elif nativeId == "textrecognition":
|
||||
return OpenVINOTextRecognition(self, nativeId)
|
||||
raise Exception("unknown device")
|
||||
|
||||
36
plugins/openvino/src/ov/text_recognition.py
Normal file
36
plugins/openvino/src/ov/text_recognition.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import openvino.runtime as ov
|
||||
|
||||
from predict.text_recognize import TextRecognition
|
||||
|
||||
|
||||
class OpenVINOTextRecognition(TextRecognition):
|
||||
def __init__(self, plugin, nativeId: str | None = None):
|
||||
self.plugin = plugin
|
||||
|
||||
super().__init__(nativeId=nativeId)
|
||||
|
||||
def downloadModel(self, model: str):
|
||||
ovmodel = "best"
|
||||
precision = self.plugin.precision
|
||||
model_version = "v5"
|
||||
xmlFile = self.downloadFile(
|
||||
f"https://raw.githubusercontent.com/koush/openvino-models/main/{model}/{precision}/{ovmodel}.xml",
|
||||
f"{model_version}/{model}/{precision}/{ovmodel}.xml",
|
||||
)
|
||||
binFile = self.downloadFile(
|
||||
f"https://raw.githubusercontent.com/koush/openvino-models/main/{model}/{precision}/{ovmodel}.bin",
|
||||
f"{model_version}/{model}/{precision}/{ovmodel}.bin",
|
||||
)
|
||||
print(xmlFile, binFile)
|
||||
return self.plugin.core.compile_model(xmlFile, self.plugin.mode)
|
||||
|
||||
def predictDetectModel(self, input):
|
||||
infer_request = self.detectModel.create_infer_request()
|
||||
im = ov.Tensor(array=input)
|
||||
input_tensor = im
|
||||
infer_request.set_input_tensor(input_tensor)
|
||||
infer_request.start_async()
|
||||
infer_request.wait()
|
||||
return infer_request.output_tensors[0].data
|
||||
1
plugins/openvino/src/requirements.optional.txt
Normal file
1
plugins/openvino/src/requirements.optional.txt
Normal file
@@ -0,0 +1 @@
|
||||
opencv-python
|
||||
@@ -4,3 +4,4 @@ openvino==2024.0.0
|
||||
# pillow-simd confirmed not building with arm64 linux or apple silicon
|
||||
Pillow>=5.4.1; sys_platform != 'linux' or platform_machine != 'x86_64'
|
||||
pillow-simd; sys_platform == 'linux' and platform_machine == 'x86_64'
|
||||
|
||||
|
||||
259
plugins/tensorflow-lite/src/predict/craft_utils.py
Normal file
259
plugins/tensorflow-lite/src/predict/craft_utils.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Copyright (c) 2019-present NAVER Corp.
|
||||
MIT License
|
||||
"""
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
|
||||
def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
|
||||
# should be RGB order
|
||||
img = in_img.copy().astype(np.float32)
|
||||
|
||||
img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32)
|
||||
img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32)
|
||||
return img
|
||||
|
||||
""" auxiliary functions """
|
||||
# unwarp corodinates
|
||||
def warpCoord(Minv, pt):
|
||||
out = np.matmul(Minv, (pt[0], pt[1], 1))
|
||||
return np.array([out[0]/out[2], out[1]/out[2]])
|
||||
""" end of auxiliary functions """
|
||||
|
||||
|
||||
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text, estimate_num_chars=False):
|
||||
# prepare data
|
||||
linkmap = linkmap.copy()
|
||||
textmap = textmap.copy()
|
||||
img_h, img_w = textmap.shape
|
||||
|
||||
""" labeling method """
|
||||
ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
|
||||
ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
|
||||
|
||||
text_score_comb = np.clip(text_score + link_score, 0, 1)
|
||||
nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
|
||||
|
||||
det = []
|
||||
mapper = []
|
||||
for k in range(1,nLabels):
|
||||
# size filtering
|
||||
size = stats[k, cv2.CC_STAT_AREA]
|
||||
if size < 10: continue
|
||||
|
||||
# thresholding
|
||||
if np.max(textmap[labels==k]) < text_threshold: continue
|
||||
|
||||
# make segmentation map
|
||||
segmap = np.zeros(textmap.shape, dtype=np.uint8)
|
||||
segmap[labels==k] = 255
|
||||
if estimate_num_chars:
|
||||
from scipy.ndimage import label
|
||||
_, character_locs = cv2.threshold((textmap - linkmap) * segmap /255., text_threshold, 1, 0)
|
||||
_, n_chars = label(character_locs)
|
||||
mapper.append(n_chars)
|
||||
else:
|
||||
mapper.append(k)
|
||||
segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
|
||||
x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
|
||||
w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
|
||||
niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
|
||||
sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
|
||||
# boundary check
|
||||
if sx < 0 : sx = 0
|
||||
if sy < 0 : sy = 0
|
||||
if ex >= img_w: ex = img_w
|
||||
if ey >= img_h: ey = img_h
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
|
||||
segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
|
||||
|
||||
# make box
|
||||
np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
|
||||
rectangle = cv2.minAreaRect(np_contours)
|
||||
box = cv2.boxPoints(rectangle)
|
||||
|
||||
# align diamond-shape
|
||||
w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
|
||||
box_ratio = max(w, h) / (min(w, h) + 1e-5)
|
||||
if abs(1 - box_ratio) <= 0.1:
|
||||
l, r = min(np_contours[:,0]), max(np_contours[:,0])
|
||||
t, b = min(np_contours[:,1]), max(np_contours[:,1])
|
||||
box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
|
||||
|
||||
# make clock-wise order
|
||||
startidx = box.sum(axis=1).argmin()
|
||||
box = np.roll(box, 4-startidx, 0)
|
||||
box = np.array(box)
|
||||
|
||||
det.append(box)
|
||||
|
||||
return det, labels, mapper
|
||||
|
||||
def getPoly_core(boxes, labels, mapper, linkmap):
|
||||
# configs
|
||||
num_cp = 5
|
||||
max_len_ratio = 0.7
|
||||
expand_ratio = 1.45
|
||||
max_r = 2.0
|
||||
step_r = 0.2
|
||||
|
||||
polys = []
|
||||
for k, box in enumerate(boxes):
|
||||
# size filter for small instance
|
||||
w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
|
||||
if w < 10 or h < 10:
|
||||
polys.append(None); continue
|
||||
|
||||
# warp image
|
||||
tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
|
||||
M = cv2.getPerspectiveTransform(box, tar)
|
||||
word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
|
||||
try:
|
||||
Minv = np.linalg.inv(M)
|
||||
except:
|
||||
polys.append(None); continue
|
||||
|
||||
# binarization for selected label
|
||||
cur_label = mapper[k]
|
||||
word_label[word_label != cur_label] = 0
|
||||
word_label[word_label > 0] = 1
|
||||
|
||||
""" Polygon generation """
|
||||
# find top/bottom contours
|
||||
cp = []
|
||||
max_len = -1
|
||||
for i in range(w):
|
||||
region = np.where(word_label[:,i] != 0)[0]
|
||||
if len(region) < 2 : continue
|
||||
cp.append((i, region[0], region[-1]))
|
||||
length = region[-1] - region[0] + 1
|
||||
if length > max_len: max_len = length
|
||||
|
||||
# pass if max_len is similar to h
|
||||
if h * max_len_ratio < max_len:
|
||||
polys.append(None); continue
|
||||
|
||||
# get pivot points with fixed length
|
||||
tot_seg = num_cp * 2 + 1
|
||||
seg_w = w / tot_seg # segment width
|
||||
pp = [None] * num_cp # init pivot points
|
||||
cp_section = [[0, 0]] * tot_seg
|
||||
seg_height = [0] * num_cp
|
||||
seg_num = 0
|
||||
num_sec = 0
|
||||
prev_h = -1
|
||||
for i in range(0,len(cp)):
|
||||
(x, sy, ey) = cp[i]
|
||||
if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
|
||||
# average previous segment
|
||||
if num_sec == 0: break
|
||||
cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
|
||||
num_sec = 0
|
||||
|
||||
# reset variables
|
||||
seg_num += 1
|
||||
prev_h = -1
|
||||
|
||||
# accumulate center points
|
||||
cy = (sy + ey) * 0.5
|
||||
cur_h = ey - sy + 1
|
||||
cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
|
||||
num_sec += 1
|
||||
|
||||
if seg_num % 2 == 0: continue # No polygon area
|
||||
|
||||
if prev_h < cur_h:
|
||||
pp[int((seg_num - 1)/2)] = (x, cy)
|
||||
seg_height[int((seg_num - 1)/2)] = cur_h
|
||||
prev_h = cur_h
|
||||
|
||||
# processing last segment
|
||||
if num_sec != 0:
|
||||
cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
|
||||
|
||||
# pass if num of pivots is not sufficient or segment width is smaller than character height
|
||||
if None in pp or seg_w < np.max(seg_height) * 0.25:
|
||||
polys.append(None); continue
|
||||
|
||||
# calc median maximum of pivot points
|
||||
half_char_h = np.median(seg_height) * expand_ratio / 2
|
||||
|
||||
# calc gradiant and apply to make horizontal pivots
|
||||
new_pp = []
|
||||
for i, (x, cy) in enumerate(pp):
|
||||
dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
|
||||
dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
|
||||
if dx == 0: # gradient if zero
|
||||
new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
|
||||
continue
|
||||
rad = - math.atan2(dy, dx)
|
||||
c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
|
||||
new_pp.append([x - s, cy - c, x + s, cy + c])
|
||||
|
||||
# get edge points to cover character heatmaps
|
||||
isSppFound, isEppFound = False, False
|
||||
grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
|
||||
grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
|
||||
for r in np.arange(0.5, max_r, step_r):
|
||||
dx = 2 * half_char_h * r
|
||||
if not isSppFound:
|
||||
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
||||
dy = grad_s * dx
|
||||
p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
|
||||
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
||||
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
||||
spp = p
|
||||
isSppFound = True
|
||||
if not isEppFound:
|
||||
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
||||
dy = grad_e * dx
|
||||
p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
|
||||
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
||||
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
||||
epp = p
|
||||
isEppFound = True
|
||||
if isSppFound and isEppFound:
|
||||
break
|
||||
|
||||
# pass if boundary of polygon is not found
|
||||
if not (isSppFound and isEppFound):
|
||||
polys.append(None); continue
|
||||
|
||||
# make final polygon
|
||||
poly = []
|
||||
poly.append(warpCoord(Minv, (spp[0], spp[1])))
|
||||
for p in new_pp:
|
||||
poly.append(warpCoord(Minv, (p[0], p[1])))
|
||||
poly.append(warpCoord(Minv, (epp[0], epp[1])))
|
||||
poly.append(warpCoord(Minv, (epp[2], epp[3])))
|
||||
for p in reversed(new_pp):
|
||||
poly.append(warpCoord(Minv, (p[2], p[3])))
|
||||
poly.append(warpCoord(Minv, (spp[2], spp[3])))
|
||||
|
||||
# add to final result
|
||||
polys.append(np.array(poly))
|
||||
|
||||
return polys
|
||||
|
||||
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False, estimate_num_chars=False):
|
||||
if poly and estimate_num_chars:
|
||||
raise Exception("Estimating the number of characters not currently supported with poly.")
|
||||
boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text, estimate_num_chars)
|
||||
|
||||
if poly:
|
||||
polys = getPoly_core(boxes, labels, mapper, linkmap)
|
||||
else:
|
||||
polys = [None] * len(boxes)
|
||||
|
||||
return boxes, polys, mapper
|
||||
|
||||
def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
|
||||
if len(polys) > 0:
|
||||
polys = np.array(polys)
|
||||
for k in range(len(polys)):
|
||||
if polys[k] is not None:
|
||||
polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
|
||||
return polys
|
||||
@@ -204,6 +204,8 @@ class RecognizeDetection(PredictPlugin):
|
||||
futures.append(asyncio.ensure_future(self.setEmbedding(d, image)))
|
||||
elif d["className"] == "plate":
|
||||
futures.append(asyncio.ensure_future(self.setLabel(d, image)))
|
||||
# elif d['className'] == 'text':
|
||||
# futures.append(asyncio.ensure_future(self.setLabel(d, image)))
|
||||
|
||||
if len(futures):
|
||||
await asyncio.wait(futures)
|
||||
|
||||
107
plugins/tensorflow-lite/src/predict/text_recognize.py
Normal file
107
plugins/tensorflow-lite/src/predict/text_recognize.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import scrypted_sdk
|
||||
from PIL import Image
|
||||
|
||||
from predict import Prediction, PredictPlugin
|
||||
from predict.craft_utils import normalizeMeanVariance
|
||||
from predict.rectangle import Rectangle
|
||||
|
||||
from .craft_utils import adjustResultCoordinates, getDetBoxes
|
||||
|
||||
predictExecutor = concurrent.futures.ThreadPoolExecutor(1, "TextDetect")
|
||||
|
||||
class TextRecognition(PredictPlugin):
|
||||
def __init__(self, nativeId: str | None = None):
|
||||
super().__init__(nativeId=nativeId)
|
||||
|
||||
self.inputheight = 640
|
||||
self.inputwidth = 640
|
||||
|
||||
self.labels = {
|
||||
0: "text",
|
||||
}
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.minThreshold = 0.1
|
||||
|
||||
self.detectModel = self.downloadModel("craft")
|
||||
|
||||
|
||||
def downloadModel(self, model: str):
|
||||
pass
|
||||
|
||||
def predictDetectModel(self, input):
|
||||
pass
|
||||
|
||||
async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss) -> scrypted_sdk.ObjectsDetected:
|
||||
image_tensor = normalizeMeanVariance(np.array(input))
|
||||
# reshape to c w h
|
||||
image_tensor = image_tensor.transpose([2, 0, 1])
|
||||
# add extra dimension to tensor
|
||||
image_tensor = np.expand_dims(image_tensor, axis=0)
|
||||
|
||||
y = await asyncio.get_event_loop().run_in_executor(
|
||||
predictExecutor, lambda: self.predictDetectModel(image_tensor)
|
||||
)
|
||||
|
||||
estimate_num_chars = False
|
||||
ratio_h = ratio_w = 1
|
||||
text_threshold = .7
|
||||
link_threshold = .7
|
||||
low_text = .4
|
||||
poly = False
|
||||
|
||||
boxes_list, polys_list = [], []
|
||||
for out in y:
|
||||
# make score and link map
|
||||
score_text = out[:, :, 0]
|
||||
score_link = out[:, :, 1]
|
||||
|
||||
# Post-processing
|
||||
boxes, polys, mapper = getDetBoxes(
|
||||
score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)
|
||||
if not len(boxes):
|
||||
continue
|
||||
|
||||
# coordinate adjustment
|
||||
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
||||
polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
|
||||
if estimate_num_chars:
|
||||
boxes = list(boxes)
|
||||
polys = list(polys)
|
||||
for k in range(len(polys)):
|
||||
if estimate_num_chars:
|
||||
boxes[k] = (boxes[k], mapper[k])
|
||||
if polys[k] is None:
|
||||
polys[k] = boxes[k]
|
||||
boxes_list.append(boxes)
|
||||
polys_list.append(polys)
|
||||
|
||||
preds: List[Prediction] = []
|
||||
for boxes in boxes_list:
|
||||
for box in boxes:
|
||||
tl, tr, br, bl = box
|
||||
l = tl[0]
|
||||
t = tl[1]
|
||||
r = br[0]
|
||||
b = br[1]
|
||||
|
||||
pred = Prediction(0, 1, Rectangle(l, t, r, b))
|
||||
preds.append(pred)
|
||||
|
||||
return self.create_detection_result(preds, src_size, cvss)
|
||||
|
||||
# width, height, channels
|
||||
def get_input_details(self) -> Tuple[int, int, int]:
|
||||
return (self.inputwidth, self.inputheight, 3)
|
||||
|
||||
def get_input_size(self) -> Tuple[float, float]:
|
||||
return (self.inputwidth, self.inputheight)
|
||||
|
||||
def get_input_format(self) -> str:
|
||||
return "rgb"
|
||||
Reference in New Issue
Block a user