-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathextract_nickname_mapper.py
177 lines (155 loc) · 6.94 KB
/
extract_nickname_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import re
from typing import Dict, Optional
import numpy as np
from loguru import logger
from pydantic import PositiveInt
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'extract_nickname_mapper'
# TODO: LLM-based inference.
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class ExtractNicknameMapper(Mapper):
"""
Extract nickname relationship in the text.
"""
DEFAULT_SYSTEM_PROMPT = ('给定你一段文本,你的任务是将人物之间的称呼方式(昵称)提取出来。\n'
'要求:\n'
'- 需要给出说话人对被称呼人的称呼,不要搞反了。\n'
'- 相同的说话人和被称呼人最多给出一个最常用的称呼。\n'
'- 请不要输出互相没有昵称的称呼方式。\n'
'- 输出格式如下:\n'
'```\n'
'### 称呼方式1\n'
'- **说话人**:...\n'
'- **被称呼人**:...\n'
'- **...对...的昵称**:...\n'
'### 称呼方式2\n'
'- **说话人**:...\n'
'- **被称呼人**:...\n'
'- **...对...的昵称**:...\n'
'### 称呼方式3\n'
'- **说话人**:...\n'
'- **被称呼人**:...\n'
'- **...对...的昵称**:...\n'
'...\n'
'```\n')
DEFAULT_INPUT_TEMPLATE = '# 文本\n```\n{text}\n```\n'
DEFAULT_OUTPUT_PATTERN = r"""
\#\#\#\s*称呼方式(\d+)\s*
-\s*\*\*说话人\*\*\s*:\s*(.*?)\s*
-\s*\*\*被称呼人\*\*\s*:\s*(.*?)\s*
-\s*\*\*(.*?)对(.*?)的昵称\*\*\s*:\s*(.*?)(?=\#\#\#|\Z) # for double check
"""
def __init__(self,
api_model: str = 'gpt-4o',
*,
nickname_key: str = MetaKeys.nickname,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
input_template: Optional[str] = None,
output_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
drop_text: bool = False,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param nickname_key: The key name to store the nickname
relationship in the meta field. It's "nickname" in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param input_template: Template for building the model input.
:param output_pattern: Regular expression for parsing model output.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param drop_text: If drop the text in the output.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.nickname_key = nickname_key
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.input_template = input_template or self.DEFAULT_INPUT_TEMPLATE
self.output_pattern = output_pattern or self.DEFAULT_OUTPUT_PATTERN
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
self.drop_text = drop_text
def parse_output(self, raw_output):
pattern = re.compile(self.output_pattern, re.VERBOSE | re.DOTALL)
matches = pattern.findall(raw_output)
nickname_relations = []
for match in matches:
_, role1, role2, role1_tmp, role2_tmp, nickname = match
# for double check
if role1.strip() != role1_tmp.strip() or role2.strip(
) != role2_tmp.strip():
continue
role1 = role1.strip()
role2 = role2.strip()
nickname = nickname.strip()
# is name but not nickname
if role2 == nickname:
continue
if role1 and role2 and nickname:
nickname_relations.append((role1, role2, nickname))
nickname_relations = list(set(nickname_relations))
nickname_relations = [{
MetaKeys.source_entity: nr[0],
MetaKeys.target_entity: nr[1],
MetaKeys.relation_description: nr[2],
MetaKeys.relation_keywords: ['nickname'],
MetaKeys.relation_strength: None
} for nr in nickname_relations]
return nickname_relations
def process_single(self, sample, rank=None):
# check if it's generated already
if self.nickname_key in sample[Fields.meta]:
return sample
client = get_model(self.model_key, rank=rank)
input_prompt = self.input_template.format(text=sample[self.text_key])
messages = [{
'role': 'system',
'content': self.system_prompt
}, {
'role': 'user',
'content': input_prompt
}]
nickname_relations = [{
MetaKeys.source_entity:
'',
MetaKeys.target_entity:
'',
MetaKeys.relation_description:
'',
MetaKeys.relation_keywords:
np.array([], dtype=str),
MetaKeys.relation_strength:
None
}]
for _ in range(self.try_num):
try:
output = client(messages, **self.sampling_params)
results = self.parse_output(output)
if len(results) > 0:
nickname_relations = results
break
except Exception as e:
logger.warning(f'Exception: {e}')
sample[Fields.meta][self.nickname_key] = nickname_relations
if self.drop_text:
sample.pop(self.text_key)
return sample