diff --git a/src/containers/image-pii-detection/Dockerfile b/src/containers/image-pii-detection/Dockerfile new file mode 100644 index 00000000..3e08eb7b --- /dev/null +++ b/src/containers/image-pii-detection/Dockerfile @@ -0,0 +1,27 @@ +FROM public.ecr.aws/lambda/python:3.9 + +ARG FUNCTION_DIR="/opt/ml/code/" +COPY requirements.txt ${FUNCTION_DIR}/requirements.txt +RUN python3.9 -m pip install -r ${FUNCTION_DIR}/requirements.txt + +COPY main.py parser_factory.py ${FUNCTION_DIR}/ +COPY parsers/ ${FUNCTION_DIR}/parsers/ + +ARG OCR_MODEL_URL="https://aws-gcr-solutions-assets.s3.cn-northwest-1.amazonaws.com.cn/ai-solution-kit/infer-ocr-model/standard" +ARG OCR_MODEL_VERSION="v1.0.0" +ARG FD_MODEL_URL="https://aws-gcr-solutions-assets.s3.cn-northwest-1.amazonaws.com.cn/ai-solution-kit/face-detection" +ARG FD_MODEL_VERSION="1.2.0" + +RUN yum install -y wget +RUN mkdir -p ${FUNCTION_DIR}/ocr_model +RUN wget -c $OCR_MODEL_URL/$OCR_MODEL_VERSION/classifier.onnx -O ${FUNCTION_DIR}/ocr_model/classifier.onnx +RUN wget -c $OCR_MODEL_URL/$OCR_MODEL_VERSION/det_standard.onnx -O ${FUNCTION_DIR}/ocr_model/det_standard.onnx +RUN wget -c $OCR_MODEL_URL/$OCR_MODEL_VERSION/keys_v1.txt -O ${FUNCTION_DIR}/ocr_model/keys_v1.txt +RUN wget -c $OCR_MODEL_URL/$OCR_MODEL_VERSION/rec_standard.onnx -O ${FUNCTION_DIR}/ocr_model/rec_standard.onnx +RUN mkdir -p ${FUNCTION_DIR}/fd_model +RUN wget -c ${FD_MODEL_URL}/${FD_MODEL_VERSION}/det.onnx -O ${FUNCTION_DIR}/fd_model/det.onnx + +WORKDIR ${FUNCTION_DIR} + +# Command can be overwritten by providing a different command in the template directly. +ENTRYPOINT ["python"] diff --git a/src/containers/image-pii-detection/__init__.py b/src/containers/image-pii-detection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/containers/image-pii-detection/main.py b/src/containers/image-pii-detection/main.py new file mode 100644 index 00000000..ebc79204 --- /dev/null +++ b/src/containers/image-pii-detection/main.py @@ -0,0 +1,208 @@ +import json +import boto3 +import os +import sys +import pandas as pd +import base64 +import argparse +import copy +import logging +import tempfile + +from parser_factory import ParserFactory + +def check_include_file_type(file_info, include_file_types): + """ + Check if the file type is included in the include_file_types list. + + :param file_info: file info + :param include_file_types: list of file types to include + + """ + file_type = file_info['file_type'] + + if file_type in include_file_types: + return True + else: + return False + +def organize_table_info(table_name, result_bucket_name, original_bucket_name, file_info, columns, file_category): + + description = json.dumps(file_info, ensure_ascii=False) + s3_location = f"s3://{result_bucket_name}/parser_results/{table_name}/" + input_format = 'org.apache.hadoop.mapred.TextInputFormat' + output_format = 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' + table_type = 'EXTERNAL_TABLE' + serde_info = {'SerializationLibrary': 'org.apache.hadoop.hive.serde2.OpenCSVSerde', + 'Parameters': {'field.delim': ','}} + parameters = {'originalFileBucketName': original_bucket_name, + 'originalFileType': file_info['file_type'], + 'originalFilePath': file_info['file_path'], + 'originalFileSample': ', '.join(file_info['sample_files'][:10]), + 'originalFileCategory': file_category, + 'Unstructured': 'true', + 'classification': 'csv'} + glue_table_columns = [{'Name': 'index', 'Type': 'string'}] + for column in columns: + glue_table_columns.append({'Name': column, 'Type': 'string'}) + + glue_table_info = { + 'Name': table_name, + 'Description': description, + 'StorageDescriptor': { + 'Columns': glue_table_columns, + 'Location': s3_location, + 'InputFormat': input_format, + 'OutputFormat': output_format, + 'SerdeInfo': serde_info + }, + 'PartitionKeys': [], + 'TableType': table_type, + 'Parameters': parameters + } + return glue_table_info + +def batch_process_files(s3_client, bucket_name, file_info, file_category): + """ + Batch process files in a folder with the same schema. + + :param bucket_name: S3 bucket name + :param file_info: file info + + Sample file_info: + { + "file_type": ".jpeg", + "file_path": "test_images/human_faces", + "sample_files": [ + "1" + ] + } + + """ + file_contents = {} + + file_type = file_info['file_type'] + file_path = file_info['file_path'] + sample_files = file_info['sample_files'] + + if file_category == 'detection_files': + + parser = ParserFactory.create_parser(file_type=file_type, s3_client=s3_client) + + for sample_file in sample_files: + object_key = f"{file_path}/{sample_file}{file_type}" + file_content = parser.load_content(bucket_name, object_key) + file_contents[f"{sample_file}"] = file_content + + elif file_category == 'include_files': + for sample_file in sample_files: + file_contents[f"{sample_file}"] = ['This file is marked as Contains-PII.'] + + elif file_category == 'exclude_files': + for sample_file in sample_files: + file_contents[f"{sample_file}"] = ['This file is marked as Non-PII.'] + + return file_contents + +def process_file(parser, bucket_name, object_key): + """ + Process a single file. + """ + file_content = parser.load_content(bucket_name, object_key) + + json_format_content = {} + json_format_content[f"{object_key}"] = file_content + + return json_format_content + +def create_glue_table(glue_client, database_name, table_name, glue_table_info): + + # Check if table exists + try: + response = glue_client.get_table( + DatabaseName=database_name, + Name=table_name + ) + print(f"Table '{table_name}' exists in database '{database_name}'. Updating table...") + response = glue_client.update_table( + DatabaseName=database_name, + TableInput=glue_table_info + ) + except glue_client.exceptions.EntityNotFoundException: + print(f"Table '{table_name}' does not exist in database '{database_name}'. Creating table...") + response = glue_client.create_table( + DatabaseName=database_name, + TableInput=glue_table_info + ) + + print(response) + +def main(param_dict): + original_bucket_name = param_dict['SourceBucketName'] + crawler_result_bucket_name = param_dict['ResultBucketName'] + region_name = param_dict['RegionName'] + + crawler_result_object_key = f"crawler_results/{original_bucket_name}_info.json" + destination_database = f"SDPS-unstructured-{original_bucket_name}" + + s3_client = boto3.client('s3', region_name = region_name) + glue_client = boto3.client('glue', region_name = region_name) + + # 1. Create a Glue Database + try: + response = glue_client.create_database( + DatabaseInput={ + 'Name': destination_database + } + ) + except glue_client.exceptions.AlreadyExistsException: + print(f"Database '{destination_database}' already exists. Skipping database creation...") + + # 2. Download the crawler result from S3 and + with tempfile.NamedTemporaryFile(mode='w') as temp: + temp_file_path = temp.name + s3_client.download_file(Bucket=crawler_result_bucket_name, Key=crawler_result_object_key, Filename=temp_file_path) + bucket_info = json.load(open(temp_file_path, 'r')) + + + # 4. Batch process files in same folder with same type + original_file_bucket_name = bucket_info['bucket_name'] + for file_category in ['detection_files', 'include_files', 'exclude_files']: + files = bucket_info[file_category] + for file_path, file_info in files.items(): + print(f"Processing {file_path}...") + file_contents = batch_process_files(s3_client, original_file_bucket_name, file_info, file_category) + + # convert file_contents to dataframe + df = pd.DataFrame.from_dict(file_contents, orient='index') + df = df.transpose() + columns = df.columns.tolist() + + # dump file_info into string and encode in base64 as filename + table_name = file_path.replace('/', '_') + table_name = table_name.replace('.', '_') + table_name = original_file_bucket_name + '_' + table_name + + # save to csv and upload to s3 + with tempfile.NamedTemporaryFile(mode='w') as temp: + csv_file_path = temp.name + df.to_csv(csv_file_path, header=False) + s3_client.upload_file(csv_file_path, crawler_result_bucket_name, f"parser_results/{table_name}/result.csv") + + glue_table_info = organize_table_info(table_name, crawler_result_bucket_name, original_file_bucket_name, file_info, columns, file_category) + create_glue_table(glue_client, destination_database, table_name, glue_table_info) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(...) + parser.add_argument('--SourceBucketName', type=str, default='icyxu-glue-assets-member-a', + help='crawler_result_bucket_name') + parser.add_argument('--ResultBucketName', type=str, default='icyxu-glue-assets-member-a', + help='crawler_result_bucket_name') + parser.add_argument('--RegionName', type=str, default='us-west-2', + help='crawler_result_object_key') + + args, _ = parser.parse_known_args() + param_dict = copy.copy(vars(args)) + + main(param_dict) diff --git a/src/containers/image-pii-detection/parser_factory.py b/src/containers/image-pii-detection/parser_factory.py new file mode 100644 index 00000000..750d9559 --- /dev/null +++ b/src/containers/image-pii-detection/parser_factory.py @@ -0,0 +1,20 @@ +from parsers import PdfParser, TxtParser, DocParser, HtmlParser, EmailParser, ImageParser + +class ParserFactory: + @staticmethod + def create_parser(file_type, s3_client): + if file_type in ['.pdf', '.PDF']: + return PdfParser(s3_client=s3_client) + elif file_type in ['.txt', '.TXT']: + return TxtParser(s3_client=s3_client) + elif file_type in ['.doc', '.docx', '.DOC', '.DOCX']: + return DocParser(s3_client=s3_client) + elif file_type in ['.html', '.htm', '.HTML', '.HTM']: + return HtmlParser(s3_client=s3_client) + elif file_type in ['.eml', '.EML']: + return EmailParser(s3_client=s3_client) + elif file_type in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']: + return ImageParser(s3_client=s3_client, fd_model_path='./fd_model/', + ocr_model_path='./ocr_model/') + else: + raise ValueError('Unsupported file type') \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/__init__.py b/src/containers/image-pii-detection/parsers/__init__.py new file mode 100644 index 00000000..948b5962 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/__init__.py @@ -0,0 +1,7 @@ +from .pdf_parser import PdfParser +from .txt_parser import TxtParser +from .doc_parser import DocParser +from .html_parser import HtmlParser +from .email_parser import EmailParser + +from .image_parser import ImageParser \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/doc_parser.py b/src/containers/image-pii-detection/parsers/doc_parser.py new file mode 100644 index 00000000..a0005470 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/doc_parser.py @@ -0,0 +1,19 @@ + +import docx +from .parser import BaseParser + +class DocParser(BaseParser): + def __init__(self, s3_client): + super().__init__(s3_client=s3_client) + + def parse_file(self, doc_path): + """ + Extracts text from a doc file and returns a string of content. + """ + + doc = docx.Document(doc_path) + file_content = "" + for para in doc.paragraphs: + file_content += para.text + "\n" + + return [file_content] diff --git a/src/containers/image-pii-detection/parsers/email_parser.py b/src/containers/image-pii-detection/parsers/email_parser.py new file mode 100644 index 00000000..9f192aa7 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/email_parser.py @@ -0,0 +1,25 @@ + +import os +from .parser import BaseParser +from email.parser import Parser as PyEmailParser + +class EmailParser(BaseParser): + def __init__(self, s3_client): + super().__init__(s3_client=s3_client) + + + def parse_file(self, eml_path): + """ + Extracts text from a eml file and returns a string of content. + """ + + with open(eml_path) as stream: + parser = PyEmailParser() + message = parser.parse(stream) + + file_content = [] + for part in message.walk(): + if part.get_content_type().startswith('text/plain'): + file_content.append(part.get_payload()) + + return ['\n'.join(file_content)] diff --git a/src/containers/image-pii-detection/parsers/html_parser.py b/src/containers/image-pii-detection/parsers/html_parser.py new file mode 100644 index 00000000..0c2b2989 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/html_parser.py @@ -0,0 +1,152 @@ + +import re +import six + +from bs4 import BeautifulSoup + +from .parser import BaseParser + +class HtmlParser(BaseParser): + def __init__(self, s3_client): + super().__init__(s3_client=s3_client) + # additional PdfParser constructor code here + + def parse_file(self, html_path): + """ + Extracts text from a html file and returns a string of content. + """ + + with open(html_path, "rb") as stream: + soup = BeautifulSoup(stream, 'lxml') + + # Convert tables to ASCII ones + soup = self._replace_tables(soup) + + # Join inline elements + soup = self._join_inlines(soup) + + # Make HTML + html = '' + elements = soup.find_all(True) + elements = [el for el in filter(self._visible, elements)] + for elem in elements: + string = elem.string + if string is None: + string = self._find_any_text(elem) + string = string.strip() + if len(string) > 0: + html += "\n" + string + "\n" + return [html] + + _disallowed_names = [ + 'style', 'script', '[document]', 'head', 'title', 'html', 'meta', + 'link', 'body', + ] + + _inline_tags = [ + 'b', 'big', 'i', 'small', 'tt', 'abbr', 'acronym', 'cite', 'code', + 'dfn', 'em', 'kbd', 'strong', 'samp', 'var', 'a', 'bdo', 'br', 'img', + 'map', 'object', 'q', 'script', 'span', 'sub', 'sup', 'button', + 'input', 'label', 'select', 'textarea', + ] + + def _visible(self, element): + """Used to filter text elements that have invisible text on the page. + """ + if element.name in self._disallowed_names: + return False + elif re.match(u'', six.text_type(element.extract())): + return False + return True + + def _inline(self, element): + """Used to check whether given element can be treated as inline + element (without new line after). + """ + if element.name in self._inline_tags: + return True + return False + + def _find_any_text(self, tag): + """Looks for any possible text within given tag. + """ + text = '' + if tag is not None: + text = six.text_type(tag) + text = re.sub(r'(<[^>]+>)', '', text) + text = re.sub(r'\s', ' ', text) + text = text.strip() + return text + + def _parse_tables(self, soup): + """Returns array containing basic informations about tables for ASCII + replacement (look: _replace_tables()). + """ + tables = [] + for t in soup.find_all('table'): + t_dict = {'width': 0, 'table': t, 'trs': [], 'col_width': {}} + trs = t.find_all('tr') + if len(trs) > 0: + for tr in trs: + tr_dict = [] + tds = tr.find_all('th') + tr.find_all('td') + if len(tds) > 0: + for i, td in enumerate(tds): + td_text = self._find_any_text(td) + length = len(td_text) + if i in t_dict['col_width']: + t_dict['col_width'][i] = max( + length, + t_dict['col_width'][i] + ) + else: + t_dict['col_width'][i] = length + tr_dict.append({ + 'text': td_text, + 'colspan': int(td.get('colspan', 1)), + }) + t_dict['trs'].append(tr_dict) + for col in t_dict['col_width']: + t_dict['width'] += t_dict['col_width'][col] + tables.append(t_dict) + return tables + + def _replace_tables(self, soup, v_separator=' | ', h_separator='-'): + """Replaces elements with its ASCII equivalent. + """ + tables = self._parse_tables(soup) + v_sep_len = len(v_separator) + v_left_sep = v_separator.lstrip() + for t in tables: + html = '' + trs = t['trs'] + h_length = 1 + (v_sep_len * len(t['col_width'])) + t['width'] + head_foot = (h_separator * h_length) + "\n" + html += head_foot + for tr in trs: + html += v_left_sep + for i, td in enumerate(tr): + text = td['text'] + col_width = t['col_width'][i] + v_sep_len + if td['colspan'] > 1: + for j in range(td['colspan']-1): + j = j + 1 + if (i+j) < len(t['col_width']): + col_width += t['col_width'][i+j] + v_sep_len + html += ('%' + str(col_width) + 's') % (text + v_separator) + html += "\n" + html += head_foot + new_table = soup.new_tag('div') + new_table.string = html + t['table'].replace_with(new_table) + return soup + + def _join_inlines(self, soup): + """Unwraps inline elements defined in self._inline_tags. + """ + elements = soup.find_all(True) + for elem in elements: + if self._inline(elem): + elem.unwrap() + return soup + diff --git a/src/containers/image-pii-detection/parsers/image_analysis/__init__.py b/src/containers/image-pii-detection/parsers/image_analysis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/containers/image-pii-detection/parsers/image_analysis/face_detection/__init__.py b/src/containers/image-pii-detection/parsers/image_analysis/face_detection/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/containers/image-pii-detection/parsers/image_analysis/face_detection/face_detection_main.py b/src/containers/image-pii-detection/parsers/image_analysis/face_detection/face_detection_main.py new file mode 100644 index 00000000..19960c8b --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_analysis/face_detection/face_detection_main.py @@ -0,0 +1,281 @@ + +import numpy as np +import onnxruntime +import os.path as osp +import cv2 + +cuda_available = False + +def distance2bbox(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + Returns: + Tensor: Decoded bboxes. + """ + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + if max_shape is not None: + x1 = x1.clamp(min=0, max=max_shape[1]) + y1 = y1.clamp(min=0, max=max_shape[0]) + x2 = x2.clamp(min=0, max=max_shape[1]) + y2 = y2.clamp(min=0, max=max_shape[0]) + return np.stack([x1, y1, x2, y2], axis=-1) + +def distance2kps(points, distance, max_shape=None): + """Decode distance prediction to bounding box. + Args: + points (Tensor): Shape (n, 2), [x, y]. + distance (Tensor): Distance from the given point to 4 + boundaries (left, top, right, bottom). + max_shape (tuple): Shape of the image. + Returns: + Tensor: Decoded bboxes. + """ + preds = [] + for i in range(0, distance.shape[1], 2): + px = points[:, i%2] + distance[:, i] + py = points[:, i%2+1] + distance[:, i+1] + if max_shape is not None: + px = px.clamp(min=0, max=max_shape[1]) + py = py.clamp(min=0, max=max_shape[0]) + preds.append(px) + preds.append(py) + return np.stack(preds, axis=-1) + +class SCRFD: + def __init__(self, model_file=None, session=None): + self.model_file = model_file + self.session = session + self.taskname = 'detection' + self.batched = False + if self.session is None: + assert self.model_file is not None + assert osp.exists(self.model_file) + self.session = onnxruntime.InferenceSession(self.model_file, providers=['CUDAExecutionProvider'] if cuda_available else ['CPUExecutionProvider']) + self.center_cache = {} + self.nms_thresh = 0.4 + self.det_thresh = 0.5 + self._init_vars() + + def _init_vars(self): + input_cfg = self.session.get_inputs()[0] + input_shape = input_cfg.shape + #print(input_shape) + if isinstance(input_shape[2], str): + self.input_size = None + else: + self.input_size = tuple(input_shape[2:4][::-1]) + self.input_size=(736,736) + #print('image_size:', self.image_size) + input_name = input_cfg.name + self.input_shape = input_shape + outputs = self.session.get_outputs() + if len(outputs[0].shape) == 3: + self.batched = True + output_names = [] + for o in outputs: + output_names.append(o.name) + self.input_name = input_name + self.output_names = output_names + self.input_mean = 127.5 + self.input_std = 128.0 + #print(self.output_names) + #assert len(outputs)==10 or len(outputs)==15 + self.use_kps = False + self._anchor_ratio = 1.0 + self._num_anchors = 1 + if len(outputs)==6: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + elif len(outputs)==9: + self.fmc = 3 + self._feat_stride_fpn = [8, 16, 32] + self._num_anchors = 2 + self.use_kps = True + elif len(outputs)==10: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + elif len(outputs)==15: + self.fmc = 5 + self._feat_stride_fpn = [8, 16, 32, 64, 128] + self._num_anchors = 1 + self.use_kps = True + def prepare(self, ctx_id, **kwargs): + if ctx_id<0: + self.session.set_providers(['CPUExecutionProvider']) + nms_thresh = kwargs.get('nms_thresh', None) + if nms_thresh is not None: + self.nms_thresh = nms_thresh + det_thresh = kwargs.get('det_thresh', None) + if det_thresh is not None: + self.det_thresh = det_thresh + input_size = kwargs.get('input_size', None) + if input_size is not None: + if self.input_size is not None: + print('warning: det_size is already set in scrfd model, ignore') + else: + self.input_size = input_size + + def forward(self, img, threshold): + scores_list = [] + bboxes_list = [] + kpss_list = [] + input_size = tuple(img.shape[0:2][::-1]) + blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) + net_outs = self.session.run(self.output_names, {self.input_name : blob}) + + input_height = blob.shape[2] + input_width = blob.shape[3] + fmc = self.fmc + for idx, stride in enumerate(self._feat_stride_fpn): + # If model support batch dim, take first output + if self.batched: + scores = net_outs[idx][0] + bbox_preds = net_outs[idx + fmc][0] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2][0] * stride + # If model doesn't support batching take output as is + else: + scores = net_outs[idx] + bbox_preds = net_outs[idx + fmc] + bbox_preds = bbox_preds * stride + if self.use_kps: + kps_preds = net_outs[idx + fmc * 2] * stride + height = input_height // stride + width = input_width // stride + K = height * width + key = (height, width, stride) + if key in self.center_cache: + anchor_centers = self.center_cache[key] + else: + #solution-1, c style: + #anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 ) + #for i in range(height): + # anchor_centers[i, :, 1] = i + #for i in range(width): + # anchor_centers[:, i, 0] = i + + #solution-2: + #ax = np.arange(width, dtype=np.float32) + #ay = np.arange(height, dtype=np.float32) + #xv, yv = np.meshgrid(np.arange(width), np.arange(height)) + #anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32) + + #solution-3: + anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32) + #print(anchor_centers.shape) + + anchor_centers = (anchor_centers * stride).reshape( (-1, 2) ) + if self._num_anchors>1: + anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) ) + if len(self.center_cache)<100: + self.center_cache[key] = anchor_centers + + pos_inds = np.where(scores>=threshold)[0] + bboxes = distance2bbox(anchor_centers, bbox_preds) + pos_scores = scores[pos_inds] + pos_bboxes = bboxes[pos_inds] + scores_list.append(pos_scores) + bboxes_list.append(pos_bboxes) + #print(anchor_centers.shape, kps_preds.shape) + if self.use_kps: + kpss = distance2kps(anchor_centers, kps_preds) + #kpss = kps_preds + kpss = kpss.reshape( (kpss.shape[0], -1, 2) ) + pos_kpss = kpss[pos_inds] + kpss_list.append(pos_kpss) + return scores_list, bboxes_list, kpss_list + + def detect(self, img, input_size = None, max_num=0, metric='default'): + assert input_size is not None or self.input_size is not None + input_size = self.input_size if input_size is None else input_size + + im_ratio = float(img.shape[0]) / img.shape[1] + model_ratio = float(input_size[1]) / input_size[0] + if im_ratio>model_ratio: + new_height = input_size[1] + new_width = int(new_height / im_ratio) + else: + new_width = input_size[0] + new_height = int(new_width * im_ratio) + det_scale = float(new_height) / img.shape[0] + resized_img = cv2.resize(img, (new_width, new_height)) + det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 ) + det_img[:new_height, :new_width, :] = resized_img + + scores_list, bboxes_list, kpss_list = self.forward(det_img, self.det_thresh) + + scores = np.vstack(scores_list) + scores_ravel = scores.ravel() + order = scores_ravel.argsort()[::-1] + bboxes = np.vstack(bboxes_list) / det_scale + if self.use_kps: + kpss = np.vstack(kpss_list) / det_scale + pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False) + pre_det = pre_det[order, :] + keep = self.nms(pre_det) + det = pre_det[keep, :] + if self.use_kps: + kpss = kpss[order,:,:] + kpss = kpss[keep,:,:] + else: + kpss = None + if max_num > 0 and det.shape[0] > max_num: + area = (det[:, 2] - det[:, 0]) * (det[:, 3] - + det[:, 1]) + img_center = img.shape[0] // 2, img.shape[1] // 2 + offsets = np.vstack([ + (det[:, 0] + det[:, 2]) / 2 - img_center[1], + (det[:, 1] + det[:, 3]) / 2 - img_center[0] + ]) + offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) + if metric=='max': + values = area + else: + values = area - offset_dist_squared * 2.0 # some extra weight on the centering + bindex = np.argsort( + values)[::-1] # some extra weight on the centering + bindex = bindex[0:max_num] + det = det[bindex, :] + if kpss is not None: + kpss = kpss[bindex, :] + return det, kpss + + def nms(self, dets): + thresh = self.nms_thresh + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/__init__.py b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/imaug/__init__.py b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/imaug/__init__.py new file mode 100644 index 00000000..2751a0d4 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/imaug/__init__.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from .operators import * + +def transform(data, ops=None): + """ transform """ + if ops is None: + ops = [] + for op in ops: + data = op(data) + if data is None: + return None + return data + + +def create_operators(op_param_list, global_config=None): + """ + create operators based on the config + Args: + params(list): a dict list, used to create some operators + """ + assert isinstance(op_param_list, list), ('operator config should be a list') + ops = [] + for operator in op_param_list: + assert isinstance(operator, + dict) and len(operator) == 1, "yaml format error" + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name] + if global_config is not None: + param.update(global_config) + op = eval(op_name)(**param) + ops.append(op) + return ops \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/imaug/operators.py b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/imaug/operators.py new file mode 100644 index 00000000..93a8eabe --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/imaug/operators.py @@ -0,0 +1,209 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +import six +import cv2 +import numpy as np + + +class DecodeImage(object): + """ decode image """ + + def __init__(self, img_mode='RGB', channel_first=False, **kwargs): + self.img_mode = img_mode + self.channel_first = channel_first + + def __call__(self, data): + img = data['image'] + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + img = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(img, 1) + if img is None: + return None + if self.img_mode == 'GRAY': + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif self.img_mode == 'RGB': + assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape) + img = img[:, :, ::-1] + + if self.channel_first: + img = img.transpose((2, 0, 1)) + + data['image'] = img + return data + + +class NormalizeImage(object): + """ normalize image such as substract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + + assert isinstance(img, + np.ndarray), "invalid input 'img' in NormalizeImage" + data['image'] = ( + img.astype('float32') * self.scale - self.mean) / self.std + return data + + +class ToCHWImage(object): + """ convert hwc image to chw image + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + data['image'] = img.transpose((2, 0, 1)) + return data + + +class KeepKeys(object): + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list + + +class DetResizeForTest(object): + def __init__(self, **kwargs): + super(DetResizeForTest, self).__init__() + self.resize_type = 0 + if 'image_shape' in kwargs: + self.image_shape = kwargs['image_shape'] + self.resize_type = 1 + elif 'limit_side_len' in kwargs: + self.limit_side_len = kwargs['limit_side_len'] + self.limit_type = kwargs.get('limit_type', 'min') + elif 'resize_long' in kwargs: + self.resize_type = 2 + self.resize_long = kwargs.get('resize_long', 960) + else: + self.limit_side_len = 736 + self.limit_type = 'min' + + def __call__(self, data): + img = data['image'] + src_h, src_w, _ = img.shape + + if self.resize_type == 0: + # img, shape = self.resize_image_type0(img) + img, [ratio_h, ratio_w] = self.resize_image_type0(img) + elif self.resize_type == 2: + img, [ratio_h, ratio_w] = self.resize_image_type2(img) + else: + # img, shape = self.resize_image_type1(img) + img, [ratio_h, ratio_w] = self.resize_image_type1(img) + data['image'] = img + data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) + return data + + def resize_image_type1(self, img): + resize_h, resize_w = self.image_shape + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + # return img, np.array([ori_h, ori_w]) + return img, [ratio_h, ratio_w] + + def resize_image_type0(self, img): + """ + resize image to a size multiple of 32 which is required by the network + args: + img(array): array with shape [h, w, c] + return(tuple): + img, (ratio_h, ratio_w) + """ + limit_side_len = self.limit_side_len + h, w, _ = img.shape + + # limit the max side + if self.limit_type == 'max': + if max(h, w) > limit_side_len: + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + else: + if min(h, w) < limit_side_len: + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + else: + ratio = 1. + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = int(round(resize_h / 32) * 32) + resize_w = int(round(resize_w / 32) * 32) + + try: + if int(resize_w) <= 0 or int(resize_h) <= 0: + return None, (None, None) + img = cv2.resize(img, (int(resize_w), int(resize_h))) + except: + print(img.shape, resize_w, resize_h) + sys.exit(0) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + # return img, np.array([h, w]) + return img, [ratio_h, ratio_w] + + def resize_image_type2(self, img): + h, w, _ = img.shape + + resize_w = w + resize_h = h + + # Fix the longer side + if resize_h > resize_w: + ratio = float(self.resize_long) / resize_h + else: + ratio = float(self.resize_long) / resize_w + + resize_h = int(resize_h * ratio) + resize_w = int(resize_w * ratio) + + max_stride = 128 + resize_h = (resize_h + max_stride - 1) // max_stride * max_stride + resize_w = (resize_w + max_stride - 1) // max_stride * max_stride + img = cv2.resize(img, (int(resize_w), int(resize_h))) + ratio_h = resize_h / float(h) + ratio_w = resize_w / float(w) + + return img, [ratio_h, ratio_w] \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/ocr_main.py b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/ocr_main.py new file mode 100644 index 00000000..1005d476 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/ocr_main.py @@ -0,0 +1,386 @@ +import copy +import math +import time +import os + +import numpy as np +import onnxruntime +from PIL import Image +import cv2 + +from .imaug import create_operators, transform +from .postprocess import build_post_process + +cuda_available = False + +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and ( + _boxes[i + 1][0][0] < _boxes[i][0][0] + ): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + +class TextClassifier(): + def __init__(self, model_path): + self.weights_path = model_path + 'classifier.onnx' + + self.cls_image_shape = [3, 48, 192] + self.cls_batch_num = 30 + self.cls_thresh = 0.9 + self.use_zero_copy_run = False + postprocess_params = { + 'name': 'ClsPostProcess', + "label_list": ['0', '180'], + } + self.postprocess_op = build_post_process(postprocess_params) + + self.ort_session = onnxruntime.InferenceSession(self.weights_path, providers=['CUDAExecutionProvider'] if cuda_available else ['CPUExecutionProvider']) + + def resize_norm_img(self, img): + imgC, imgH, imgW = self.cls_image_shape + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = np.array(Image.fromarray(img).resize((resized_w, imgH))) + #resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if self.cls_image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def __call__(self, img_list): + img_list = copy.deepcopy(img_list) + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + + cls_res = [['', 0.0]] * img_num + batch_num = self.cls_batch_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + starttime = time.time() + ort_inputs = {self.ort_session.get_inputs()[0].name: norm_img_batch} + prob_out = self.ort_session.run(None, ort_inputs)[0] + cls_result = self.postprocess_op(prob_out) + for rno in range(len(cls_result)): + label, score = cls_result[rno] + cls_res[indices[beg_img_no + rno]] = [label, score] + if '180' in label and score > self.cls_thresh: + img_list[indices[beg_img_no + rno]] = np.array(Image.fromarray(img_list[indices[beg_img_no + rno]]).transpose(Image.ROTATE_180)) + return img_list, cls_res + +class TextDetector(): + def __init__(self, model_path): + self.weights_path = model_path + 'det_standard.onnx' + + self.det_algorithm = 'DB' + self.use_zero_copy_run = False + + pre_process_list = [{ + 'DetResizeForTest': { + 'limit_side_len': 960, + 'limit_type': 'max' + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] + + postprocess_params = {} + postprocess_params['name'] = 'DBPostProcess' + postprocess_params["thresh"] = 0.3 + postprocess_params["box_thresh"] = 0.3 + postprocess_params["max_candidates"] = 1000 + postprocess_params["unclip_ratio"] = 1.6 + postprocess_params["use_dilation"] = True + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.ort_session = onnxruntime.InferenceSession(self.weights_path, providers=['CUDAExecutionProvider'] if cuda_available else ['CPUExecutionProvider']) + _ = self.ort_session.run(None, {"backbone": np.zeros([1, 3, 64, 64], dtype='float32')}) + + # load_pytorch_weights + + def order_points_clockwise(self, pts): + """ + reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py + # sort the points based on their x-coordinates + """ + xSorted = pts[np.argsort(pts[:, 0]), :] + + # grab the left-most and right-most points from the sorted + # x-roodinate points + leftMost = xSorted[:2, :] + rightMost = xSorted[2:, :] + + # now, sort the left-most coordinates according to their + # y-coordinates so we can grab the top-left and bottom-left + # points, respectively + leftMost = leftMost[np.argsort(leftMost[:, 1]), :] + (tl, bl) = leftMost + + rightMost = rightMost[np.argsort(rightMost[:, 1]), :] + (tr, br) = rightMost + + rect = np.array([tl, tr, br, bl], dtype="float32") + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + ort_inputs = {self.ort_session.get_inputs()[0].name: img} + preds = {} + preds['maps'] = self.ort_session.run(None, ort_inputs)[0] + + post_result = self.postprocess_op(preds, shape_list) + dt_boxes = post_result[0]['points'] + if self.det_algorithm == "SAST" and self.det_sast_polygon: + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) + else: + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + return dt_boxes + +class TextRecognizer(): + def __init__(self, model_path): + self.weights_path = model_path + 'rec_standard.onnx' + + self.limited_max_width = 1280 + self.limited_min_width = 16 + + self.rec_image_shape = [3, 32, 320] + self.character_type = 'ch' + self.rec_batch_num = 6 + self.rec_algorithm = 'CRNN' + self.use_zero_copy_run = False + postprocess_params = { + 'name': 'CTCLabelDecode', + "character_type": 'ch', + "character_dict_path": model_path + 'keys_v1.txt', + "use_space_char": True + } + self.postprocess_op = build_post_process(postprocess_params) + + self.ort_session = onnxruntime.InferenceSession(self.weights_path, providers=['CUDAExecutionProvider'] if cuda_available else ['CPUExecutionProvider']) + _ = self.ort_session.run(None, {"backbone": np.zeros([1, 3, 32, 64], dtype='float32')}) + + def resize_norm_img(self, img, max_wh_ratio): + imgC, imgH, imgW = self.rec_image_shape + assert imgC == img.shape[2] + if self.character_type == "ch": + imgW = int((32 * max_wh_ratio)) + imgW = max(min(imgW, self.limited_max_width), self.limited_min_width) + h, w = img.shape[:2] + ratio = w / float(h) + ratio_imgH = math.ceil(imgH * ratio) + ratio_imgH = max(ratio_imgH, self.limited_min_width) + if ratio_imgH > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = np.array(Image.fromarray(img).resize((resized_w, imgH))) + #resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def __call__(self, img_list): + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the recognition process + indices = np.argsort(np.array(width_list)) + + # rec_res = [] + rec_res = [['', 0.0]] * img_num + batch_num = self.rec_batch_num + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + # h, w = img_list[ino].shape[0:2] + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + ort_inputs = {self.ort_session.get_inputs()[0].name: norm_img_batch} + preds = self.ort_session.run(None, ort_inputs)[0] + + rec_result = self.postprocess_op(preds) + for rno in range(len(rec_result)): + rec_res[indices[beg_img_no + rno]] = rec_result[rno] + return rec_res + +class TextSystem: + def __init__(self, model_path): + self.text_detector = TextDetector(model_path) + self.text_recognizer = TextRecognizer(model_path) + self.drop_score = 0.3 + self.text_classifier = TextClassifier(model_path) + + def get_rotate_crop_image(self, img, points): + """ + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + """ + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]), + ) + ) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]), + ) + ) + pts_std = np.float32( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, + (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC, + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + def __call__(self, img): + ori_im = img.copy() + dt_boxes = self.text_detector(img) + if dt_boxes is None: + return None, None + img_crop_list = [] + + dt_boxes = sorted_boxes(dt_boxes) + + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + img_crop = self.get_rotate_crop_image(ori_im, tmp_box) + img_crop_list.append(img_crop) + img_crop_list, angle_list = self.text_classifier(img_crop_list) + + rec_res = self.text_recognizer(img_crop_list) + filter_boxes, filter_rec_res = [], [] + for box, rec_reuslt in zip(dt_boxes, rec_res): + text, score = rec_reuslt + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_reuslt) + return filter_boxes, filter_rec_res diff --git a/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/__init__.py b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/__init__.py new file mode 100644 index 00000000..85bc130c --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/__init__.py @@ -0,0 +1,27 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import copy + +__all__ = ['build_post_process'] + + +def build_post_process(config, global_config=None): + from .db_postprocess import DBPostProcess + from .rec_postprocess import CTCLabelDecode, AttnLabelDecode + from .cls_postprocess import ClsPostProcess + + support_dict = [ + 'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' + ] + + config = copy.deepcopy(config) + module_name = config.pop('name') + if global_config is not None: + config.update(global_config) + assert module_name in support_dict, Exception( + 'post process only support {}'.format(support_dict)) + module_class = eval(module_name)(**config) + return module_class \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/cls_postprocess.py b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/cls_postprocess.py new file mode 100644 index 00000000..f16536c3 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/cls_postprocess.py @@ -0,0 +1,15 @@ +class ClsPostProcess(object): + """ Convert between text-label and text-index """ + + def __init__(self, label_list, **kwargs): + super(ClsPostProcess, self).__init__() + self.label_list = label_list + + def __call__(self, preds, label=None, *args, **kwargs): + pred_idxs = preds.argmax(axis=1) + decode_out = [(self.label_list[idx], preds[i, idx]) + for i, idx in enumerate(pred_idxs)] + if label is None: + return decode_out + label = [(self.label_list[idx], 1.0) for idx in label] + return decode_out, label \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/db_postprocess.py b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/db_postprocess.py new file mode 100644 index 00000000..741a2c88 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/db_postprocess.py @@ -0,0 +1,139 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + + +class DBPostProcess(object): + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__(self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + **kwargs): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.dilation_kernel = None if not use_dilation else np.array( + [[1, 1], [1, 1]]) + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + ''' + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + ''' + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, + cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes = [] + scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self.box_thresh > score: + continue + + box = self.unclip(points).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.astype(np.int16)) + scores.append(score) + return np.array(boxes, dtype=np.int16), scores + + def unclip(self, box): + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + else: + mask = segmentation[batch_index] + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, + src_w, src_h) + + boxes_batch.append({'points': boxes}) + return boxes_batch \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/rec_postprocess.py b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/rec_postprocess.py new file mode 100644 index 00000000..5b3245c5 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_analysis/general_ocr/postprocess/rec_postprocess.py @@ -0,0 +1,138 @@ +import numpy as np + + +class BaseRecLabelDecode(object): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False): + support_character_type = [ + 'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean' + ] + assert character_type in support_character_type, "Only {} are supported now but get {}".format( + support_character_type, character_type) + + if character_type == "en": + self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" + dict_character = list(self.character_str) + elif character_type in ["ch", "french", "german", "japan", "korean"]: + self.character_str = "" + assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + self.character_str += line + if use_space_char: + self.character_str += " " + dict_character = list(self.character_str) + elif character_type == "en_sensitive": + # same with ASTER setting (use 94 char). + import string + self.character_str = string.printable[:-6] + dict_character = list(self.character_str) + else: + raise NotImplementedError + self.character_type = character_type + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + + def add_special_char(self, dict_character): + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=True): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def get_ignored_tokens(self): + return [0] # for ctc blank + + +class CTCLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(CTCLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob) + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def add_special_char(self, dict_character): + dict_character = ['blank'] + dict_character + return dict_character + + +class AttnLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + **kwargs): + super(AttnLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + self.beg_str = "sos" + self.end_str = "eos" + + def add_special_char(self, dict_character): + dict_character = [self.beg_str, self.end_str] + dict_character + return dict_character + + def __call__(self, text): + text = self.decode(text) + return text + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/image_parser.py b/src/containers/image-pii-detection/parsers/image_parser.py new file mode 100644 index 00000000..d33a0e23 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/image_parser.py @@ -0,0 +1,79 @@ + +import os +from .parser import BaseParser + +from PIL import Image +import numpy as np + +from .image_analysis.face_detection import face_detection_main +from .image_analysis.general_ocr import ocr_main + +def check_keywords_exist(det_results, keywords): + for keyword in keywords: + found = False + for dt_result in det_results: + text, score = dt_result[1] + if keyword in text and score >= 0.5: + found = True + break + if not found: + return False + return True + +class ImageParser(BaseParser): + def __init__(self, s3_client, fd_model_path, ocr_model_path): + super().__init__(s3_client=s3_client) + self.face_detection_model = face_detection_main.SCRFD(model_file = fd_model_path + 'det.onnx') + self.ocr_model = ocr_main.TextSystem(model_path = ocr_model_path) + # additional PdfParser constructor code here + + def read_img(self, file_path): + img = np.array(Image.open(file_path).convert('RGB'))[:, :, :3] + + return img + + def face_detection_pipeline(self, img): + bboxes, kpss = self.face_detection_model.detect(img) + return bboxes, kpss + + def ocr_pipeline(self, img): + img = img[:,:,::-1] + dt_boxes, rec_res = self.ocr_model(img) + dt_results = list(zip(dt_boxes, rec_res)) + return dt_results + + def parse_file(self, file_path): + file_content = [] + img = self.read_img(file_path) + + face_detection_result, _ = self.face_detection_pipeline(img) + ocr_pipeline_result = self.ocr_pipeline(img) + + contain_face = True if len(face_detection_result) > 0 else False + business_license_keywords = ['营', '业', '执', '照', '信用代码'] + cnid_keywords = ['公', '民', '身', '份', '号', '码'] + car_license_keywords = ['机动车', '驾驶证'] + + contain_business_license = check_keywords_exist(ocr_pipeline_result, business_license_keywords) + contain_cnid = check_keywords_exist(ocr_pipeline_result, cnid_keywords) + contain_car_license = check_keywords_exist(ocr_pipeline_result, car_license_keywords) + + if contain_face: + if contain_cnid: + file_content.append('ChineseID') + elif contain_car_license: + file_content.append('CarLicense') + else: + file_content.append('Face') + else: + if contain_business_license: + file_content.append('BusinessLicense') + elif contain_car_license: + file_content.append('CarLicense') + else: + pass + + return file_content + + + diff --git a/src/containers/image-pii-detection/parsers/parser.py b/src/containers/image-pii-detection/parsers/parser.py new file mode 100644 index 00000000..5bc41ab7 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/parser.py @@ -0,0 +1,65 @@ +import os +import magic +import re +from tempfile import NamedTemporaryFile + + +class BaseParser: + def __init__(self, s3_client): + # constructor code here + # self.region = region + self.s3_client=s3_client + pass + + def parse_file(self, file_path, **kwargs): + """This method must be overwritten by child classes to extract raw + text from a file path. + """ + raise NotImplementedError('must be overwritten by child classes') + + def load_content(self, bucket, object_key): + """ + Downloads the file from S3. + """ + # Create a temporary file + with NamedTemporaryFile() as temp_file: + self.s3_client.download_file(Bucket=bucket, Key=object_key, Filename=temp_file.name) + file_path = temp_file.name + + file_content = self.parse_file(file_path) + processed_content = self.postprocess_content(file_content) + + return processed_content + + def postprocess_content(self, file_content): + """ + For each item in content, if size is bigger than 128, split it into multiple items. + """ + # split all_page_content into a list of lines and remove empty lines + processed_content=[] + for page in file_content: + # page_content = [] + lines = [line for line in page.splitlines() if line.strip() != ''] + + for item in lines: + if len(item) > 128: + # Split item by . and extend to processed_content + split_items = re.split(r'(?<=[.。;])', item) + # + for split_item in split_items: + if len(split_item) != 0: + # Avoid too long item + processed_content.append(split_item[:256]) + else: + processed_content.append(item) + + return processed_content + + def get_encoding(self, file_path): + """ + Returns the encoding of the file. + """ + blob = open(file_path, 'rb').read() + m = magic.Magic(mime_encoding=True) + encoding = m.from_buffer(blob) + return encoding \ No newline at end of file diff --git a/src/containers/image-pii-detection/parsers/pdf_parser.py b/src/containers/image-pii-detection/parsers/pdf_parser.py new file mode 100644 index 00000000..e1a07bb1 --- /dev/null +++ b/src/containers/image-pii-detection/parsers/pdf_parser.py @@ -0,0 +1,30 @@ + +import os +import boto3 +from pypdf import PdfReader + +from .parser import BaseParser + +class PdfParser(BaseParser): + def __init__(self, s3_client): + super().__init__(s3_client=s3_client) + + + def parse_file(self, pdf_path): + """ + Extracts text from a PDF file and returns a list of lines. + """ + + # Create a PDF reader object + pdf_reader = PdfReader(pdf_path) + file_content = [] + + # Loop through each page in the PDF file + for page_num in range(len(pdf_reader.pages)): + page = pdf_reader.pages[page_num] + + # Extract the text from the page and append it to the string + page_content = page.extract_text() + file_content.append(page_content) + + return file_content diff --git a/src/containers/image-pii-detection/parsers/txt_parser.py b/src/containers/image-pii-detection/parsers/txt_parser.py new file mode 100644 index 00000000..7bf503bf --- /dev/null +++ b/src/containers/image-pii-detection/parsers/txt_parser.py @@ -0,0 +1,18 @@ + +import os +from .parser import BaseParser + +class TxtParser(BaseParser): + def __init__(self, s3_client): + super().__init__(s3_client=s3_client) + + def parse_file(self, txt_path): + """ + Extracts text from a TXT file and returns a list of lines. + """ + + # Read the file + with open(txt_path, 'r') as file: + file_content = file.read() + + return [file_content] diff --git a/src/containers/image-pii-detection/requirements.txt b/src/containers/image-pii-detection/requirements.txt new file mode 100644 index 00000000..a50b4853 --- /dev/null +++ b/src/containers/image-pii-detection/requirements.txt @@ -0,0 +1,20 @@ +requests +boto3 +six==1.16.0 +opencv-python-headless==4.5.3.56 +numpy<=1.23.5 +onnxruntime +Pillow==8.4.0 +pyclipper==1.3.0 +Shapely==1.7.1 +base64image==0.5.1 +urllib3==1.26.6 +python-dateutil==2.8.2 +certifi==2022.12.7 +idna==2.10 +chardet==4.0.0 +pypdf==3.12.1 +python-magic==0.4.27 +python-docx==0.8.11 +bs4==0.0.1 +pandas==1.5.3 \ No newline at end of file diff --git a/src/containers/image-pii-detection/utils.py b/src/containers/image-pii-detection/utils.py new file mode 100644 index 00000000..e69de29b