parent
3ef890c4c5
commit
0e8d2c3c4f
After Width: | Height: | Size: 42 KiB |
@ -0,0 +1 @@
|
|||||||
|
<html><body><table><thead><tr><td></td><td></td><td>版本</td></tr></thead><tbody><tr><td>上海长江云息</td><td>文档编号</td><td>密级</td></tr><tr><td></td><td>CJYX202211240001<</td><td>V1.0e 机密</td></tr><tr><td>」数字科技有限公司</td><td>再生资源智慧管家平台之货客智慧管控平台一</td><td>共51页</td></tr></tbody></table></body></html>
|
Binary file not shown.
After Width: | Height: | Size: 23 KiB |
After Width: | Height: | Size: 28 KiB |
After Width: | Height: | Size: 19 KiB |
@ -0,0 +1,83 @@
|
|||||||
|
package jnpf.ocr_sdk;
|
||||||
|
|
||||||
|
import ai.djl.ModelException;
|
||||||
|
import ai.djl.inference.Predictor;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.opencv.OpenCVImageFactory;
|
||||||
|
import ai.djl.repository.zoo.ModelZoo;
|
||||||
|
import ai.djl.repository.zoo.ZooModel;
|
||||||
|
import ai.djl.translate.TranslateException;
|
||||||
|
import jnpf.ocr_sdk.utils.common.ImageUtils;
|
||||||
|
import jnpf.ocr_sdk.utils.common.RotatedBox;
|
||||||
|
import jnpf.ocr_sdk.utils.detection.OcrV3Detection;
|
||||||
|
import jnpf.ocr_sdk.utils.opencv.OpenCVUtils;
|
||||||
|
import jnpf.ocr_sdk.utils.recognition.OcrV3Recognition;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OCR V3模型 文字识别. 支持文本有旋转角度
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
* @date 2022-10-07
|
||||||
|
* @email 179209347@qq.com
|
||||||
|
*/
|
||||||
|
public final class OcrV3RecognitionExample {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(OcrV3RecognitionExample.class);
|
||||||
|
|
||||||
|
private OcrV3RecognitionExample() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||||
|
// Path imageFile = Paths.get("src/test/resources/7.jpg");
|
||||||
|
String relativelyPath=System.getProperty("user.dir");
|
||||||
|
// Path imageFile = Paths.get("C:\\Users\\admin\\Desktop\\图像\\AISDK\\AIAS\\1_image_sdks\\text_recognition\\ocr_sdk\\src\\test\\resources\\7.jpg");
|
||||||
|
// Path imageFile = Paths.get("src/test/resources/7.jpg");
|
||||||
|
StringBuffer ocrStr = new StringBuffer("本次识别的内容:");
|
||||||
|
Path imageFile = Paths.get("C:/Users/admin/Desktop/AAAA.png");
|
||||||
|
Image image = OpenCVImageFactory.getInstance().fromFile(imageFile);
|
||||||
|
|
||||||
|
|
||||||
|
OcrV3Detection detection = new OcrV3Detection();
|
||||||
|
OcrV3Recognition recognition = new OcrV3Recognition();
|
||||||
|
try (ZooModel detectionModel = ModelZoo.loadModel(detection.detectCriteria());
|
||||||
|
Predictor<Image, NDList> detector = detectionModel.newPredictor();
|
||||||
|
ZooModel recognitionModel = ModelZoo.loadModel(recognition.recognizeCriteria());
|
||||||
|
Predictor<Image, String> recognizer = recognitionModel.newPredictor()) {
|
||||||
|
|
||||||
|
long timeInferStart = System.currentTimeMillis();
|
||||||
|
List<RotatedBox> detections = recognition.predict(image, detector, recognizer);
|
||||||
|
|
||||||
|
// for (int i = 0; i < 1000; i++) {
|
||||||
|
// detections = recognition.predict(image, detector, recognizer);
|
||||||
|
// System.out.println("time: " + i);
|
||||||
|
// }
|
||||||
|
|
||||||
|
long timeInferEnd = System.currentTimeMillis();
|
||||||
|
System.out.println("time: " + (timeInferEnd - timeInferStart));
|
||||||
|
|
||||||
|
for (RotatedBox result : detections) {
|
||||||
|
System.out.println(result.getText());
|
||||||
|
ocrStr.append(result.getText());
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferedImage bufferedImage = OpenCVUtils.mat2Image((org.opencv.core.Mat) image.getWrappedImage());
|
||||||
|
for (RotatedBox result : detections) {
|
||||||
|
ImageUtils.drawImageRectWithText(bufferedImage, result.getBox(), result.getText());
|
||||||
|
}
|
||||||
|
image = ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(bufferedImage));
|
||||||
|
ImageUtils.saveImage(image, "ocr_result.png", "build/output");
|
||||||
|
logger.info("{}", detections);
|
||||||
|
logger.info(ocrStr.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,141 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.cls;
|
||||||
|
|
||||||
|
import ai.djl.inference.Predictor;
|
||||||
|
import ai.djl.modality.Classifications;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
import ai.djl.modality.cv.output.DetectedObjects;
|
||||||
|
import ai.djl.modality.cv.output.Rectangle;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDManager;
|
||||||
|
import ai.djl.repository.zoo.Criteria;
|
||||||
|
import ai.djl.training.util.ProgressBar;
|
||||||
|
import ai.djl.translate.TranslateException;
|
||||||
|
import jnpf.ocr_sdk.utils.detection.PpWordDetectionTranslator;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public final class OcrDirectionDetection {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(OcrDirectionDetection.class);
|
||||||
|
|
||||||
|
public OcrDirectionDetection() {}
|
||||||
|
|
||||||
|
public DetectedObjects predict(
|
||||||
|
Image image,
|
||||||
|
Predictor<Image, DetectedObjects> detector,
|
||||||
|
Predictor<Image, Classifications> rotateClassifier)
|
||||||
|
throws TranslateException {
|
||||||
|
DetectedObjects detections = detector.predict(image);
|
||||||
|
|
||||||
|
List<DetectedObjects.DetectedObject> boxes = detections.items();
|
||||||
|
|
||||||
|
List<String> names = new ArrayList<>();
|
||||||
|
List<Double> prob = new ArrayList<>();
|
||||||
|
List<BoundingBox> rect = new ArrayList<>();
|
||||||
|
|
||||||
|
for (int i = 0; i < boxes.size(); i++) {
|
||||||
|
Image subImg = getSubImage(image, boxes.get(i).getBoundingBox());
|
||||||
|
Classifications.Classification result = null;
|
||||||
|
if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {
|
||||||
|
subImg = rotateImg(subImg);
|
||||||
|
result = rotateClassifier.predict(subImg).best();
|
||||||
|
prob.add(result.getProbability());
|
||||||
|
if (result.getClassName().equalsIgnoreCase("Rotate")) {
|
||||||
|
names.add("90");
|
||||||
|
} else {
|
||||||
|
names.add("270");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result = rotateClassifier.predict(subImg).best();
|
||||||
|
prob.add(result.getProbability());
|
||||||
|
if (result.getClassName().equalsIgnoreCase("No Rotate")) {
|
||||||
|
names.add("0");
|
||||||
|
} else {
|
||||||
|
names.add("180");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rect.add(boxes.get(i).getBoundingBox());
|
||||||
|
}
|
||||||
|
DetectedObjects detectedObjects = new DetectedObjects(names, prob, rect);
|
||||||
|
|
||||||
|
return detectedObjects;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria<Image, DetectedObjects> detectCriteria() {
|
||||||
|
Criteria<Image, DetectedObjects> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, DetectedObjects.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_PP-OCRv2_det_infer.zip")
|
||||||
|
// .optModelUrls(
|
||||||
|
// "/Users/calvin/Documents/build/paddle_models/ppocr/ch_PP-OCRv2_det_infer")
|
||||||
|
// .optDevice(Device.cpu())
|
||||||
|
.optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria<Image, Classifications> clsCriteria() {
|
||||||
|
|
||||||
|
Criteria<Image, Classifications> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, Classifications.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_ppocr_mobile_v2.0_cls_infer.zip")
|
||||||
|
// .optModelUrls(
|
||||||
|
// "/Users/calvin/Documents/build/paddle_models/ppocr/ch_ppocr_mobile_v2.0_cls_infer")
|
||||||
|
.optTranslator(new PpWordRotateTranslator())
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.build();
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Image getSubImage(Image img, BoundingBox box) {
|
||||||
|
Rectangle rect = box.getBounds();
|
||||||
|
double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());
|
||||||
|
int width = img.getWidth();
|
||||||
|
int height = img.getHeight();
|
||||||
|
int[] recovered = {
|
||||||
|
(int) (extended[0] * width),
|
||||||
|
(int) (extended[1] * height),
|
||||||
|
(int) (extended[2] * width),
|
||||||
|
(int) (extended[3] * height)
|
||||||
|
};
|
||||||
|
return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] extendRect(double xmin, double ymin, double width, double height) {
|
||||||
|
double centerx = xmin + width / 2;
|
||||||
|
double centery = ymin + height / 2;
|
||||||
|
if (width > height) {
|
||||||
|
width += height * 2.0;
|
||||||
|
height *= 3.0;
|
||||||
|
} else {
|
||||||
|
height += width * 2.0;
|
||||||
|
width *= 3.0;
|
||||||
|
}
|
||||||
|
double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;
|
||||||
|
double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;
|
||||||
|
double newWidth = newX + width > 1 ? 1 - newX : width;
|
||||||
|
double newHeight = newY + height > 1 ? 1 - newY : height;
|
||||||
|
return new double[] {newX, newY, newWidth, newHeight};
|
||||||
|
}
|
||||||
|
|
||||||
|
private Image rotateImg(Image image) {
|
||||||
|
try (NDManager manager = NDManager.newBaseManager()) {
|
||||||
|
NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);
|
||||||
|
return ImageFactory.getInstance().fromNDArray(rotated);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,69 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.cls;
|
||||||
|
|
||||||
|
import ai.djl.modality.Classifications;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.ndarray.index.NDIndex;
|
||||||
|
import ai.djl.ndarray.types.Shape;
|
||||||
|
import ai.djl.translate.Batchifier;
|
||||||
|
import ai.djl.translate.Translator;
|
||||||
|
import ai.djl.translate.TranslatorContext;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class PpWordRotateTranslator implements Translator<Image, Classifications> {
|
||||||
|
List<String> classes = Arrays.asList("No Rotate", "Rotate");
|
||||||
|
|
||||||
|
public PpWordRotateTranslator() {}
|
||||||
|
|
||||||
|
public Classifications processOutput(TranslatorContext ctx, NDList list) {
|
||||||
|
NDArray prob = list.singletonOrThrow();
|
||||||
|
return new Classifications(this.classes, prob);
|
||||||
|
}
|
||||||
|
|
||||||
|
public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
|
||||||
|
NDArray img = input.toNDArray(ctx.getNDManager());
|
||||||
|
img = NDImageUtils.resize(img, 192, 48);
|
||||||
|
img = NDImageUtils.toTensor(img).sub(0.5F).div(0.5F);
|
||||||
|
img = img.expandDims(0);
|
||||||
|
return new NDList(new NDArray[]{img});
|
||||||
|
}
|
||||||
|
|
||||||
|
public NDList processInputBak(TranslatorContext ctx, Image input) throws Exception {
|
||||||
|
NDArray img = input.toNDArray(ctx.getNDManager());
|
||||||
|
int imgC = 3;
|
||||||
|
int imgH = 48;
|
||||||
|
int imgW = 192;
|
||||||
|
|
||||||
|
NDArray array = ctx.getNDManager().zeros(new Shape(imgC, imgH, imgW));
|
||||||
|
|
||||||
|
int h = input.getHeight();
|
||||||
|
int w = input.getWidth();
|
||||||
|
int resized_w = 0;
|
||||||
|
|
||||||
|
float ratio = (float) w / (float) h;
|
||||||
|
if (Math.ceil(imgH * ratio) > imgW) {
|
||||||
|
resized_w = imgW;
|
||||||
|
} else {
|
||||||
|
resized_w = (int) (Math.ceil(imgH * ratio));
|
||||||
|
}
|
||||||
|
|
||||||
|
img = NDImageUtils.resize(img, resized_w, imgH);
|
||||||
|
|
||||||
|
img = NDImageUtils.toTensor(img).sub(0.5F).div(0.5F);
|
||||||
|
// img = img.transpose(2, 0, 1);
|
||||||
|
|
||||||
|
array.set(new NDIndex(":,:,0:" + resized_w), img);
|
||||||
|
|
||||||
|
array = array.expandDims(0);
|
||||||
|
|
||||||
|
return new NDList(new NDArray[] {array});
|
||||||
|
}
|
||||||
|
|
||||||
|
public Batchifier getBatchifier() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,99 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.common;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.modality.cv.output.DetectedObjects;
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
|
||||||
|
public class DJLImageUtils {
|
||||||
|
|
||||||
|
public static Image bufferedImage2DJLImage(BufferedImage img) {
|
||||||
|
return ImageFactory.getInstance().fromImage(img);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void saveImage(BufferedImage img, String name, String path) {
|
||||||
|
|
||||||
|
Image djlImg = ImageFactory.getInstance().fromImage(img); // 支持多种图片格式,自动适配
|
||||||
|
Path outputDir = Paths.get(path);
|
||||||
|
Path imagePath = outputDir.resolve(name);
|
||||||
|
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||||
|
try {
|
||||||
|
djlImg.save(Files.newOutputStream(imagePath), "png");
|
||||||
|
} catch (IOException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void saveDJLImage(Image img, String name, String path) {
|
||||||
|
Path outputDir = Paths.get(path);
|
||||||
|
Path imagePath = outputDir.resolve(name);
|
||||||
|
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||||
|
try {
|
||||||
|
img.save(Files.newOutputStream(imagePath), "png");
|
||||||
|
} catch (IOException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void saveBoundingBoxImage(
|
||||||
|
Image img, DetectedObjects detection, String name, String path) throws IOException {
|
||||||
|
// Make imageName copy with alpha channel because original imageName was jpg
|
||||||
|
img.drawBoundingBoxes(detection);
|
||||||
|
Path outputDir = Paths.get(path);
|
||||||
|
Files.createDirectories(outputDir);
|
||||||
|
Path imagePath = outputDir.resolve(name);
|
||||||
|
// OpenJDK can't save jpg with alpha channel
|
||||||
|
img.save(Files.newOutputStream(imagePath), "png");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void drawImageRect(BufferedImage image, int x, int y, int width, int height) {
|
||||||
|
// 将绘制图像转换为Graphics2D
|
||||||
|
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||||
|
try {
|
||||||
|
g.setColor(new Color(246, 96, 0));
|
||||||
|
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||||
|
BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||||
|
g.setStroke(bStroke);
|
||||||
|
g.drawRect(x, y, width, height);
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
g.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void drawImageRect(
|
||||||
|
BufferedImage image, int x, int y, int width, int height, Color c) {
|
||||||
|
// 将绘制图像转换为Graphics2D
|
||||||
|
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||||
|
try {
|
||||||
|
g.setColor(c);
|
||||||
|
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||||
|
BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||||
|
g.setStroke(bStroke);
|
||||||
|
g.drawRect(x, y, width, height);
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
g.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void drawImageText(BufferedImage image, String text) {
|
||||||
|
Graphics graphics = image.getGraphics();
|
||||||
|
int fontSize = 100;
|
||||||
|
Font font = new Font("楷体", Font.PLAIN, fontSize);
|
||||||
|
try {
|
||||||
|
graphics.setFont(font);
|
||||||
|
graphics.setColor(new Color(246, 96, 0));
|
||||||
|
int strWidth = graphics.getFontMetrics().stringWidth(text);
|
||||||
|
graphics.drawString(text, fontSize - (strWidth / 2), fontSize + 30);
|
||||||
|
} finally {
|
||||||
|
graphics.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,48 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.common;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
|
||||||
|
public class ImageInfo {
|
||||||
|
private String name;
|
||||||
|
private Double prob;
|
||||||
|
private Image image;
|
||||||
|
private BoundingBox box;
|
||||||
|
|
||||||
|
public ImageInfo(Image image, BoundingBox box) {
|
||||||
|
this.image = image;
|
||||||
|
this.box = box;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setName(String name) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getProb() {
|
||||||
|
return prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setProb(Double prob) {
|
||||||
|
this.prob = prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Image getImage() {
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setImage(Image image) {
|
||||||
|
this.image = image;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BoundingBox getBox() {
|
||||||
|
return box;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBox(BoundingBox box) {
|
||||||
|
this.box = box;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,262 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.common;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
import ai.djl.modality.cv.output.DetectedObjects;
|
||||||
|
import ai.djl.modality.cv.output.Rectangle;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
|
||||||
|
public class ImageUtils {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* BufferedImage图片格式转DJL图片格式
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static Image convert(BufferedImage img) {
|
||||||
|
return ImageFactory.getInstance().fromImage(img);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 保存BufferedImage图片
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void saveImage(BufferedImage img, String name, String path) {
|
||||||
|
Image djlImg = ImageFactory.getInstance().fromImage(img); // 支持多种图片格式,自动适配
|
||||||
|
Path outputDir = Paths.get(path);
|
||||||
|
Path imagePath = outputDir.resolve(name);
|
||||||
|
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||||
|
try {
|
||||||
|
djlImg.save(Files.newOutputStream(imagePath), "png");
|
||||||
|
} catch (IOException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 保存DJL图片
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void saveImage(Image img, String name, String path) {
|
||||||
|
Path outputDir = Paths.get(path);
|
||||||
|
Path imagePath = outputDir.resolve(name);
|
||||||
|
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||||
|
try {
|
||||||
|
img.save(Files.newOutputStream(imagePath), "png");
|
||||||
|
} catch (IOException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 保存图片,含检测框
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void saveBoundingBoxImage(
|
||||||
|
Image img, DetectedObjects detection, String name, String path) throws IOException {
|
||||||
|
// Make image copy with alpha channel because original image was jpg
|
||||||
|
img.drawBoundingBoxes(detection);
|
||||||
|
Path outputDir = Paths.get(path);
|
||||||
|
Files.createDirectories(outputDir);
|
||||||
|
Path imagePath = outputDir.resolve(name);
|
||||||
|
// OpenJDK can't save jpg with alpha channel
|
||||||
|
img.save(Files.newOutputStream(imagePath), "png");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 绘制人脸关键点
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void drawLandmark(Image img, BoundingBox box, float[] array) {
|
||||||
|
for (int i = 0; i < array.length / 2; i++) {
|
||||||
|
int x = getX(img, box, array[2 * i]);
|
||||||
|
int y = getY(img, box, array[2 * i + 1]);
|
||||||
|
Color c = new Color(0, 255, 0);
|
||||||
|
drawImageRect((BufferedImage) img.getWrappedImage(), x, y, 1, 1, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 画检测框(有倾斜角)
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void drawImageRect(BufferedImage image, NDArray box) {
|
||||||
|
float[] points = box.toFloatArray();
|
||||||
|
int[] xPoints = new int[5];
|
||||||
|
int[] yPoints = new int[5];
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
xPoints[i] = (int) points[2 * i];
|
||||||
|
yPoints[i] = (int) points[2 * i + 1];
|
||||||
|
}
|
||||||
|
xPoints[4] = xPoints[0];
|
||||||
|
yPoints[4] = yPoints[0];
|
||||||
|
|
||||||
|
// 将绘制图像转换为Graphics2D
|
||||||
|
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||||
|
try {
|
||||||
|
g.setColor(new Color(0, 255, 0));
|
||||||
|
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||||
|
BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||||
|
g.setStroke(bStroke);
|
||||||
|
g.drawPolyline(xPoints, yPoints, 5); // xPoints, yPoints, nPoints
|
||||||
|
} finally {
|
||||||
|
g.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 画检测框(有倾斜角)和文本
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void drawImageRectWithText(BufferedImage image, NDArray box, String text) {
|
||||||
|
float[] points = box.toFloatArray();
|
||||||
|
int[] xPoints = new int[5];
|
||||||
|
int[] yPoints = new int[5];
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
xPoints[i] = (int) points[2 * i];
|
||||||
|
yPoints[i] = (int) points[2 * i + 1];
|
||||||
|
}
|
||||||
|
xPoints[4] = xPoints[0];
|
||||||
|
yPoints[4] = yPoints[0];
|
||||||
|
|
||||||
|
// 将绘制图像转换为Graphics2D
|
||||||
|
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||||
|
try {
|
||||||
|
int fontSize = 32;
|
||||||
|
Font font = new Font("楷体", Font.PLAIN, fontSize);
|
||||||
|
g.setFont(font);
|
||||||
|
g.setColor(new Color(0, 0, 255));
|
||||||
|
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||||
|
BasicStroke bStroke = new BasicStroke(2, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||||
|
g.setStroke(bStroke);
|
||||||
|
g.drawPolyline(xPoints, yPoints, 5); // xPoints, yPoints, nPoints
|
||||||
|
g.drawString(text, xPoints[0], yPoints[0]);
|
||||||
|
} finally {
|
||||||
|
g.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 画检测框
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void drawImageRect(BufferedImage image, int x, int y, int width, int height) {
|
||||||
|
// 将绘制图像转换为Graphics2D
|
||||||
|
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||||
|
try {
|
||||||
|
g.setColor(new Color(0, 255, 0));
|
||||||
|
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||||
|
BasicStroke bStroke = new BasicStroke(2, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||||
|
g.setStroke(bStroke);
|
||||||
|
g.drawRect(x, y, width, height);
|
||||||
|
} finally {
|
||||||
|
g.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 画检测框
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void drawImageRect(
|
||||||
|
BufferedImage image, int x, int y, int width, int height, Color c) {
|
||||||
|
// 将绘制图像转换为Graphics2D
|
||||||
|
Graphics2D g = (Graphics2D) image.getGraphics();
|
||||||
|
try {
|
||||||
|
g.setColor(c);
|
||||||
|
// 声明画笔属性 :粗 细(单位像素)末端无修饰 折线处呈尖角
|
||||||
|
BasicStroke bStroke = new BasicStroke(1, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
|
||||||
|
g.setStroke(bStroke);
|
||||||
|
g.drawRect(x, y, width, height);
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
g.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 显示文字
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static void drawImageText(BufferedImage image, String text, int x, int y) {
|
||||||
|
Graphics graphics = image.getGraphics();
|
||||||
|
int fontSize = 32;
|
||||||
|
Font font = new Font("楷体", Font.PLAIN, fontSize);
|
||||||
|
try {
|
||||||
|
graphics.setFont(font);
|
||||||
|
graphics.setColor(new Color(0, 0, 255));
|
||||||
|
int strWidth = graphics.getFontMetrics().stringWidth(text);
|
||||||
|
graphics.drawString(text, x, y);
|
||||||
|
} finally {
|
||||||
|
graphics.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 返回外扩人脸 factor = 1, 100%, factor = 0.2, 20%
|
||||||
|
*
|
||||||
|
* @author Calvin
|
||||||
|
*/
|
||||||
|
public static Image getSubImage(Image img, BoundingBox box, float factor) {
|
||||||
|
Rectangle rect = box.getBounds();
|
||||||
|
// 左上角坐标
|
||||||
|
int x1 = (int) (rect.getX() * img.getWidth());
|
||||||
|
int y1 = (int) (rect.getY() * img.getHeight());
|
||||||
|
// 宽度,高度
|
||||||
|
int w = (int) (rect.getWidth() * img.getWidth());
|
||||||
|
int h = (int) (rect.getHeight() * img.getHeight());
|
||||||
|
// 左上角坐标
|
||||||
|
int x2 = x1 + w;
|
||||||
|
int y2 = y1 + h;
|
||||||
|
|
||||||
|
// 外扩大100%,防止对齐后人脸出现黑边
|
||||||
|
int new_x1 = Math.max((int) (x1 + x1 * factor / 2 - x2 * factor / 2), 0);
|
||||||
|
int new_x2 = Math.min((int) (x2 + x2 * factor / 2 - x1 * factor / 2), img.getWidth() - 1);
|
||||||
|
int new_y1 = Math.max((int) (y1 + y1 * factor / 2 - y2 * factor / 2), 0);
|
||||||
|
int new_y2 = Math.min((int) (y2 + y2 * factor / 2 - y1 * factor / 2), img.getHeight() - 1);
|
||||||
|
int new_w = new_x2 - new_x1;
|
||||||
|
int new_h = new_y2 - new_y1;
|
||||||
|
|
||||||
|
return img.getSubImage(new_x1, new_y1, new_w, new_h);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int getX(Image img, BoundingBox box, float x) {
|
||||||
|
Rectangle rect = box.getBounds();
|
||||||
|
// 左上角坐标
|
||||||
|
int x1 = (int) (rect.getX() * img.getWidth());
|
||||||
|
// 宽度
|
||||||
|
int w = (int) (rect.getWidth() * img.getWidth());
|
||||||
|
|
||||||
|
return (int) (x * w + x1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int getY(Image img, BoundingBox box, float y) {
|
||||||
|
Rectangle rect = box.getBounds();
|
||||||
|
// 左上角坐标
|
||||||
|
int y1 = (int) (rect.getY() * img.getHeight());
|
||||||
|
// 高度
|
||||||
|
int h = (int) (rect.getHeight() * img.getHeight());
|
||||||
|
|
||||||
|
return (int) (y * h + y1);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,29 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.common;
|
||||||
|
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
|
||||||
|
public class RotatedBox {
|
||||||
|
private NDArray box;
|
||||||
|
private String text;
|
||||||
|
|
||||||
|
public RotatedBox(NDArray box, String text) {
|
||||||
|
this.box = box;
|
||||||
|
this.text = text;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NDArray getBox() {
|
||||||
|
return box;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBox(NDArray box) {
|
||||||
|
this.box = box;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getText() {
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setText(String text) {
|
||||||
|
this.text = text;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.detection;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.repository.zoo.Criteria;
|
||||||
|
import ai.djl.training.util.ProgressBar;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public final class OcrV3Detection {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(OcrV3Detection.class);
|
||||||
|
|
||||||
|
public OcrV3Detection() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria<Image, NDList> detectCriteria() {
|
||||||
|
Criteria<Image, NDList> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, NDList.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_PP-OCRv3_det_infer.zip")
|
||||||
|
// .optModelUrls(
|
||||||
|
// "/Users/calvin/Documents/build/paddle_models/ppocr/ch_PP-OCRv2_det_infer")
|
||||||
|
.optTranslator(new OCRDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,120 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
|
||||||
|
* with the License. A copy of the License is located at
|
||||||
|
*
|
||||||
|
* http://aws.amazon.com/apache2.0/
|
||||||
|
*
|
||||||
|
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
|
||||||
|
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
|
||||||
|
* and limitations under the License.
|
||||||
|
*/
|
||||||
|
package jnpf.ocr_sdk.utils.detection;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
import ai.djl.modality.cv.output.DetectedObjects;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.ndarray.types.DataType;
|
||||||
|
import ai.djl.ndarray.types.Shape;
|
||||||
|
import ai.djl.paddlepaddle.zoo.cv.objectdetection.BoundFinder;
|
||||||
|
import ai.djl.translate.Batchifier;
|
||||||
|
import ai.djl.translate.Translator;
|
||||||
|
import ai.djl.translate.TranslatorContext;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
public class PpWordDetectionTranslator implements Translator<Image, DetectedObjects> {
|
||||||
|
|
||||||
|
private final int max_side_len;
|
||||||
|
|
||||||
|
public PpWordDetectionTranslator(Map<String, ?> arguments) {
|
||||||
|
max_side_len =
|
||||||
|
arguments.containsKey("maxLength")
|
||||||
|
? Integer.parseInt(arguments.get("maxLength").toString())
|
||||||
|
: 960;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
|
||||||
|
NDArray result = list.singletonOrThrow();
|
||||||
|
result = result.squeeze().mul(255f).toType(DataType.UINT8, true).gt(0.3); // thresh=0.3
|
||||||
|
boolean[] flattened = result.toBooleanArray();
|
||||||
|
Shape shape = result.getShape();
|
||||||
|
int w = (int) shape.get(0);
|
||||||
|
int h = (int) shape.get(1);
|
||||||
|
boolean[][] grid = new boolean[w][h];
|
||||||
|
IntStream.range(0, flattened.length)
|
||||||
|
.parallel()
|
||||||
|
.forEach(i -> grid[i / h][i % h] = flattened[i]);
|
||||||
|
List<BoundingBox> boxes = new BoundFinder(grid).getBoxes();
|
||||||
|
List<String> names = new ArrayList<>();
|
||||||
|
List<Double> probs = new ArrayList<>();
|
||||||
|
int boxSize = boxes.size();
|
||||||
|
for (int i = 0; i < boxSize; i++) {
|
||||||
|
names.add("word");
|
||||||
|
probs.add(1.0);
|
||||||
|
}
|
||||||
|
return new DetectedObjects(names, probs, boxes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||||
|
NDArray img = input.toNDArray(ctx.getNDManager());
|
||||||
|
int h = input.getHeight();
|
||||||
|
int w = input.getWidth();
|
||||||
|
int resize_w = w;
|
||||||
|
int resize_h = h;
|
||||||
|
|
||||||
|
// limit the max side
|
||||||
|
float ratio = 1.0f;
|
||||||
|
if (Math.max(resize_h, resize_w) > max_side_len) {
|
||||||
|
if (resize_h > resize_w) {
|
||||||
|
ratio = (float) max_side_len / (float) resize_h;
|
||||||
|
} else {
|
||||||
|
ratio = (float) max_side_len / (float) resize_w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resize_h = (int) (resize_h * ratio);
|
||||||
|
resize_w = (int) (resize_w * ratio);
|
||||||
|
|
||||||
|
if (resize_h % 32 == 0) {
|
||||||
|
resize_h = resize_h;
|
||||||
|
} else if (Math.floor((float) resize_h / 32f) <= 1) {
|
||||||
|
resize_h = 32;
|
||||||
|
} else {
|
||||||
|
resize_h = (int) Math.floor((float) resize_h / 32f) * 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (resize_w % 32 == 0) {
|
||||||
|
resize_w = resize_w;
|
||||||
|
} else if (Math.floor((float) resize_w / 32f) <= 1) {
|
||||||
|
resize_w = 32;
|
||||||
|
} else {
|
||||||
|
resize_w = (int) Math.floor((float) resize_w / 32f) * 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
img = NDImageUtils.resize(img, resize_w, resize_h);
|
||||||
|
img = NDImageUtils.toTensor(img);
|
||||||
|
img =
|
||||||
|
NDImageUtils.normalize(
|
||||||
|
img,
|
||||||
|
new float[]{0.485f, 0.456f, 0.406f},
|
||||||
|
new float[]{0.229f, 0.224f, 0.225f});
|
||||||
|
img = img.expandDims(0);
|
||||||
|
return new NDList(img);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Batchifier getBatchifier() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,32 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.layout;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.output.DetectedObjects;
|
||||||
|
import ai.djl.repository.zoo.Criteria;
|
||||||
|
import ai.djl.training.util.ProgressBar;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
public final class LayoutDetection {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(LayoutDetection.class);
|
||||||
|
|
||||||
|
public LayoutDetection() {}
|
||||||
|
|
||||||
|
public Criteria<Image, DetectedObjects> criteria() {
|
||||||
|
|
||||||
|
Criteria<Image, DetectedObjects> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, DetectedObjects.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ppyolov2_r50vd_dcn_365e_publaynet_infer.zip")
|
||||||
|
// .optModelUrls(
|
||||||
|
// "/Users/calvin/.paddledet/inference_model/ppyolov2_r50vd_dcn_365e_publaynet/ppyolov2_r50vd_dcn_365e_publaynet_infer")
|
||||||
|
.optTranslator(new LayoutDetectionTranslator())
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,104 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.layout;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
import ai.djl.modality.cv.output.DetectedObjects;
|
||||||
|
import ai.djl.modality.cv.output.Rectangle;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.ndarray.types.DataType;
|
||||||
|
import ai.djl.translate.Batchifier;
|
||||||
|
import ai.djl.translate.Translator;
|
||||||
|
import ai.djl.translate.TranslatorContext;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class LayoutDetectionTranslator implements Translator<Image, DetectedObjects> {
|
||||||
|
|
||||||
|
private int width;
|
||||||
|
private int height;
|
||||||
|
|
||||||
|
public LayoutDetectionTranslator() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
|
||||||
|
NDArray result = list.get(0); // np_boxes
|
||||||
|
long rows = result.size(0);
|
||||||
|
|
||||||
|
List<BoundingBox> boxes = new ArrayList<>();
|
||||||
|
List<String> names = new ArrayList<>();
|
||||||
|
List<Double> probs = new ArrayList<>();
|
||||||
|
|
||||||
|
for (long i = 0; i < rows; i++) {
|
||||||
|
NDArray row = result.get(i);
|
||||||
|
float[] array = row.toFloatArray();
|
||||||
|
if (array[1] <= 0.5 || array[0] <= -1) continue;
|
||||||
|
int clsid = (int) array[0];
|
||||||
|
double score = array[1];
|
||||||
|
String name = "";
|
||||||
|
switch (clsid) {
|
||||||
|
case 0:
|
||||||
|
name = "Text";
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
name = "Title";
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
name = "List";
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
name = "Table";
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
name = "Figure";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
name = "Unknown";
|
||||||
|
}
|
||||||
|
|
||||||
|
float x = array[2] / width;
|
||||||
|
float y = array[3] / height;
|
||||||
|
float w = (array[4] - array[2]) / width;
|
||||||
|
float h = (array[5] - array[3]) / height;
|
||||||
|
|
||||||
|
Rectangle rect = new Rectangle(x, y, w, h);
|
||||||
|
boxes.add(rect);
|
||||||
|
names.add(name);
|
||||||
|
probs.add(score);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new DetectedObjects(names, probs, boxes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||||
|
NDArray img = input.toNDArray(ctx.getNDManager());
|
||||||
|
width = input.getWidth();
|
||||||
|
height = input.getHeight();
|
||||||
|
|
||||||
|
img = NDImageUtils.resize(img, 640, 640);
|
||||||
|
img = img.transpose(2, 0, 1).div(255);
|
||||||
|
img =
|
||||||
|
NDImageUtils.normalize(
|
||||||
|
img, new float[] {0.485f, 0.456f, 0.406f}, new float[] {0.229f, 0.224f, 0.225f});
|
||||||
|
img = img.expandDims(0);
|
||||||
|
|
||||||
|
NDArray scale_factor = ctx.getNDManager().create(new float[] {640f / height, 640f / width});
|
||||||
|
scale_factor = scale_factor.toType(DataType.FLOAT32, false);
|
||||||
|
scale_factor = scale_factor.expandDims(0);
|
||||||
|
|
||||||
|
NDArray im_shape = ctx.getNDManager().create(new float[] {640f, 640f});
|
||||||
|
im_shape = im_shape.toType(DataType.FLOAT32, false);
|
||||||
|
im_shape = im_shape.expandDims(0);
|
||||||
|
|
||||||
|
// im_shape, image, scale_factor
|
||||||
|
return new NDList(im_shape, img, scale_factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Batchifier getBatchifier() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.opencv;
|
||||||
|
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import org.bytedeco.javacpp.indexer.DoubleRawIndexer;
|
||||||
|
import org.bytedeco.opencv.global.opencv_core;
|
||||||
|
import org.bytedeco.opencv.opencv_core.Mat;
|
||||||
|
import org.bytedeco.opencv.opencv_core.Point2f;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class NDArrayUtils {
|
||||||
|
// NDArray 转 opencv_core.Mat
|
||||||
|
public static Mat toOpenCVMat(NDArray points, int rows, int cols) {
|
||||||
|
double[] doubleArray = points.toDoubleArray();
|
||||||
|
// CV_32F = FloatRawIndexer
|
||||||
|
// CV_64F = DoubleRawIndexer
|
||||||
|
Mat mat = new Mat(rows, cols, opencv_core.CV_64F);
|
||||||
|
|
||||||
|
DoubleRawIndexer ldIdx = mat.createIndexer();
|
||||||
|
for (int i = 0; i < rows; i++) {
|
||||||
|
for (int j = 0; j < cols; j++) {
|
||||||
|
ldIdx.put(i, j, doubleArray[i * cols + j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ldIdx.release();
|
||||||
|
|
||||||
|
return mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
// NDArray 转 opencv_core.Point2f
|
||||||
|
public static Point2f toOpenCVPoint2f(NDArray points, int rows) {
|
||||||
|
double[] doubleArray = points.toDoubleArray();
|
||||||
|
Point2f points2f = new Point2f(rows);
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++) {
|
||||||
|
points2f.position(i).x((float) doubleArray[i * 2]).y((float) doubleArray[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return points2f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Double array 转 opencv_core.Point2f
|
||||||
|
public static Point2f toOpenCVPoint2f(double[] doubleArray, int rows) {
|
||||||
|
Point2f points2f = new Point2f(rows);
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++) {
|
||||||
|
points2f.position(i).x((float) doubleArray[i * 2]).y((float) doubleArray[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return points2f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// list 转 opencv_core.Point2f
|
||||||
|
public static Point2f toOpenCVPoint2f(List<ai.djl.modality.cv.output.Point> points, int rows) {
|
||||||
|
Point2f points2f = new Point2f(points.size());
|
||||||
|
|
||||||
|
for (int i = 0; i < rows; i++) {
|
||||||
|
ai.djl.modality.cv.output.Point point = points.get(i);
|
||||||
|
points2f.position(i).x((float) point.getX()).y((float) point.getY());
|
||||||
|
}
|
||||||
|
|
||||||
|
return points2f;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,72 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.opencv;
|
||||||
|
|
||||||
|
import org.bytedeco.opencv.global.opencv_imgproc;
|
||||||
|
import org.bytedeco.opencv.opencv_core.Mat;
|
||||||
|
import org.bytedeco.opencv.opencv_core.Point2f;
|
||||||
|
import org.opencv.core.CvType;
|
||||||
|
import org.opencv.imgproc.Imgproc;
|
||||||
|
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.awt.image.DataBufferByte;
|
||||||
|
import java.awt.image.WritableRaster;
|
||||||
|
|
||||||
|
|
||||||
|
public class OpenCVUtils {
|
||||||
|
|
||||||
|
public static Mat perspectiveTransform(
|
||||||
|
Mat src, Point2f srcPoints, Point2f dstPoints) {
|
||||||
|
Mat dst = src.clone();
|
||||||
|
Mat warp_mat = opencv_imgproc.getPerspectiveTransform(srcPoints.position(0), dstPoints.position(0));
|
||||||
|
opencv_imgproc.warpPerspective(src, dst, warp_mat, dst.size());
|
||||||
|
warp_mat.release();
|
||||||
|
|
||||||
|
return dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mat to BufferedImage
|
||||||
|
*
|
||||||
|
* @param mat
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static BufferedImage mat2Image(org.opencv.core.Mat mat) {
|
||||||
|
int width = mat.width();
|
||||||
|
int height = mat.height();
|
||||||
|
byte[] data = new byte[width * height * (int) mat.elemSize()];
|
||||||
|
Imgproc.cvtColor(mat, mat, 4);
|
||||||
|
mat.get(0, 0, data);
|
||||||
|
BufferedImage ret = new BufferedImage(width, height, 5);
|
||||||
|
ret.getRaster().setDataElements(0, 0, width, height, data);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BufferedImage matToBufferedImage(org.opencv.core.Mat frame) {
|
||||||
|
int type = 0;
|
||||||
|
if (frame.channels() == 1) {
|
||||||
|
type = BufferedImage.TYPE_BYTE_GRAY;
|
||||||
|
} else if (frame.channels() == 3) {
|
||||||
|
type = BufferedImage.TYPE_3BYTE_BGR;
|
||||||
|
}
|
||||||
|
BufferedImage image = new BufferedImage(frame.width(), frame.height(), type);
|
||||||
|
WritableRaster raster = image.getRaster();
|
||||||
|
DataBufferByte dataBuffer = (DataBufferByte) raster.getDataBuffer();
|
||||||
|
byte[] data = dataBuffer.getData();
|
||||||
|
frame.get(0, 0, data);
|
||||||
|
return image;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* BufferedImage to Mat
|
||||||
|
*
|
||||||
|
* @param img
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static org.opencv.core.Mat image2Mat(BufferedImage img) {
|
||||||
|
int width = img.getWidth();
|
||||||
|
int height = img.getHeight();
|
||||||
|
byte[] data = ((DataBufferByte) img.getRaster().getDataBuffer()).getData();
|
||||||
|
org.opencv.core.Mat mat = new org.opencv.core.Mat(height, width, CvType.CV_8UC3);
|
||||||
|
mat.put(0, 0, data);
|
||||||
|
return mat;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,129 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.recognition;
|
||||||
|
|
||||||
|
import ai.djl.inference.Predictor;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
import ai.djl.modality.cv.output.DetectedObjects;
|
||||||
|
import ai.djl.modality.cv.output.Rectangle;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDManager;
|
||||||
|
import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;
|
||||||
|
import ai.djl.repository.zoo.Criteria;
|
||||||
|
import ai.djl.training.util.ProgressBar;
|
||||||
|
import ai.djl.translate.TranslateException;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public final class OcrV3AlignedRecognition {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(OcrV3AlignedRecognition.class);
|
||||||
|
|
||||||
|
public OcrV3AlignedRecognition() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public DetectedObjects predict(
|
||||||
|
Image image, Predictor<Image, DetectedObjects> detector, Predictor<Image, String> recognizer)
|
||||||
|
throws TranslateException {
|
||||||
|
DetectedObjects detections = detector.predict(image);
|
||||||
|
|
||||||
|
List<DetectedObjects.DetectedObject> boxes = detections.items();
|
||||||
|
|
||||||
|
List<String> names = new ArrayList<>();
|
||||||
|
List<Double> prob = new ArrayList<>();
|
||||||
|
List<BoundingBox> rect = new ArrayList<>();
|
||||||
|
|
||||||
|
long timeInferStart = System.currentTimeMillis();
|
||||||
|
for (int i = 0; i < boxes.size(); i++) {
|
||||||
|
Image subImg = getSubImage(image, boxes.get(i).getBoundingBox());
|
||||||
|
if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {
|
||||||
|
subImg = rotateImg(subImg);
|
||||||
|
}
|
||||||
|
// ImageUtils.saveImage(subImg, i + ".png", "build/output");
|
||||||
|
String name = recognizer.predict(subImg);
|
||||||
|
names.add(name);
|
||||||
|
prob.add(-1.0);
|
||||||
|
rect.add(boxes.get(i).getBoundingBox());
|
||||||
|
}
|
||||||
|
long timeInferEnd = System.currentTimeMillis();
|
||||||
|
System.out.println("time: " + (timeInferEnd - timeInferStart));
|
||||||
|
|
||||||
|
DetectedObjects detectedObjects = new DetectedObjects(names, prob, rect);
|
||||||
|
|
||||||
|
return detectedObjects;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria<Image, DetectedObjects> detectCriteria() {
|
||||||
|
Criteria<Image, DetectedObjects> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, DetectedObjects.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_PP-OCRv3_det_infer.zip")
|
||||||
|
// .optModelUrls(
|
||||||
|
// "/Users/calvin/Documents/build/paddle_models/ppocr/ch_PP-OCRv2_det_infer")
|
||||||
|
.optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria<Image, String> recognizeCriteria() {
|
||||||
|
Criteria<Image, String> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, String.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_PP-OCRv3_rec_infer.zip")
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap<String, String>())))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Image getSubImage(Image img, BoundingBox box) {
|
||||||
|
Rectangle rect = box.getBounds();
|
||||||
|
double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());
|
||||||
|
int width = img.getWidth();
|
||||||
|
int height = img.getHeight();
|
||||||
|
int[] recovered = {
|
||||||
|
(int) (extended[0] * width),
|
||||||
|
(int) (extended[1] * height),
|
||||||
|
(int) (extended[2] * width),
|
||||||
|
(int) (extended[3] * height)
|
||||||
|
};
|
||||||
|
|
||||||
|
return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] extendRect(double xmin, double ymin, double width, double height) {
|
||||||
|
double centerx = xmin + width / 2;
|
||||||
|
double centery = ymin + height / 2;
|
||||||
|
if (width > height) {
|
||||||
|
width += height * 2.0;
|
||||||
|
height *= 3.0;
|
||||||
|
} else {
|
||||||
|
height += width * 2.0;
|
||||||
|
width *= 3.0;
|
||||||
|
}
|
||||||
|
double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;
|
||||||
|
double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;
|
||||||
|
double newWidth = newX + width > 1 ? 1 - newX : width;
|
||||||
|
double newHeight = newY + height > 1 ? 1 - newY : height;
|
||||||
|
return new double[]{newX, newY, newWidth, newHeight};
|
||||||
|
}
|
||||||
|
|
||||||
|
private Image rotateImg(Image image) {
|
||||||
|
try (NDManager manager = NDManager.newBaseManager()) {
|
||||||
|
NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);
|
||||||
|
return ImageFactory.getInstance().fromNDArray(rotated);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,194 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.recognition;
|
||||||
|
|
||||||
|
import ai.djl.inference.Predictor;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
import ai.djl.modality.cv.output.DetectedObjects;
|
||||||
|
import ai.djl.modality.cv.output.Rectangle;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDManager;
|
||||||
|
import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;
|
||||||
|
import ai.djl.repository.zoo.Criteria;
|
||||||
|
import ai.djl.repository.zoo.ZooModel;
|
||||||
|
import ai.djl.training.util.ProgressBar;
|
||||||
|
import ai.djl.translate.TranslateException;
|
||||||
|
import jnpf.ocr_sdk.utils.common.ImageInfo;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.*;
|
||||||
|
|
||||||
|
public final class OcrV3MultiThreadRecognition {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(OcrV3MultiThreadRecognition.class);
|
||||||
|
|
||||||
|
public OcrV3MultiThreadRecognition() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public DetectedObjects predict(
|
||||||
|
Image image, List<ZooModel> recModels, Predictor<Image, DetectedObjects> detector, int threadNum)
|
||||||
|
throws TranslateException {
|
||||||
|
DetectedObjects detections = detector.predict(image);
|
||||||
|
|
||||||
|
List<DetectedObjects.DetectedObject> boxes = detections.items();
|
||||||
|
|
||||||
|
ConcurrentLinkedQueue<ImageInfo> queue = new ConcurrentLinkedQueue<>();
|
||||||
|
for (int i = 0; i < boxes.size(); i++) {
|
||||||
|
BoundingBox box = boxes.get(i).getBoundingBox();
|
||||||
|
Image subImg = getSubImage(image, box);
|
||||||
|
if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {
|
||||||
|
subImg = rotateImg(subImg);
|
||||||
|
}
|
||||||
|
ImageInfo imageInfo = new ImageInfo(subImg, box);
|
||||||
|
queue.add(imageInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<InferCallable> callables = new ArrayList<>(threadNum);
|
||||||
|
for (int i = 0; i < threadNum; i++) {
|
||||||
|
callables.add(new InferCallable(recModels.get(i), queue));
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutorService es = Executors.newFixedThreadPool(threadNum);
|
||||||
|
List<ImageInfo> resultList = new ArrayList<>();
|
||||||
|
try {
|
||||||
|
List<Future<List<ImageInfo>>> futures = new ArrayList<>();
|
||||||
|
long timeInferStart = System.currentTimeMillis();
|
||||||
|
for (InferCallable callable : callables) {
|
||||||
|
futures.add(es.submit(callable));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Future<List<ImageInfo>> future : futures) {
|
||||||
|
List<ImageInfo> subList = future.get();
|
||||||
|
if (subList != null) {
|
||||||
|
resultList.addAll(subList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
long timeInferEnd = System.currentTimeMillis();
|
||||||
|
System.out.println("time: " + (timeInferEnd - timeInferStart));
|
||||||
|
|
||||||
|
for (InferCallable callable : callables) {
|
||||||
|
callable.close();
|
||||||
|
}
|
||||||
|
} catch (InterruptedException | ExecutionException e) {
|
||||||
|
logger.error("", e);
|
||||||
|
} finally {
|
||||||
|
es.shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> names = new ArrayList<>();
|
||||||
|
List<Double> prob = new ArrayList<>();
|
||||||
|
List<BoundingBox> rect = new ArrayList<>();
|
||||||
|
for (ImageInfo imageInfo : resultList) {
|
||||||
|
names.add(imageInfo.getName());
|
||||||
|
prob.add(imageInfo.getProb());
|
||||||
|
rect.add(imageInfo.getBox());
|
||||||
|
}
|
||||||
|
DetectedObjects detectedObjects = new DetectedObjects(names, prob, rect);
|
||||||
|
|
||||||
|
return detectedObjects;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria<Image, DetectedObjects> detectCriteria() {
|
||||||
|
Criteria<Image, DetectedObjects> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, DetectedObjects.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_PP-OCRv3_det_infer.zip")
|
||||||
|
// .optModelUrls(
|
||||||
|
// "/Users/calvin/Documents/build/paddle_models/ppocr/ch_PP-OCRv2_det_infer")
|
||||||
|
.optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria<Image, String> recognizeCriteria() {
|
||||||
|
Criteria<Image, String> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, String.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_PP-OCRv3_rec_infer.zip")
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap<String, String>())))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class InferCallable implements Callable<List<ImageInfo>> {
|
||||||
|
private Predictor<Image, String> recognizer;
|
||||||
|
private ConcurrentLinkedQueue<ImageInfo> queue;
|
||||||
|
private List<ImageInfo> resultList = new ArrayList<>();
|
||||||
|
|
||||||
|
public InferCallable(ZooModel recognitionModel, ConcurrentLinkedQueue<ImageInfo> queue){
|
||||||
|
recognizer = recognitionModel.newPredictor();
|
||||||
|
this.queue = queue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<ImageInfo> call() {
|
||||||
|
try {
|
||||||
|
ImageInfo imageInfo = queue.poll();
|
||||||
|
while (imageInfo != null) {
|
||||||
|
String name = recognizer.predict(imageInfo.getImage());
|
||||||
|
imageInfo.setName(name);
|
||||||
|
imageInfo.setProb(-1.0);
|
||||||
|
resultList.add(imageInfo);
|
||||||
|
imageInfo = queue.poll();
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
return resultList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void close() {
|
||||||
|
recognizer.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Image getSubImage(Image img, BoundingBox box) {
|
||||||
|
Rectangle rect = box.getBounds();
|
||||||
|
double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());
|
||||||
|
int width = img.getWidth();
|
||||||
|
int height = img.getHeight();
|
||||||
|
int[] recovered = {
|
||||||
|
(int) (extended[0] * width),
|
||||||
|
(int) (extended[1] * height),
|
||||||
|
(int) (extended[2] * width),
|
||||||
|
(int) (extended[3] * height)
|
||||||
|
};
|
||||||
|
return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] extendRect(double xmin, double ymin, double width, double height) {
|
||||||
|
double centerx = xmin + width / 2;
|
||||||
|
double centery = ymin + height / 2;
|
||||||
|
if (width > height) {
|
||||||
|
width += height * 2.0;
|
||||||
|
height *= 3.0;
|
||||||
|
} else {
|
||||||
|
height += width * 2.0;
|
||||||
|
width *= 3.0;
|
||||||
|
}
|
||||||
|
double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;
|
||||||
|
double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;
|
||||||
|
double newWidth = newX + width > 1 ? 1 - newX : width;
|
||||||
|
double newHeight = newY + height > 1 ? 1 - newY : height;
|
||||||
|
return new double[]{newX, newY, newWidth, newHeight};
|
||||||
|
}
|
||||||
|
|
||||||
|
private Image rotateImg(Image image) {
|
||||||
|
try (NDManager manager = NDManager.newBaseManager()) {
|
||||||
|
NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);
|
||||||
|
return ImageFactory.getInstance().fromNDArray(rotated);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,140 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.recognition;
|
||||||
|
|
||||||
|
import ai.djl.inference.Predictor;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.modality.cv.output.Point;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.ndarray.NDManager;
|
||||||
|
import ai.djl.opencv.OpenCVImageFactory;
|
||||||
|
import ai.djl.repository.zoo.Criteria;
|
||||||
|
import ai.djl.training.util.ProgressBar;
|
||||||
|
import ai.djl.translate.TranslateException;
|
||||||
|
import jnpf.ocr_sdk.utils.common.RotatedBox;
|
||||||
|
import jnpf.ocr_sdk.utils.opencv.NDArrayUtils;
|
||||||
|
import jnpf.ocr_sdk.utils.opencv.OpenCVUtils;
|
||||||
|
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||||
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
|
import org.bytedeco.opencv.opencv_core.Point2f;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public final class OcrV3Recognition {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(OcrV3Recognition.class);
|
||||||
|
|
||||||
|
public OcrV3Recognition() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria<Image, String> recognizeCriteria() {
|
||||||
|
Criteria<Image, String> criteria =
|
||||||
|
Criteria.builder()
|
||||||
|
.optEngine("PaddlePaddle")
|
||||||
|
.setTypes(Image.class, String.class)
|
||||||
|
.optModelUrls(
|
||||||
|
"https://aias-home.oss-cn-beijing.aliyuncs.com/models/ocr_models/ch_PP-OCRv3_rec_infer.zip")
|
||||||
|
.optProgress(new ProgressBar())
|
||||||
|
.optTranslator(new PpWordRecognitionTranslator((new ConcurrentHashMap<String, String>())))
|
||||||
|
.build();
|
||||||
|
return criteria;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<RotatedBox> predict(
|
||||||
|
Image image, Predictor<Image, NDList> detector, Predictor<Image, String> recognizer)
|
||||||
|
throws TranslateException {
|
||||||
|
NDList boxes = detector.predict(image);
|
||||||
|
|
||||||
|
List<RotatedBox> result = new ArrayList<>();
|
||||||
|
long timeInferStart = System.currentTimeMillis();
|
||||||
|
|
||||||
|
OpenCVFrameConverter.ToMat cv = new OpenCVFrameConverter.ToMat();
|
||||||
|
OpenCVFrameConverter.ToMat converter1 = new OpenCVFrameConverter.ToMat();
|
||||||
|
OpenCVFrameConverter.ToOrgOpenCvCoreMat converter2 = new OpenCVFrameConverter.ToOrgOpenCvCoreMat();
|
||||||
|
|
||||||
|
for (int i = 0; i < boxes.size(); i++) {
|
||||||
|
NDArray box = boxes.get(i);
|
||||||
|
// BufferedImage bufferedImage = get_rotate_crop_image(image, box);
|
||||||
|
|
||||||
|
float[] pointsArr = box.toFloatArray();
|
||||||
|
float[] lt = java.util.Arrays.copyOfRange(pointsArr, 0, 2);
|
||||||
|
float[] rt = java.util.Arrays.copyOfRange(pointsArr, 2, 4);
|
||||||
|
float[] rb = java.util.Arrays.copyOfRange(pointsArr, 4, 6);
|
||||||
|
float[] lb = java.util.Arrays.copyOfRange(pointsArr, 6, 8);
|
||||||
|
int img_crop_width = (int) Math.max(distance(lt, rt), distance(rb, lb));
|
||||||
|
int img_crop_height = (int) Math.max(distance(lt, lb), distance(rt, rb));
|
||||||
|
List<Point> srcPoints = new ArrayList<>();
|
||||||
|
srcPoints.add(new Point(lt[0], lt[1]));
|
||||||
|
srcPoints.add(new Point(rt[0], rt[1]));
|
||||||
|
srcPoints.add(new Point(rb[0], rb[1]));
|
||||||
|
srcPoints.add(new Point(lb[0], lb[1]));
|
||||||
|
List<Point> dstPoints = new ArrayList<>();
|
||||||
|
dstPoints.add(new Point(0, 0));
|
||||||
|
dstPoints.add(new Point(img_crop_width, 0));
|
||||||
|
dstPoints.add(new Point(img_crop_width, img_crop_height));
|
||||||
|
dstPoints.add(new Point(0, img_crop_height));
|
||||||
|
|
||||||
|
Point2f srcPoint2f = NDArrayUtils.toOpenCVPoint2f(srcPoints, 4);
|
||||||
|
Point2f dstPoint2f = NDArrayUtils.toOpenCVPoint2f(dstPoints, 4);
|
||||||
|
|
||||||
|
BufferedImage bufferedImage = OpenCVUtils.matToBufferedImage((org.opencv.core.Mat) image.getWrappedImage());
|
||||||
|
// try {
|
||||||
|
// File outputfile = new File("build/output/srcImage.jpg");
|
||||||
|
// ImageIO.write(bufferedImage, "jpg", outputfile);
|
||||||
|
// } catch (IOException e) {
|
||||||
|
// e.printStackTrace();
|
||||||
|
// }
|
||||||
|
org.bytedeco.opencv.opencv_core.Mat mat = cv.convertToMat(new Java2DFrameConverter().convert(bufferedImage));
|
||||||
|
org.bytedeco.opencv.opencv_core.Mat dstMat = OpenCVUtils.perspectiveTransform(mat, srcPoint2f, dstPoint2f);
|
||||||
|
org.opencv.core.Mat cvMat = converter2.convert(converter1.convert(dstMat));
|
||||||
|
Image subImg = OpenCVImageFactory.getInstance().fromImage(cvMat);
|
||||||
|
// ImageUtils.saveImage(subImg, i + ".png", "build/output");
|
||||||
|
|
||||||
|
subImg = subImg.getSubImage(0,0,img_crop_width,img_crop_height);
|
||||||
|
if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {
|
||||||
|
subImg = rotateImg(subImg);
|
||||||
|
}
|
||||||
|
|
||||||
|
String name = recognizer.predict(subImg);
|
||||||
|
RotatedBox rotatedBox = new RotatedBox(box, name);
|
||||||
|
result.add(rotatedBox);
|
||||||
|
|
||||||
|
mat.release();
|
||||||
|
dstMat.release();
|
||||||
|
cvMat.release();
|
||||||
|
srcPoint2f.releaseReference();
|
||||||
|
dstPoint2f.releaseReference();
|
||||||
|
}
|
||||||
|
cv.close();
|
||||||
|
converter1.close();
|
||||||
|
converter2.close();
|
||||||
|
long timeInferEnd = System.currentTimeMillis();
|
||||||
|
System.out.println("time: " + (timeInferEnd - timeInferStart));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private BufferedImage get_rotate_crop_image(Image image, NDArray box) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private float distance(float[] point1, float[] point2) {
|
||||||
|
float disX = point1[0] - point2[0];
|
||||||
|
float disY = point1[1] - point2[1];
|
||||||
|
float dis = (float) Math.sqrt(disX * disX + disY * disY);
|
||||||
|
return dis;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Image rotateImg(Image image) {
|
||||||
|
try (NDManager manager = NDManager.newBaseManager()) {
|
||||||
|
NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);
|
||||||
|
return ImageFactory.getInstance().fromNDArray(rotated);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,108 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.recognition;
|
||||||
|
|
||||||
|
import ai.djl.Model;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.ndarray.index.NDIndex;
|
||||||
|
import ai.djl.ndarray.types.DataType;
|
||||||
|
import ai.djl.ndarray.types.Shape;
|
||||||
|
import ai.djl.translate.Batchifier;
|
||||||
|
import ai.djl.translate.Translator;
|
||||||
|
import ai.djl.translate.TranslatorContext;
|
||||||
|
import ai.djl.util.Utils;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class PpWordRecognitionTranslator implements Translator<Image, String> {
|
||||||
|
private List<String> table;
|
||||||
|
private final boolean use_space_char;
|
||||||
|
|
||||||
|
public PpWordRecognitionTranslator(Map<String, ?> arguments) {
|
||||||
|
use_space_char =
|
||||||
|
arguments.containsKey("use_space_char")
|
||||||
|
? Boolean.parseBoolean(arguments.get("use_space_char").toString())
|
||||||
|
: false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void prepare(TranslatorContext ctx) throws IOException {
|
||||||
|
Model model = ctx.getModel();
|
||||||
|
try (InputStream is = model.getArtifact("ppocr_keys_v1.txt").openStream()) {
|
||||||
|
table = Utils.readLines(is, true);
|
||||||
|
table.add(0, "blank");
|
||||||
|
if(use_space_char)
|
||||||
|
table.add(" ");
|
||||||
|
else
|
||||||
|
table.add("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String processOutput(TranslatorContext ctx, NDList list) throws IOException {
|
||||||
|
StringBuilder sb = new StringBuilder();
|
||||||
|
NDArray tokens = list.singletonOrThrow();
|
||||||
|
|
||||||
|
long[] indices = tokens.get(0).argMax(1).toLongArray();
|
||||||
|
boolean[] selection = new boolean[indices.length];
|
||||||
|
Arrays.fill(selection, true);
|
||||||
|
for (int i = 1; i < indices.length; i++) {
|
||||||
|
if (indices[i] == indices[i - 1]) {
|
||||||
|
selection[i] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 字符置信度
|
||||||
|
// float[] probs = new float[indices.length];
|
||||||
|
// for (int row = 0; row < indices.length; row++) {
|
||||||
|
// NDArray value = tokens.get(0).get(new NDIndex(""+ row +":" + (row + 1) +"," + indices[row] +":" + ( indices[row] + 1)));
|
||||||
|
// probs[row] = value.toFloatArray()[0];
|
||||||
|
// }
|
||||||
|
|
||||||
|
int lastIdx = 0;
|
||||||
|
for (int i = 0; i < indices.length; i++) {
|
||||||
|
if (selection[i] == true && indices[i] > 0 && !(i > 0 && indices[i] == lastIdx)) {
|
||||||
|
sb.append(table.get((int) indices[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||||
|
NDArray img = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
|
||||||
|
int imgC = 3;
|
||||||
|
int imgH = 48;
|
||||||
|
int imgW = 320;//192 320
|
||||||
|
|
||||||
|
int h = input.getHeight();
|
||||||
|
int w = input.getWidth();
|
||||||
|
float ratio = (float) w / (float) h;
|
||||||
|
imgW = (int)(imgH * ratio);
|
||||||
|
|
||||||
|
int resized_w;
|
||||||
|
if (Math.ceil(imgH * ratio) > imgW) {
|
||||||
|
resized_w = imgW;
|
||||||
|
} else {
|
||||||
|
resized_w = (int) (Math.ceil(imgH * ratio));
|
||||||
|
}
|
||||||
|
img = NDImageUtils.resize(img, resized_w, imgH);
|
||||||
|
img = img.transpose(2, 0, 1).div(255).sub(0.5f).div(0.5f);
|
||||||
|
NDArray padding_im = ctx.getNDManager().zeros(new Shape(imgC, imgH, imgW), DataType.FLOAT32);
|
||||||
|
padding_im.set(new NDIndex(":,:,0:" + resized_w), img);
|
||||||
|
|
||||||
|
padding_im = padding_im.expandDims(0);
|
||||||
|
return new NDList(padding_im);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Batchifier getBatchifier() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,69 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.rotation;
|
||||||
|
|
||||||
|
import ai.djl.modality.Classifications;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.ndarray.index.NDIndex;
|
||||||
|
import ai.djl.ndarray.types.Shape;
|
||||||
|
import ai.djl.translate.Batchifier;
|
||||||
|
import ai.djl.translate.Translator;
|
||||||
|
import ai.djl.translate.TranslatorContext;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class PpWordRotateTranslator implements Translator<Image, Classifications> {
|
||||||
|
List<String> classes = Arrays.asList("No Rotate", "Rotate");
|
||||||
|
|
||||||
|
public PpWordRotateTranslator() {}
|
||||||
|
|
||||||
|
public Classifications processOutput(TranslatorContext ctx, NDList list) {
|
||||||
|
NDArray prob = list.singletonOrThrow();
|
||||||
|
return new Classifications(this.classes, prob);
|
||||||
|
}
|
||||||
|
|
||||||
|
public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
|
||||||
|
NDArray img = input.toNDArray(ctx.getNDManager());
|
||||||
|
img = NDImageUtils.resize(img, 192, 48);
|
||||||
|
img = NDImageUtils.toTensor(img).sub(0.5F).div(0.5F);
|
||||||
|
img = img.expandDims(0);
|
||||||
|
return new NDList(new NDArray[]{img});
|
||||||
|
}
|
||||||
|
|
||||||
|
public NDList processInputBak(TranslatorContext ctx, Image input) throws Exception {
|
||||||
|
NDArray img = input.toNDArray(ctx.getNDManager());
|
||||||
|
int imgC = 3;
|
||||||
|
int imgH = 48;
|
||||||
|
int imgW = 192;
|
||||||
|
|
||||||
|
NDArray array = ctx.getNDManager().zeros(new Shape(imgC, imgH, imgW));
|
||||||
|
|
||||||
|
int h = input.getHeight();
|
||||||
|
int w = input.getWidth();
|
||||||
|
int resized_w = 0;
|
||||||
|
|
||||||
|
float ratio = (float) w / (float) h;
|
||||||
|
if (Math.ceil(imgH * ratio) > imgW) {
|
||||||
|
resized_w = imgW;
|
||||||
|
} else {
|
||||||
|
resized_w = (int) (Math.ceil(imgH * ratio));
|
||||||
|
}
|
||||||
|
|
||||||
|
img = NDImageUtils.resize(img, resized_w, imgH);
|
||||||
|
|
||||||
|
img = NDImageUtils.toTensor(img).sub(0.5F).div(0.5F);
|
||||||
|
// img = img.transpose(2, 0, 1);
|
||||||
|
|
||||||
|
array.set(new NDIndex(":,:,0:" + resized_w), img);
|
||||||
|
|
||||||
|
array = array.expandDims(0);
|
||||||
|
|
||||||
|
return new NDList(new NDArray[] {array});
|
||||||
|
}
|
||||||
|
|
||||||
|
public Batchifier getBatchifier() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,235 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.table;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.math.NumberUtils;
|
||||||
|
import org.apache.poi.hssf.usermodel.*;
|
||||||
|
import org.apache.poi.ss.usermodel.BorderStyle;
|
||||||
|
import org.apache.poi.ss.usermodel.CellType;
|
||||||
|
import org.apache.poi.ss.usermodel.HorizontalAlignment;
|
||||||
|
import org.apache.poi.ss.usermodel.VerticalAlignment;
|
||||||
|
import org.apache.poi.ss.util.CellRangeAddress;
|
||||||
|
import org.dom4j.Document;
|
||||||
|
import org.dom4j.DocumentException;
|
||||||
|
import org.dom4j.DocumentHelper;
|
||||||
|
import org.dom4j.Element;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @Auther: xiaoqiang
|
||||||
|
* @Date: 2020/12/9 9:16
|
||||||
|
* @Description:
|
||||||
|
*/
|
||||||
|
public class ConvertHtml2Excel {
|
||||||
|
/**
|
||||||
|
* html表格转excel
|
||||||
|
*
|
||||||
|
* @param tableHtml 如
|
||||||
|
* <table>
|
||||||
|
* ..
|
||||||
|
* </table>
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static HSSFWorkbook table2Excel(String tableHtml) {
|
||||||
|
HSSFWorkbook wb = new HSSFWorkbook();
|
||||||
|
HSSFSheet sheet = wb.createSheet();
|
||||||
|
List<CrossRangeCellMeta> crossRowEleMetaLs = new ArrayList<>();
|
||||||
|
int rowIndex = 0;
|
||||||
|
try {
|
||||||
|
Document data = DocumentHelper.parseText(tableHtml);
|
||||||
|
// 生成表头
|
||||||
|
Element thead = data.getRootElement().element("thead");
|
||||||
|
HSSFCellStyle titleStyle = getTitleStyle(wb);
|
||||||
|
int ls=0;//列数
|
||||||
|
if (thead != null) {
|
||||||
|
List<Element> trLs = thead.elements("tr");
|
||||||
|
for (Element trEle : trLs) {
|
||||||
|
HSSFRow row = sheet.createRow(rowIndex);
|
||||||
|
List<Element> thLs = trEle.elements("td");
|
||||||
|
ls=thLs.size();
|
||||||
|
makeRowCell(thLs, rowIndex, row, 0, titleStyle, crossRowEleMetaLs);
|
||||||
|
rowIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 生成表体
|
||||||
|
Element tbody = data.getRootElement().element("tbody");
|
||||||
|
HSSFCellStyle contentStyle = getContentStyle(wb);
|
||||||
|
if (tbody != null) {
|
||||||
|
List<Element> trLs = tbody.elements("tr");
|
||||||
|
for (Element trEle : trLs) {
|
||||||
|
HSSFRow row = sheet.createRow(rowIndex);
|
||||||
|
List<Element> thLs = trEle.elements("th");
|
||||||
|
int cellIndex = makeRowCell(thLs, rowIndex, row, 0, titleStyle, crossRowEleMetaLs);
|
||||||
|
List<Element> tdLs = trEle.elements("td");
|
||||||
|
makeRowCell(tdLs, rowIndex, row, cellIndex, contentStyle, crossRowEleMetaLs);
|
||||||
|
rowIndex++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 合并表头
|
||||||
|
for (CrossRangeCellMeta crcm : crossRowEleMetaLs) {
|
||||||
|
sheet.addMergedRegion(new CellRangeAddress(crcm.getFirstRow(), crcm.getLastRow(), crcm.getFirstCol(), crcm.getLastCol()));
|
||||||
|
setRegionStyle(sheet, new CellRangeAddress(crcm.getFirstRow(), crcm.getLastRow(), crcm.getFirstCol(), crcm.getLastCol()),titleStyle);
|
||||||
|
}
|
||||||
|
for(int i=0;i<sheet.getRow(0).getPhysicalNumberOfCells();i++){
|
||||||
|
sheet.autoSizeColumn(i, true);//设置列宽
|
||||||
|
if(sheet.getColumnWidth(i)<255*256){
|
||||||
|
sheet.setColumnWidth(i, sheet.getColumnWidth(i) < 9000 ? 9000 : sheet.getColumnWidth(i));
|
||||||
|
}else{
|
||||||
|
sheet.setColumnWidth(i, 15000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (DocumentException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
return wb;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生产行内容
|
||||||
|
*
|
||||||
|
* @return 最后一列的cell index
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* @param tdLs th或者td集合
|
||||||
|
* @param rowIndex 行号
|
||||||
|
* @param row POI行对象
|
||||||
|
* @param startCellIndex
|
||||||
|
* @param cellStyle 样式
|
||||||
|
* @param crossRowEleMetaLs 跨行元数据集合
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private static int makeRowCell(List<Element> tdLs, int rowIndex, HSSFRow row, int startCellIndex, HSSFCellStyle cellStyle,
|
||||||
|
List<CrossRangeCellMeta> crossRowEleMetaLs) {
|
||||||
|
int i = startCellIndex;
|
||||||
|
for (int eleIndex = 0; eleIndex < tdLs.size(); i++, eleIndex++) {
|
||||||
|
int captureCellSize = getCaptureCellSize(rowIndex, i, crossRowEleMetaLs);
|
||||||
|
while (captureCellSize > 0) {
|
||||||
|
for (int j = 0; j < captureCellSize; j++) {// 当前行跨列处理(补单元格)
|
||||||
|
row.createCell(i);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
captureCellSize = getCaptureCellSize(rowIndex, i, crossRowEleMetaLs);
|
||||||
|
}
|
||||||
|
Element thEle = tdLs.get(eleIndex);
|
||||||
|
String val = thEle.getTextTrim();
|
||||||
|
if (StringUtils.isBlank(val)) {
|
||||||
|
Element e = thEle.element("a");
|
||||||
|
if (e != null) {
|
||||||
|
val = e.getTextTrim();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
HSSFCell c = row.createCell(i);
|
||||||
|
if (NumberUtils.isNumber(val)) {
|
||||||
|
c.setCellValue(Double.parseDouble(val));
|
||||||
|
c.setCellType(CellType.NUMERIC);
|
||||||
|
} else {
|
||||||
|
c.setCellValue(val);
|
||||||
|
}
|
||||||
|
int rowSpan = NumberUtils.toInt(thEle.attributeValue("rowspan"), 1);
|
||||||
|
int colSpan = NumberUtils.toInt(thEle.attributeValue("colspan"), 1);
|
||||||
|
c.setCellStyle(cellStyle);
|
||||||
|
if (rowSpan > 1 || colSpan > 1) { // 存在跨行或跨列
|
||||||
|
crossRowEleMetaLs.add(new CrossRangeCellMeta(rowIndex, i, rowSpan, colSpan));
|
||||||
|
}
|
||||||
|
if (colSpan > 1) {// 当前行跨列处理(补单元格)
|
||||||
|
for (int j = 1; j < colSpan; j++) {
|
||||||
|
i++;
|
||||||
|
row.createCell(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置合并单元格的边框样式
|
||||||
|
*
|
||||||
|
* @param sheet
|
||||||
|
* @param region
|
||||||
|
* @param cs
|
||||||
|
*/
|
||||||
|
public static void setRegionStyle(HSSFSheet sheet, CellRangeAddress region, HSSFCellStyle cs) {
|
||||||
|
for (int i = region.getFirstRow(); i <= region.getLastRow(); i++) {
|
||||||
|
HSSFRow row = sheet.getRow(i);
|
||||||
|
for (int j = region.getFirstColumn(); j <= region.getLastColumn(); j++) {
|
||||||
|
HSSFCell cell = row.getCell(j);
|
||||||
|
cell.setCellStyle(cs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获得因rowSpan占据的单元格
|
||||||
|
*
|
||||||
|
* @param rowIndex 行号
|
||||||
|
* @param colIndex 列号
|
||||||
|
* @param crossRowEleMetaLs 跨行列元数据
|
||||||
|
* @return 当前行在某列需要占据单元格
|
||||||
|
*/
|
||||||
|
private static int getCaptureCellSize(int rowIndex, int colIndex, List<CrossRangeCellMeta> crossRowEleMetaLs) {
|
||||||
|
int captureCellSize = 0;
|
||||||
|
for (CrossRangeCellMeta crossRangeCellMeta : crossRowEleMetaLs) {
|
||||||
|
if (crossRangeCellMeta.getFirstRow() < rowIndex && crossRangeCellMeta.getLastRow() >= rowIndex) {
|
||||||
|
if (crossRangeCellMeta.getFirstCol() <= colIndex && crossRangeCellMeta.getLastCol() >= colIndex) {
|
||||||
|
captureCellSize = crossRangeCellMeta.getLastCol() - colIndex + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return captureCellSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获得标题样式
|
||||||
|
*
|
||||||
|
* @param workbook
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private static HSSFCellStyle getTitleStyle(HSSFWorkbook workbook) {
|
||||||
|
//short titlebackgroundcolor = IndexedColors.GREY_25_PERCENT.index;
|
||||||
|
short fontSize = 12;
|
||||||
|
String fontName = "宋体";
|
||||||
|
HSSFCellStyle style = workbook.createCellStyle();
|
||||||
|
style.setVerticalAlignment(VerticalAlignment.CENTER);
|
||||||
|
style.setAlignment(HorizontalAlignment.CENTER);
|
||||||
|
style.setBorderBottom(BorderStyle.THIN); //下边框
|
||||||
|
style.setBorderLeft(BorderStyle.THIN);//左边框
|
||||||
|
style.setBorderTop(BorderStyle.THIN);//上边框
|
||||||
|
style.setBorderRight(BorderStyle.THIN);//右边框
|
||||||
|
//style.setFillPattern(FillPatternType.SOLID_FOREGROUND);
|
||||||
|
//style.setFillForegroundColor(titlebackgroundcolor);// 背景色
|
||||||
|
|
||||||
|
HSSFFont font = workbook.createFont();
|
||||||
|
font.setFontName(fontName);
|
||||||
|
font.setFontHeightInPoints(fontSize);
|
||||||
|
font.setBold(true);
|
||||||
|
style.setFont(font);
|
||||||
|
return style;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获得内容样式
|
||||||
|
*
|
||||||
|
* @param wb
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private static HSSFCellStyle getContentStyle(HSSFWorkbook wb) {
|
||||||
|
short fontSize = 12;
|
||||||
|
String fontName = "宋体";
|
||||||
|
HSSFCellStyle style = wb.createCellStyle();
|
||||||
|
style.setBorderBottom(BorderStyle.THIN); //下边框
|
||||||
|
style.setBorderLeft(BorderStyle.THIN);//左边框
|
||||||
|
style.setBorderTop(BorderStyle.THIN);//上边框
|
||||||
|
style.setBorderRight(BorderStyle.THIN);//右边框
|
||||||
|
HSSFFont font = wb.createFont();
|
||||||
|
font.setFontName(fontName);
|
||||||
|
font.setFontHeightInPoints(fontSize);
|
||||||
|
style.setFont(font);
|
||||||
|
style.setAlignment(HorizontalAlignment.CENTER);//水平居中
|
||||||
|
style.setVerticalAlignment(VerticalAlignment.CENTER);//垂直居中
|
||||||
|
style.setWrapText(true);
|
||||||
|
return style;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,42 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.table;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @Auther: xiaoqiang
|
||||||
|
* @Date: 2020/12/9 9:17
|
||||||
|
* @Description:
|
||||||
|
*/
|
||||||
|
public class CrossRangeCellMeta {
|
||||||
|
public CrossRangeCellMeta(int firstRowIndex, int firstColIndex, int rowSpan, int colSpan) {
|
||||||
|
super();
|
||||||
|
this.firstRowIndex = firstRowIndex;
|
||||||
|
this.firstColIndex = firstColIndex;
|
||||||
|
this.rowSpan = rowSpan;
|
||||||
|
this.colSpan = colSpan;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int firstRowIndex;
|
||||||
|
private int firstColIndex;
|
||||||
|
private int rowSpan;// 跨越行数
|
||||||
|
private int colSpan;// 跨越列数
|
||||||
|
|
||||||
|
public int getFirstRow() {
|
||||||
|
return firstRowIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getLastRow() {
|
||||||
|
return firstRowIndex + rowSpan - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getFirstCol() {
|
||||||
|
return firstColIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getLastCol() {
|
||||||
|
return firstColIndex + colSpan - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getColSpan(){
|
||||||
|
return colSpan;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,31 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.table;
|
||||||
|
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class TableResult {
|
||||||
|
private List<String> structure_str_list;
|
||||||
|
private List<BoundingBox> boxes;
|
||||||
|
|
||||||
|
public TableResult(List<String> structure_str_list, List<BoundingBox> boxes) {
|
||||||
|
this.structure_str_list = structure_str_list;
|
||||||
|
this.boxes = boxes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getStructure_str_list() {
|
||||||
|
return structure_str_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setStructure_str_list(List<String> structure_str_list) {
|
||||||
|
this.structure_str_list = structure_str_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<BoundingBox> getBoxes() {
|
||||||
|
return boxes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setBoxes(List<BoundingBox> boxes) {
|
||||||
|
this.boxes = boxes;
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,246 @@
|
|||||||
|
package jnpf.ocr_sdk.utils.table;
|
||||||
|
|
||||||
|
import ai.djl.Model;
|
||||||
|
import ai.djl.modality.cv.Image;
|
||||||
|
import ai.djl.modality.cv.ImageFactory;
|
||||||
|
import ai.djl.modality.cv.output.BoundingBox;
|
||||||
|
import ai.djl.modality.cv.output.Rectangle;
|
||||||
|
import ai.djl.modality.cv.util.NDImageUtils;
|
||||||
|
import ai.djl.ndarray.NDArray;
|
||||||
|
import ai.djl.ndarray.NDArrays;
|
||||||
|
import ai.djl.ndarray.NDList;
|
||||||
|
import ai.djl.ndarray.index.NDIndex;
|
||||||
|
import ai.djl.ndarray.types.DataType;
|
||||||
|
import ai.djl.ndarray.types.Shape;
|
||||||
|
import ai.djl.translate.Batchifier;
|
||||||
|
import ai.djl.translate.Translator;
|
||||||
|
import ai.djl.translate.TranslatorContext;
|
||||||
|
import ai.djl.util.Utils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
public class TableStructTranslator implements Translator<Image, TableResult> {
|
||||||
|
|
||||||
|
private final int maxLength;
|
||||||
|
private int height;
|
||||||
|
private int width;
|
||||||
|
private float xScale = 1.0f;
|
||||||
|
private float yScale = 1.0f;
|
||||||
|
|
||||||
|
public TableStructTranslator(Map<String, ?> arguments) {
|
||||||
|
maxLength =
|
||||||
|
arguments.containsKey("maxLength")
|
||||||
|
? Integer.parseInt(arguments.get("maxLength").toString())
|
||||||
|
: 488;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, String> dict_idx_character = new ConcurrentHashMap<>();
|
||||||
|
private Map<String, String> dict_character_idx = new ConcurrentHashMap<>();
|
||||||
|
private Map<String, String> dict_idx_elem = new ConcurrentHashMap<>();
|
||||||
|
private Map<String, String> dict_elem_idx = new ConcurrentHashMap<>();
|
||||||
|
private String beg_str = "sos";
|
||||||
|
private String end_str = "eos";
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void prepare(TranslatorContext ctx) throws IOException {
|
||||||
|
Model model = ctx.getModel();
|
||||||
|
// ppocr_keys_v1.txt
|
||||||
|
try (InputStream is = model.getArtifact("table_structure_dict.txt").openStream()) {
|
||||||
|
List<String> lines = Utils.readLines(is, false);
|
||||||
|
String[] substr = lines.get(0).trim().split("\\t");
|
||||||
|
int characterNum = Integer.parseInt(substr[0]);
|
||||||
|
int elemNum = Integer.parseInt(substr[1]);
|
||||||
|
|
||||||
|
List<String> listCharacter = new ArrayList<>();
|
||||||
|
List<String> listElem = new ArrayList<>();
|
||||||
|
for (int i = 1; i < 1 + characterNum; i++) {
|
||||||
|
listCharacter.add(lines.get(i).trim());
|
||||||
|
}
|
||||||
|
for (int i = 1 + characterNum; i < 1 + characterNum + elemNum; i++) {
|
||||||
|
listElem.add(lines.get(i).trim());
|
||||||
|
}
|
||||||
|
listCharacter.add(0, beg_str);
|
||||||
|
listCharacter.add(end_str);
|
||||||
|
listElem.add(0, beg_str);
|
||||||
|
listElem.add(end_str);
|
||||||
|
|
||||||
|
for (int i = 0; i < listCharacter.size(); i++) {
|
||||||
|
dict_idx_character.put("" + i, listCharacter.get(i));
|
||||||
|
dict_character_idx.put(listCharacter.get(i), "" + i);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < listElem.size(); i++) {
|
||||||
|
dict_idx_elem.put("" + i, listElem.get(i));
|
||||||
|
dict_elem_idx.put(listElem.get(i), "" + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||||
|
NDArray img = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
|
||||||
|
height = input.getHeight();
|
||||||
|
width = input.getWidth();
|
||||||
|
|
||||||
|
// img = ResizeTableImage(img, height, width, maxLength);
|
||||||
|
// img = PaddingTableImage(ctx, img, maxLength);
|
||||||
|
|
||||||
|
img = NDImageUtils.resize(img, 488, 488);
|
||||||
|
|
||||||
|
// img = NDImageUtils.toTensor(img);
|
||||||
|
img = img.transpose(2, 0, 1).div(255).flip(0);
|
||||||
|
img =
|
||||||
|
NDImageUtils.normalize(
|
||||||
|
img, new float[] {0.485f, 0.456f, 0.406f}, new float[] {0.229f, 0.224f, 0.225f});
|
||||||
|
|
||||||
|
img = img.expandDims(0);
|
||||||
|
return new NDList(img);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TableResult processOutput(TranslatorContext ctx, NDList list) {
|
||||||
|
NDArray locPreds = list.get(0);
|
||||||
|
NDArray structureProbs = list.get(1);
|
||||||
|
NDArray structure_idx = structureProbs.argMax(2);
|
||||||
|
NDArray structure_probs = structureProbs.max(new int[] {2});
|
||||||
|
|
||||||
|
List<List<String>> result_list = new ArrayList<>();
|
||||||
|
List<List<String>> result_pos_list = new ArrayList<>();
|
||||||
|
List<List<String>> result_score_list = new ArrayList<>();
|
||||||
|
List<List<String>> result_elem_idx_list = new ArrayList<>();
|
||||||
|
List<String> res_html_code_list = new ArrayList<>();
|
||||||
|
List<NDArray> res_loc_list = new ArrayList<>();
|
||||||
|
|
||||||
|
// get ignored tokens
|
||||||
|
int beg_idx = Integer.parseInt(dict_elem_idx.get(beg_str));
|
||||||
|
int end_idx = Integer.parseInt(dict_elem_idx.get(end_str));
|
||||||
|
|
||||||
|
long batch_size = structure_idx.size(0); // len(text_index)
|
||||||
|
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
|
||||||
|
List<String> char_list = new ArrayList<>();
|
||||||
|
List<String> elem_pos_list = new ArrayList<>();
|
||||||
|
List<String> elem_idx_list = new ArrayList<>();
|
||||||
|
List<String> score_list = new ArrayList<>();
|
||||||
|
|
||||||
|
long len = structure_idx.get(batch_idx).size();
|
||||||
|
for (int idx = 0; idx < len; idx++) {
|
||||||
|
int tmp_elem_idx = (int) structure_idx.get(batch_idx).get(idx).toLongArray()[0];
|
||||||
|
if (idx > 0 && tmp_elem_idx == end_idx) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (tmp_elem_idx == beg_idx || tmp_elem_idx == end_idx) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
char_list.add(dict_idx_elem.get("" + tmp_elem_idx));
|
||||||
|
elem_pos_list.add("" + idx);
|
||||||
|
score_list.add("" + structure_probs.get(batch_idx, idx).toFloatArray()[0]);
|
||||||
|
elem_idx_list.add("" + tmp_elem_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
result_list.add(char_list); // structure_str
|
||||||
|
result_pos_list.add(elem_pos_list);
|
||||||
|
result_score_list.add(score_list);
|
||||||
|
result_elem_idx_list.add(elem_idx_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
int batch_num = result_list.size();
|
||||||
|
for (int bno = 0; bno < batch_num; bno++) {
|
||||||
|
NDList res_loc = new NDList();
|
||||||
|
int len = result_list.get(bno).size();
|
||||||
|
for (int sno = 0; sno < len; sno++) {
|
||||||
|
String text = result_list.get(bno).get(sno);
|
||||||
|
if (text.equals("<td>") || text.equals("<td")) {
|
||||||
|
int pos = Integer.parseInt(result_pos_list.get(bno).get(sno));
|
||||||
|
res_loc.add(locPreds.get(bno, pos));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
String res_html_code = StringUtils.join(result_list.get(bno), "");
|
||||||
|
res_html_code_list.add(res_html_code);
|
||||||
|
NDArray array = NDArrays.stack(res_loc);
|
||||||
|
res_loc_list.add(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
// structure_str_list result_list
|
||||||
|
// res_loc res_loc_list
|
||||||
|
List<BoundingBox> boxes = new ArrayList<>();
|
||||||
|
|
||||||
|
long rows = res_loc_list.get(0).size(0);
|
||||||
|
for (int rno = 0; rno < rows; rno++) {
|
||||||
|
float[] arr = res_loc_list.get(0).get(rno).toFloatArray();
|
||||||
|
Rectangle rect = new Rectangle(arr[0], arr[1], (arr[2] - arr[0]), (arr[3] - arr[1]));
|
||||||
|
boxes.add(rect);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> structure_str_list = result_list.get(0);
|
||||||
|
structure_str_list.add(0, "<table>");
|
||||||
|
structure_str_list.add(0, "<body>");
|
||||||
|
structure_str_list.add(0, "<html>");
|
||||||
|
structure_str_list.add("</table>");
|
||||||
|
structure_str_list.add("</body>");
|
||||||
|
structure_str_list.add("</html>");
|
||||||
|
|
||||||
|
TableResult result = new TableResult(structure_str_list, boxes);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Batchifier getBatchifier() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private NDArray ResizeTableImage(NDArray img, int height, int width, int maxLen) {
|
||||||
|
int localMax = Math.max(height, width);
|
||||||
|
float ratio = maxLen * 1.0f / localMax;
|
||||||
|
int resize_h = (int) (height * ratio);
|
||||||
|
int resize_w = (int) (width * ratio);
|
||||||
|
if(width > height){
|
||||||
|
xScale = 1.0f;
|
||||||
|
yScale = ratio;
|
||||||
|
} else{
|
||||||
|
xScale = ratio;
|
||||||
|
yScale = 1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
img = NDImageUtils.resize(img, resize_w, resize_h);
|
||||||
|
return img;
|
||||||
|
}
|
||||||
|
|
||||||
|
private NDArray PaddingTableImage(TranslatorContext ctx, NDArray img, int maxLen) {
|
||||||
|
|
||||||
|
Image srcImg = ImageFactory.getInstance().fromNDArray(img.duplicate());
|
||||||
|
saveImage(srcImg, "img.png", "build/output");
|
||||||
|
|
||||||
|
NDArray paddingImg = ctx.getNDManager().zeros(new Shape(maxLen, maxLen, 3), DataType.UINT8);
|
||||||
|
// NDManager manager = NDManager.newBaseManager();
|
||||||
|
// NDArray paddingImg = manager.zeros(new Shape(maxLen, maxLen, 3), DataType.UINT8);
|
||||||
|
paddingImg.set(
|
||||||
|
new NDIndex("0:" + img.getShape().get(0) + ",0:" + img.getShape().get(1) + ",:"), img);
|
||||||
|
Image image = ImageFactory.getInstance().fromNDArray(paddingImg);
|
||||||
|
|
||||||
|
saveImage(image, "paddingImg.png", "build/output");
|
||||||
|
|
||||||
|
return paddingImg;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void saveImage(Image img, String name, String path) {
|
||||||
|
Path outputDir = Paths.get(path);
|
||||||
|
Path imagePath = outputDir.resolve(name);
|
||||||
|
// OpenJDK 不能保存 jpg 图片的 alpha channel
|
||||||
|
try {
|
||||||
|
img.save(Files.newOutputStream(imagePath), "png");
|
||||||
|
} catch (IOException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in new issue