-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathexplain_relevant_responses.py
159 lines (145 loc) · 4.92 KB
/
explain_relevant_responses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from pathlib import Path
import click
import pandas as pd
from patents4IPPC.explainability import (
HierarchicalTransformerTextSimilarityExplainer,
HuggingFaceTextSimilarityExplainer
)
def format_hf_explainer_output(
explainer_output, output_dir, query_id, response_id
):
(
query_tokens,
query_attributions,
response_tokens,
response_attributions
) = explainer_output
output_subdir = Path(output_dir) / f"q_{query_id}_r_{response_id}"
output_subdir.mkdir(parents=True, exist_ok=True)
query_attributions = pd.DataFrame({
"token": query_tokens,
"attribution_score": query_attributions
})
query_attributions.to_csv(
str(output_subdir / f"q_{query_id}_tokens.csv"), index=False
)
response_attributions = pd.DataFrame({
"token": response_tokens,
"attribution_score": response_attributions
})
response_attributions.to_csv(
str(output_subdir / f"r_{response_id}_tokens.csv"), index=False
)
def format_ht_explainer_output(
explainer_output, output_dir, query_id, response_id
):
(
query_segments,
query_segment_attributions,
response_segments,
response_segment_attributions,
) = explainer_output
output_subdir = Path(output_dir) / f"q_{query_id}_r_{response_id}"
output_subdir.mkdir(parents=True, exist_ok=True)
query_segments_and_attributions = pd.DataFrame({
"segment": query_segments,
"attribution_scores": query_segment_attributions
})
query_segments_and_attributions.to_csv(
str(output_subdir / f"q_{query_id}_segments.csv"), index=False
)
response_segments_and_attributions = pd.DataFrame({
"segment": response_segments,
"attribution_scores": response_segment_attributions
})
response_segments_and_attributions.to_csv(
str(output_subdir / f"r_{response_id}_segments.csv"), index=False
)
@click.command()
@click.option(
"-mc", "--model-checkpoint", "path_to_model",
type=click.Path(exists=True),
required=True,
help="Path to a pre-trained model whose predictions you want to explain."
)
@click.option(
"-mt", "--model-type",
type=click.Choice(["huggingface", "hierarchical"]),
required=True,
help="Type of the pre-trained model."
)
@click.option(
"-p", "--pooling-mode",
type=click.Choice(["cls", "max", "mean"]),
default=None,
help=("Pooling strategy to transform token embeddings into sentence "
"embeddings. Only required when --model-type is \"huggingface\". "
"Note that when --model-type is \"hierarchical\", the pooling "
"strategy for the segment Transformer is automatically extracted "
"from the configuration files contained in the model checkpoint.")
)
@click.option(
"-pr", "--predictions", "path_to_predictions",
type=click.Path(exists=True, dir_okay=False),
required=True,
help=("Path to a .csv file containing predictions. The file should have "
"at least the following columns: query, query_id, response, "
"response_id, score.")
)
@click.option(
"-s", "--steps-for-integrated-gradients",
type=int,
default=50,
help="Number of steps for approximating integrated gradients."
)
@click.option(
"-o", "--output-dir",
type=click.Path(file_okay=False),
required=True,
help="Path top a directory where explanations will be saved."
)
def main(
path_to_model,
model_type,
pooling_mode,
path_to_predictions,
steps_for_integrated_gradients,
output_dir
):
if model_type == "huggingface":
assert pooling_mode is not None, \
("You must provide a --pooling-mode when using a "
"\"huggingface\" model.")
explainer = HuggingFaceTextSimilarityExplainer(
path_to_model, pooling_mode=pooling_mode
)
elif model_type == "hierarchical":
explainer = HierarchicalTransformerTextSimilarityExplainer(
path_to_model,
disable_gradients_computation_for_segment_transformer=True
)
predictions = pd.read_csv(path_to_predictions)
for _, row in predictions.iterrows():
explainer_output = explainer.explain(
row["query"],
row["response"],
n_steps=steps_for_integrated_gradients,
internal_batch_size=1,
normalize_attributions=False
)
if model_type == "huggingface":
format_hf_explainer_output(
explainer_output,
output_dir,
row["query_id"],
row["response_id"]
)
elif model_type == "hierarchical":
format_ht_explainer_output(
explainer_output,
output_dir,
row["query_id"],
row["response_id"]
)
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter