-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathimage_face_blur_mapper.py
136 lines (113 loc) · 4.86 KB
/
image_face_blur_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
from loguru import logger
from PIL import ImageFilter
from pydantic import NonNegativeFloat
from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import transfer_filename
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.mm_utils import (detect_faces, load_data_with_context,
load_image)
from data_juicer.utils.model_utils import get_model, prepare_model
from ..base_op import OPERATORS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_IMAGES
cv2 = LazyLoader('cv2', 'cv2')
OP_NAME = 'image_face_blur_mapper'
@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageFaceBlurMapper(Mapper):
"""Mapper to blur faces detected in images.
"""
_default_kwargs = {
'scaleFactor': 1.1,
'minNeighbors': 3,
'minSize': None,
'maxSize': None,
}
def __init__(self,
cv_classifier: str = '',
blur_type: str = 'gaussian',
radius: NonNegativeFloat = 2,
*args,
**kwargs):
"""
Initialization method.
:param cv_classifier: OpenCV classifier path for face detection.
By default, we will use 'haarcascade_frontalface_alt.xml'.
:param blur_type: Type of blur kernel, including
['mean', 'box', 'gaussian'].
:param radius: Radius of blur kernel.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self._init_parameters = self.remove_extra_parameters(locals())
if cv_classifier == '':
cv_classifier = os.path.join(cv2.data.haarcascades,
'haarcascade_frontalface_alt.xml')
if blur_type not in ['mean', 'box', 'gaussian']:
raise ValueError(
f'Blur_type [{blur_type}] is not supported. '
f'Can only be one of ["mean", "box", "gaussian"]. ')
if radius < 0:
raise ValueError('Radius must be >= 0. ')
if blur_type == 'mean':
self.blur = ImageFilter.BLUR
elif blur_type == 'box':
self.blur = ImageFilter.BoxBlur(radius)
else:
self.blur = ImageFilter.GaussianBlur(radius)
self.blur_type = blur_type
self.radius = radius
self.extra_kwargs = self._default_kwargs
for key in kwargs:
if key in self.extra_kwargs:
self.extra_kwargs[key] = kwargs[key]
self.model_key = prepare_model(model_type='opencv_classifier',
model_path=cv_classifier)
def process_single(self, sample, context=False):
# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.source_file] = []
return sample
if Fields.source_file not in sample or not sample[Fields.source_file]:
sample[Fields.source_file] = sample[self.image_key]
# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)
model = get_model(self.model_key)
# detect faces
face_detections = {}
for key, image in images.items():
face_detections[key] = detect_faces(image, model,
**self.extra_kwargs)
logger.debug(f'detections: {face_detections}')
# blur face regions
key_mapping = {}
for key, image in images.items():
dets = face_detections[key]
# only blur when detected face
if len(dets) > 0:
blured_image = image.copy()
for (x, y, w, h) in dets:
box = (x, y, x + w, y + h)
blured_roi = image.crop(box).filter(self.blur)
blured_image.paste(blured_roi, box)
blured_image_key = transfer_filename(key, OP_NAME,
**self._init_parameters)
blured_image.save(blured_image_key)
key_mapping[key] = blured_image_key
if context:
sample[Fields.context][blured_image_key] = blured_image
else:
key_mapping[key] = key
# when the file is modified, its source file needs to be updated.
for i, value in enumerate(loaded_image_keys):
if sample[Fields.source_file][i] != value:
if key_mapping[value] != value:
sample[Fields.source_file][i] = value
sample[self.image_key] = [
key_mapping[key] for key in loaded_image_keys
]
return sample