diff --git a/mindocr/data/transforms/det_east_transforms.py b/mindocr/data/transforms/det_east_transforms.py index e4bf44206..db62e3171 100644 --- a/mindocr/data/transforms/det_east_transforms.py +++ b/mindocr/data/transforms/det_east_transforms.py @@ -1,3 +1,5 @@ +import ast +import json import math import cv2 @@ -414,8 +416,17 @@ def _extract_vertices(self, data_labels): """ vertices_list = [] labels_list = [] - data_labels = eval(data_labels) - for data_label in data_labels: + try: + parsed_data = json.loads(data_labels) + except json.JSONDecodeError: + try: + parsed_data = ast.literal_eval(data_labels) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Invalid data format: {str(e)}") from e + + if not isinstance(parsed_data, list): + raise ValueError("Data labels should be a list") + for data_label in parsed_data: vertices = data_label["points"] vertices = [item for point in vertices for item in point] vertices_list.append(vertices) diff --git a/mindocr/data/transforms/svtr_transform.py b/mindocr/data/transforms/svtr_transform.py index 4281c5261..4e479d48e 100644 --- a/mindocr/data/transforms/svtr_transform.py +++ b/mindocr/data/transforms/svtr_transform.py @@ -546,10 +546,19 @@ def __init__(self, max_text_length, character_dict_path=None, use_space_char=Fal self.ctc_encode = CTCLabelEncodeForSVTR(max_text_length, character_dict_path, use_space_char, **kwargs) self.gtc_encode_type = gtc_encode + # Pls explicitly specify the supported gtc_encode classes and obtain the class objects through dictionaries. + supported_gtc_encode = {} if gtc_encode is None: self.gtc_encode = SARLabelEncodeForSVTR(max_text_length, character_dict_path, use_space_char, **kwargs) else: - self.gtc_encode = eval(gtc_encode)(max_text_length, character_dict_path, use_space_char, **kwargs) + # Mindocr currently does not have a module that requires a custom `gtc_encode` input parameter, and will not + # enter this branch at present. If it is supported later, please directly obtain the class reference through + # a specific dict, and do not use the `eval` function. + if gtc_encode not in supported_gtc_encode: + raise ValueError(f"Get unsupported gtc_encode {gtc_encode}") + self.gtc_encode = supported_gtc_encode[gtc_encode]( + max_text_length, character_dict_path, use_space_char, **kwargs + ) def __call__(self, data): data_ctc = copy.deepcopy(data) @@ -925,7 +934,7 @@ def __init__( jitter_prob=0.4, blur_prob=0.4, hsv_aug_prob=0.4, - **kwargs + **kwargs, ): self.crop_prob = crop_prob self.reverse_prob = reverse_prob @@ -973,7 +982,7 @@ def __init__( jitter_prob=0.4, blur_prob=0.4, hsv_aug_prob=0.4, - **kwargs + **kwargs, ): self.tia_prob = tia_prob self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob, jitter_prob, blur_prob, hsv_aug_prob) @@ -1078,7 +1087,7 @@ def __init__( character_dict_path=".mindocr/utils/dict/ch_dict.txt", padding=True, width_downsample_ratio=0.125, - **kwargs + **kwargs, ): self.image_shape = image_shape self.infer_mode = infer_mode diff --git a/mindocr/data/transforms/transforms_factory.py b/mindocr/data/transforms/transforms_factory.py index 60f544075..a1d551612 100644 --- a/mindocr/data/transforms/transforms_factory.py +++ b/mindocr/data/transforms/transforms_factory.py @@ -1,11 +1,8 @@ """ Create and run transformations from a config or predefined transformation pipeline """ -import logging from typing import Dict, List -import numpy as np - from .det_east_transforms import * from .det_transforms import * from .general_transforms import * @@ -15,6 +12,63 @@ from .svtr_transform import * from .table_transform import * +SUPPORTED_TRANSFORMS = { + "EASTProcessTrain": EASTProcessTrain, + "DetLabelEncode": DetLabelEncode, + "BorderMap": BorderMap, + "ShrinkBinaryMap": ShrinkBinaryMap, + "expand_poly": expand_poly, + "PSEGtDecode": PSEGtDecode, + "ValidatePolygons": ValidatePolygons, + "RandomCropWithBBox": RandomCropWithBBox, + "RandomCropWithMask": RandomCropWithMask, + "DetResize": DetResize, + "DecodeImage": DecodeImage, + "NormalizeImage": NormalizeImage, + "ToCHWImage": ToCHWImage, + "PackLoaderInputs": PackLoaderInputs, + "RandomScale": RandomScale, + "RandomColorAdjust": RandomColorAdjust, + "RandomRotate": RandomRotate, + "RandomHorizontalFlip": RandomHorizontalFlip, + "LayoutResize": LayoutResize, + "ImageStridePad": ImageStridePad, + "VQATokenLabelEncode": VQATokenLabelEncode, + "VQATokenPad": VQATokenPad, + "VQASerTokenChunk": VQASerTokenChunk, + "VQAReTokenRelation": VQAReTokenRelation, + "VQAReTokenChunk": VQAReTokenChunk, + "TensorizeEntitiesRelations": TensorizeEntitiesRelations, + "ABINetTransforms": ABINetTransforms, + "ABINetRecAug": ABINetRecAug, + "ABINetEval": ABINetEval, + "ABINetEvalTransforms": ABINetEvalTransforms, + "RecCTCLabelEncode": RecCTCLabelEncode, + "RecAttnLabelEncode": RecAttnLabelEncode, + "RecMasterLabelEncode": RecMasterLabelEncode, + "VisionLANLabelEncode": VisionLANLabelEncode, + "RecResizeImg": RecResizeImg, + "RecResizeNormForInfer": RecResizeNormForInfer, + "SVTRRecResizeImg": SVTRRecResizeImg, + "Rotate90IfVertical": Rotate90IfVertical, + "ClsLabelEncode": ClsLabelEncode, + "SARLabelEncode": SARLabelEncode, + "RobustScannerRecResizeImg": RobustScannerRecResizeImg, + "SVTRRecAug": SVTRRecAug, + "MultiLabelEncode": MultiLabelEncode, + "RecConAug": RecConAug, + "RecAug": RecAug, + "RecResizeImgForSVTR": RecResizeImgForSVTR, + "BaseRecLabelEncode": BaseRecLabelEncode, + "AttnLabelEncode": AttnLabelEncode, + "TableLabelEncode": TableLabelEncode, + "TableMasterLabelEncode": TableMasterLabelEncode, + "ResizeTableImage": ResizeTableImage, + "PaddingTableImage": PaddingTableImage, + "TableBoxEncode": TableBoxEncode, + "TableImageNorm": TableImageNorm, +} + __all__ = ["create_transforms", "run_transforms", "transforms_dbnet_icdar15"] _logger = logging.getLogger(__name__) @@ -45,9 +99,9 @@ def create_transforms(transform_pipeline: List, global_config: Dict = None): param = {} if transform_config[trans_name] is None else transform_config[trans_name] if global_config is not None: param.update(global_config) - # TODO: assert undefined transform class - - transform = eval(trans_name)(**param) + # For security reasons, we no longer use the eval function to dynamically obtain class objects. + # If you need to add a new transform class, please explicitly add it to the ``SUPPORTED_TRANSFORMS`` dict. + transform = SUPPORTED_TRANSFORMS[trans_name](**param) transforms.append(transform) elif callable(transform_config): transforms.append(transform_config) diff --git a/mindocr/models/backbones/mindcv_models/download.py b/mindocr/models/backbones/mindcv_models/download.py index 09ce8b3e7..cf06877c5 100644 --- a/mindocr/models/backbones/mindcv_models/download.py +++ b/mindocr/models/backbones/mindcv_models/download.py @@ -99,10 +99,11 @@ def extract_archive(self, from_path: str, to_path: str = None) -> str: def download_file(self, url: str, file_path: str, chunk_size: int = 1024): """Download a file.""" - # no check certificate + # For security reasons, this repository code does not provide a function to disable SSL. + # If necessary, please disable SSL verification yourself. ctx = ssl.create_default_context() - ctx.check_hostname = False - ctx.verify_mode = ssl.CERT_NONE + # ctx.check_hostname = False + # ctx.verify_mode = ssl.CERT_NONE # Define request headers. headers = {"User-Agent": self.USER_AGENT} diff --git a/mindocr/postprocess/builder.py b/mindocr/postprocess/builder.py index c01dd4485..e0f80da60 100644 --- a/mindocr/postprocess/builder.py +++ b/mindocr/postprocess/builder.py @@ -23,18 +23,24 @@ __all__ = ["build_postprocess"] -supported_postprocess = ( - det_db_postprocess.__all__ - + det_pse_postprocess.__all__ - + det_east_postprocess.__all__ - + rec_postprocess.__all__ - + cls_postprocess.__all__ - + rec_abinet_postprocess.__all__ - + kie_ser_postprocess.__all__ - + kie_re_postprocess.__all__ - + layout_postprocess.__all__ - + table_postprocess.__all__ -) +SUPPORTED_POSTPROCESS = { + "DBPostprocess": DBPostprocess, + "PSEPostprocess": PSEPostprocess, + "EASTPostprocess": EASTPostprocess, + "CTCLabelDecode": CTCLabelDecode, + "RecCTCLabelDecode": RecCTCLabelDecode, + "RecAttnLabelDecode": RecAttnLabelDecode, + "RecMasterLabelDecode": RecMasterLabelDecode, + "VisionLANPostProcess": VisionLANPostProcess, + "SARLabelDecode": SARLabelDecode, + "ClsPostprocess": ClsPostprocess, + "ABINetLabelDecode": ABINetLabelDecode, + "VQASerTokenLayoutLMPostProcess": VQASerTokenLayoutLMPostProcess, + "VQAReTokenLayoutLMPostProcess": VQAReTokenLayoutLMPostProcess, + "YOLOv8Postprocess": YOLOv8Postprocess, + "Layoutlmv3Postprocess": Layoutlmv3Postprocess, + "TableMasterLabelDecode": TableMasterLabelDecode, +} def build_postprocess(config: dict): @@ -57,11 +63,11 @@ def build_postprocess(config: dict): >>> postprocess """ proc = config.pop("name") - if proc in supported_postprocess: - postprocessor = eval(proc)(**config) - elif proc is None: + if proc is None: return None + if proc in SUPPORTED_POSTPROCESS: + postprocessor = SUPPORTED_POSTPROCESS[proc](**config) else: - raise ValueError(f"Invalid postprocess name {proc}, support postprocess are {supported_postprocess}") + raise ValueError(f"Invalid postprocess name {proc}, support postprocess are {SUPPORTED_POSTPROCESS.keys()}") return postprocessor diff --git a/tools/arg_parser.py b/tools/arg_parser.py index 44674a836..2443459cf 100644 --- a/tools/arg_parser.py +++ b/tools/arg_parser.py @@ -62,7 +62,10 @@ def _parse_options(opts: list): "=" in opt_str ), "Invalid option {}. A valid option must be in the format of {{key_name}}={{value}}".format(opt_str) k, v = opt_str.strip().split("=") - options[k] = yaml.load(v, Loader=yaml.Loader) + try: + options[k] = yaml.load(v, Loader=yaml.SafeLoader) + except yaml.YAMLError as e: + raise ValueError(f"Failed to parse value for key '{k}': {str(e)}") from e # print('Parsed options: ', options) return options