-
Notifications
You must be signed in to change notification settings - Fork 197
/
Copy pathdialog_intent_detection_mapper.py
226 lines (193 loc) · 9.81 KB
/
dialog_intent_detection_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import re
from typing import Dict, List, Optional
from loguru import logger
from pydantic import NonNegativeInt, PositiveInt
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model
OP_NAME = 'dialog_intent_detection_mapper'
# TODO: LLM-based inference.
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class DialogIntentDetectionMapper(Mapper):
"""
Mapper to generate user's intent labels in dialog. Input from
history_key, query_key and response_key. Output lists of
labels and analysis for queries in the dialog.
"""
DEFAULT_SYSTEM_PROMPT = (
'请判断用户和LLM多轮对话中用户的意图。\n'
'要求:\n'
'- 需要先进行分析,然后列出用户所具有的意图,下面是一个样例,请模仿样例格式输出'
'。\n'
'用户:你好,我最近对人工智能很感兴趣,能给我讲讲什么是机器学习吗?\n'
'意图分析:用户在请求信息,希望了解有关机器学习的基础知识。\n'
'意图类别:信息查找\n'
'LLM:你好!当然可以。机器学习是一种人工智能方法,允许计算机通过数据自动改进和学习。\n'
'用户:听起来很有趣,有没有推荐的入门书籍或资料?\n'
'意图分析:用户在请求建议,希望获取关于机器学习的入门资源。\n'
'意图类别:请求建议\n'
'LLM:有很多不错的入门书籍和资源。一本常被推荐的书是《Python机器学习实践》(Python'
' Machine Learning),它涵盖了基础知识和一些实际案例。此外,您还可以参考Coursera'
'或edX上的在线课程,这些课程提供了系统的学习路径。\n'
'用户:谢谢你的建议!我还想知道,学习机器学习需要什么样的数学基础?\n'
'意图分析:用户在寻求信息,希望了解学习机器学习所需的前提条件,特别是在数学方面。\n'
'意图类别:信息查找\n'
'LLM:学习机器学习通常需要一定的数学基础,特别是线性代数、概率论和统计学。这些数学领'
'域帮助理解算法的工作原理和数据模式分析。如果您对这些主题不太熟悉,建议先从相关基础'
'书籍或在线资源开始学习。\n'
'用户:明白了,我会先补习这些基础知识。再次感谢你的帮助!\n'
'意图分析:用户表达感谢,并表示计划付诸行动来补充所需的基础知识。\n'
'意图类别:其他')
DEFAULT_QUERY_TEMPLATE = '用户:{query}\n'
DEFAULT_RESPONSE_TEMPLATE = 'LLM:{response}\n'
DEFAULT_CANDIDATES_TEMPLATE = '备选意图类别:[{candidate_str}]'
DEFAULT_ANALYSIS_TEMPLATE = '意图分析:{analysis}\n'
DEFAULT_LABELS_TEMPLATE = '意图类别:{labels}\n'
DEFAULT_ANALYSIS_PATTERN = '意图分析:(.*?)\n'
DEFAULT_LABELS_PATTERN = '意图类别:(.*?)($|\n)'
def __init__(self,
api_model: str = 'gpt-4o',
intent_candidates: Optional[List[str]] = None,
max_round: NonNegativeInt = 10,
*,
labels_key: str = MetaKeys.dialog_intent_labels,
analysis_key: str = MetaKeys.dialog_intent_labels_analysis,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
query_template: Optional[str] = None,
response_template: Optional[str] = None,
candidate_template: Optional[str] = None,
analysis_template: Optional[str] = None,
labels_template: Optional[str] = None,
analysis_pattern: Optional[str] = None,
labels_pattern: Optional[str] = None,
try_num: PositiveInt = 3,
model_params: Dict = {},
sampling_params: Dict = {},
**kwargs):
"""
Initialization method.
:param api_model: API model name.
:param intent_candidates: The output intent candidates. Use the
intent labels of the open domain if it is None.
:param max_round: The max num of round in the dialog to build the
prompt.
:param labels_key: The key name in the meta field to store the
output labels. It is 'dialog_intent_labels' in default.
:param analysis_key: The key name in the meta field to store the
corresponding analysis. It is 'dialog_intent_labels_analysis'
in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
:param system_prompt: System prompt for the task.
:param query_template: Template for query part to build the input
prompt.
:param response_template: Template for response part to build the
input prompt.
:param candidate_template: Template for intent candidates to
build the input prompt.
:param analysis_template: Template for analysis part to build the
input prompt.
:param labels_template: Template for labels to build the
input prompt.
:param analysis_pattern: Pattern to parse the return intent
analysis.
:param labels_pattern: Pattern to parse the return intent
labels.
:param try_num: The number of retry attempts when there is an API
call error or output parsing error.
:param model_params: Parameters for initializing the API model.
:param sampling_params: Extra parameters passed to the API call.
e.g {'temperature': 0.9, 'top_p': 0.95}
:param kwargs: Extra keyword arguments.
"""
super().__init__(**kwargs)
self.intent_candidates = intent_candidates
self.max_round = max_round
self.labels_key = labels_key
self.analysis_key = analysis_key
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
self.response_template = response_template or \
self.DEFAULT_RESPONSE_TEMPLATE
self.candidate_template = candidate_template or \
self.DEFAULT_CANDIDATES_TEMPLATE
self.analysis_template = analysis_template or \
self.DEFAULT_ANALYSIS_TEMPLATE
self.labels_template = labels_template or \
self.DEFAULT_LABELS_TEMPLATE
self.analysis_pattern = analysis_pattern or \
self.DEFAULT_ANALYSIS_PATTERN
self.labels_pattern = labels_pattern or \
self.DEFAULT_LABELS_PATTERN
self.sampling_params = sampling_params
self.model_key = prepare_model(model_type='api',
model=api_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params)
self.try_num = try_num
def build_input(self, history, query):
if self.intent_candidates:
input_prompt = self.candidate_template.format(
candidate_str=','.join(self.intent_candidates))
else:
input_prompt = ''
if self.max_round > 0:
input_prompt += ''.join(history[-self.max_round * 4:])
input_prompt += self.query_template.format(query=query[0])
return input_prompt
def parse_output(self, response):
analysis = ''
labels = ''
match = re.search(self.analysis_pattern, response)
if match:
analysis = match.group(1)
match = re.search(self.labels_pattern, response)
if match:
labels = match.group(1)
return analysis, labels
def process_single(self, sample, rank=None):
meta = sample[Fields.meta]
if self.labels_key in meta and self.analysis_key in meta:
return sample
client = get_model(self.model_key, rank=rank)
analysis_list = []
labels_list = []
history = []
dialog = sample[self.history_key]
if self.query_key in sample and sample[self.query_key]:
if self.response_key in sample and sample[self.response_key]:
dialog.append(
(sample[self.query_key], sample[self.response_key]))
else:
dialog.append((sample[self.query_key], ''))
for qa in dialog:
input_prompt = self.build_input(history, qa)
messages = [{
'role': 'system',
'content': self.system_prompt,
}, {
'role': 'user',
'content': input_prompt,
}]
for _ in range(self.try_num):
try:
response = client(messages, **self.sampling_params)
analysis, labels = self.parse_output(response)
if len(analysis) > 0:
break
except Exception as e:
logger.warning(f'Exception: {e}')
analysis_list.append(analysis)
labels_list.append(labels)
history.append(self.query_template.format(query=qa[0]))
history.append(self.analysis_template.format(analysis=analysis))
history.append(self.labels_template.format(labels=labels))
history.append(self.response_template.format(response=qa[1]))
meta[self.labels_key] = labels_list
meta[self.analysis_key] = analysis_list
return sample