Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions mindocr/data/transforms/det_east_transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ast
import json
import math

import cv2
Expand Down Expand Up @@ -414,8 +416,17 @@ def _extract_vertices(self, data_labels):
"""
vertices_list = []
labels_list = []
data_labels = eval(data_labels)
for data_label in data_labels:
try:
parsed_data = json.loads(data_labels)
except json.JSONDecodeError:
try:
parsed_data = ast.literal_eval(data_labels)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Invalid data format: {str(e)}") from e

if not isinstance(parsed_data, list):
raise ValueError("Data labels should be a list")
for data_label in parsed_data:
vertices = data_label["points"]
vertices = [item for point in vertices for item in point]
vertices_list.append(vertices)
Expand Down
17 changes: 13 additions & 4 deletions mindocr/data/transforms/svtr_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,10 +546,19 @@ def __init__(self, max_text_length, character_dict_path=None, use_space_char=Fal

self.ctc_encode = CTCLabelEncodeForSVTR(max_text_length, character_dict_path, use_space_char, **kwargs)
self.gtc_encode_type = gtc_encode
# Pls explicitly specify the supported gtc_encode classes and obtain the class objects through dictionaries.
supported_gtc_encode = {}
if gtc_encode is None:
self.gtc_encode = SARLabelEncodeForSVTR(max_text_length, character_dict_path, use_space_char, **kwargs)
else:
self.gtc_encode = eval(gtc_encode)(max_text_length, character_dict_path, use_space_char, **kwargs)
# Mindocr currently does not have a module that requires a custom `gtc_encode` input parameter, and will not
# enter this branch at present. If it is supported later, please directly obtain the class reference through
# a specific dict, and do not use the `eval` function.
if gtc_encode not in supported_gtc_encode:
raise ValueError(f"Get unsupported gtc_encode {gtc_encode}")
self.gtc_encode = supported_gtc_encode[gtc_encode](
max_text_length, character_dict_path, use_space_char, **kwargs
)

def __call__(self, data):
data_ctc = copy.deepcopy(data)
Expand Down Expand Up @@ -925,7 +934,7 @@ def __init__(
jitter_prob=0.4,
blur_prob=0.4,
hsv_aug_prob=0.4,
**kwargs
**kwargs,
):
self.crop_prob = crop_prob
self.reverse_prob = reverse_prob
Expand Down Expand Up @@ -973,7 +982,7 @@ def __init__(
jitter_prob=0.4,
blur_prob=0.4,
hsv_aug_prob=0.4,
**kwargs
**kwargs,
):
self.tia_prob = tia_prob
self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob, jitter_prob, blur_prob, hsv_aug_prob)
Expand Down Expand Up @@ -1078,7 +1087,7 @@ def __init__(
character_dict_path=".mindocr/utils/dict/ch_dict.txt",
padding=True,
width_downsample_ratio=0.125,
**kwargs
**kwargs,
):
self.image_shape = image_shape
self.infer_mode = infer_mode
Expand Down
66 changes: 60 additions & 6 deletions mindocr/data/transforms/transforms_factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""
Create and run transformations from a config or predefined transformation pipeline
"""
import logging
from typing import Dict, List

import numpy as np

from .det_east_transforms import *
from .det_transforms import *
from .general_transforms import *
Expand All @@ -15,6 +12,63 @@
from .svtr_transform import *
from .table_transform import *

SUPPORTED_TRANSFORMS = {
"EASTProcessTrain": EASTProcessTrain,
"DetLabelEncode": DetLabelEncode,
"BorderMap": BorderMap,
"ShrinkBinaryMap": ShrinkBinaryMap,
"expand_poly": expand_poly,
"PSEGtDecode": PSEGtDecode,
"ValidatePolygons": ValidatePolygons,
"RandomCropWithBBox": RandomCropWithBBox,
"RandomCropWithMask": RandomCropWithMask,
"DetResize": DetResize,
"DecodeImage": DecodeImage,
"NormalizeImage": NormalizeImage,
"ToCHWImage": ToCHWImage,
"PackLoaderInputs": PackLoaderInputs,
"RandomScale": RandomScale,
"RandomColorAdjust": RandomColorAdjust,
"RandomRotate": RandomRotate,
"RandomHorizontalFlip": RandomHorizontalFlip,
"LayoutResize": LayoutResize,
"ImageStridePad": ImageStridePad,
"VQATokenLabelEncode": VQATokenLabelEncode,
"VQATokenPad": VQATokenPad,
"VQASerTokenChunk": VQASerTokenChunk,
"VQAReTokenRelation": VQAReTokenRelation,
"VQAReTokenChunk": VQAReTokenChunk,
"TensorizeEntitiesRelations": TensorizeEntitiesRelations,
"ABINetTransforms": ABINetTransforms,
"ABINetRecAug": ABINetRecAug,
"ABINetEval": ABINetEval,
"ABINetEvalTransforms": ABINetEvalTransforms,
"RecCTCLabelEncode": RecCTCLabelEncode,
"RecAttnLabelEncode": RecAttnLabelEncode,
"RecMasterLabelEncode": RecMasterLabelEncode,
"VisionLANLabelEncode": VisionLANLabelEncode,
"RecResizeImg": RecResizeImg,
"RecResizeNormForInfer": RecResizeNormForInfer,
"SVTRRecResizeImg": SVTRRecResizeImg,
"Rotate90IfVertical": Rotate90IfVertical,
"ClsLabelEncode": ClsLabelEncode,
"SARLabelEncode": SARLabelEncode,
"RobustScannerRecResizeImg": RobustScannerRecResizeImg,
"SVTRRecAug": SVTRRecAug,
"MultiLabelEncode": MultiLabelEncode,
"RecConAug": RecConAug,
"RecAug": RecAug,
"RecResizeImgForSVTR": RecResizeImgForSVTR,
"BaseRecLabelEncode": BaseRecLabelEncode,
"AttnLabelEncode": AttnLabelEncode,
"TableLabelEncode": TableLabelEncode,
"TableMasterLabelEncode": TableMasterLabelEncode,
"ResizeTableImage": ResizeTableImage,
"PaddingTableImage": PaddingTableImage,
"TableBoxEncode": TableBoxEncode,
"TableImageNorm": TableImageNorm,
}

__all__ = ["create_transforms", "run_transforms", "transforms_dbnet_icdar15"]
_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -45,9 +99,9 @@ def create_transforms(transform_pipeline: List, global_config: Dict = None):
param = {} if transform_config[trans_name] is None else transform_config[trans_name]
if global_config is not None:
param.update(global_config)
# TODO: assert undefined transform class

transform = eval(trans_name)(**param)
# For security reasons, we no longer use the eval function to dynamically obtain class objects.
# If you need to add a new transform class, please explicitly add it to the ``SUPPORTED_TRANSFORMS`` dict.
transform = SUPPORTED_TRANSFORMS[trans_name](**param)
transforms.append(transform)
elif callable(transform_config):
transforms.append(transform_config)
Expand Down
7 changes: 4 additions & 3 deletions mindocr/models/backbones/mindcv_models/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ def extract_archive(self, from_path: str, to_path: str = None) -> str:
def download_file(self, url: str, file_path: str, chunk_size: int = 1024):
"""Download a file."""

# no check certificate
# For security reasons, this repository code does not provide a function to disable SSL.
# If necessary, please disable SSL verification yourself.
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
# ctx.check_hostname = False
# ctx.verify_mode = ssl.CERT_NONE

# Define request headers.
headers = {"User-Agent": self.USER_AGENT}
Expand Down
38 changes: 22 additions & 16 deletions mindocr/postprocess/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,24 @@

__all__ = ["build_postprocess"]

supported_postprocess = (
det_db_postprocess.__all__
+ det_pse_postprocess.__all__
+ det_east_postprocess.__all__
+ rec_postprocess.__all__
+ cls_postprocess.__all__
+ rec_abinet_postprocess.__all__
+ kie_ser_postprocess.__all__
+ kie_re_postprocess.__all__
+ layout_postprocess.__all__
+ table_postprocess.__all__
)
SUPPORTED_POSTPROCESS = {
"DBPostprocess": DBPostprocess,
"PSEPostprocess": PSEPostprocess,
"EASTPostprocess": EASTPostprocess,
"CTCLabelDecode": CTCLabelDecode,
"RecCTCLabelDecode": RecCTCLabelDecode,
"RecAttnLabelDecode": RecAttnLabelDecode,
"RecMasterLabelDecode": RecMasterLabelDecode,
"VisionLANPostProcess": VisionLANPostProcess,
"SARLabelDecode": SARLabelDecode,
"ClsPostprocess": ClsPostprocess,
"ABINetLabelDecode": ABINetLabelDecode,
"VQASerTokenLayoutLMPostProcess": VQASerTokenLayoutLMPostProcess,
"VQAReTokenLayoutLMPostProcess": VQAReTokenLayoutLMPostProcess,
"YOLOv8Postprocess": YOLOv8Postprocess,
"Layoutlmv3Postprocess": Layoutlmv3Postprocess,
"TableMasterLabelDecode": TableMasterLabelDecode,
}


def build_postprocess(config: dict):
Expand All @@ -57,11 +63,11 @@ def build_postprocess(config: dict):
>>> postprocess
"""
proc = config.pop("name")
if proc in supported_postprocess:
postprocessor = eval(proc)(**config)
elif proc is None:
if proc is None:
return None
if proc in SUPPORTED_POSTPROCESS:
postprocessor = SUPPORTED_POSTPROCESS[proc](**config)
else:
raise ValueError(f"Invalid postprocess name {proc}, support postprocess are {supported_postprocess}")
raise ValueError(f"Invalid postprocess name {proc}, support postprocess are {SUPPORTED_POSTPROCESS.keys()}")

return postprocessor
5 changes: 4 additions & 1 deletion tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def _parse_options(opts: list):
"=" in opt_str
), "Invalid option {}. A valid option must be in the format of {{key_name}}={{value}}".format(opt_str)
k, v = opt_str.strip().split("=")
options[k] = yaml.load(v, Loader=yaml.Loader)
try:
options[k] = yaml.load(v, Loader=yaml.SafeLoader)
except yaml.YAMLError as e:
raise ValueError(f"Failed to parse value for key '{k}': {str(e)}") from e
# print('Parsed options: ', options)

return options
Expand Down