-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathget_heuristic_stats.py
124 lines (97 loc) · 5.33 KB
/
get_heuristic_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import sys
sys.path.append('..')
from dataclasses import dataclass, field
from typing import Optional
import re
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoTokenizer
from transformers import GlueDataTrainingArguments as DataTrainingArguments, TrainingArguments
from transformers import GlueDataset, default_data_collator, Trainer, glue_compute_metrics
from tqdm import trange
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac
# import nlpaug.augmenter.char as nac
# import nlpaug.augmenter.sentence as nas
os.environ["WANDB_DISABLED"] = "true"
model_id = 'bert_base'
model_path = '/home/nlp/experiments/big_small/bert_base/epoch_4'
config = AutoConfig.from_pretrained(model_path,
num_labels=3)
# output_attentions=True)
model = AutoModelForSequenceClassification.from_pretrained(model_path,
config=config)
tokenizer = AutoTokenizer.from_pretrained(model_path)
training_args = TrainingArguments(output_dir='/home/nlp/experiments/aug', seed=42, per_device_eval_batch_size=8)
# mnli_easy_data_args = DataTrainingArguments(task_name = 'mnli',
# max_seq_length= 32,
# data_dir = '/home/nlp/cartography/filtered/' + model_id + '_easy_mnli/cartography_confidence_0.01/MNLI')
# mnli_valid_args = DataTrainingArguments(task_name = 'mnli',
# max_seq_length= 96,
# data_dir = '/home/nlp/data/glue_data/MNLI')
mnli_hard_data_args = DataTrainingArguments(task_name = 'mnli',
max_seq_length= 96,
data_dir = '/home/nlp/cartography/filtered/' + model_id + '_hard_mnli/cartography_confidence_0.05/MNLI')
def build_compute_metrics_fn(task_name):
def compute_metrics_fn(p):
preds = np.argmax(p.predictions, axis=1)
return glue_compute_metrics('mnli', preds, p.label_ids)
return compute_metrics_fn
# mnli_easy_dataset = GlueDataset(mnli_easy_data_args, tokenizer, mode="train")
mnli_hard_dataset = GlueDataset(mnli_hard_data_args, tokenizer, mode="train")
# mnli_valid = GlueDataset(mnli_valid_args, tokenizer, mode="dev")
# mnli_easy_dataset_valid = GlueDataset(mnli_easy_data_args, tokenizer, mode="dev")
# mnli_hard_dataset_valid = GlueDataset(mnli_hard_data_args, tokenizer, mode="dev")
# aug = naw.WordEmbsAug(
# model_type='fasttext', model_path='/home/nlp/data/'+'wiki-news-300d-1M.vec',
# action="insert")
# aug = nac.RandomCharAug(action="substitute", aug_word_p=0.5, aug_char_p=0.3)
# aug = nac.RandomCharAug(action="swap", aug_word_p=0.5, aug_char_p=0.3)
# aug = nac.RandomCharAug(action="delete", aug_word_p=0.1, aug_char_p=0.1, aug_word_min=0, aug_char_min=0)
aug = nac.RandomCharAug(action="insert") #, aug_word_p=0.2, aug_char_p=0.2,)
# aug = naw.WordEmbsAug(
# model_type='word2vec',model_path= '/home/nlp/data/'+'GoogleNews-vectors-negative300.bin',
# action="substitute")
# aug = naw.WordEmbsAug(
# model_type='glove',model_path= '/home/nlp/data/'+'glove.6B.300d.txt',
# action="substitute")
# aug = naw.ContextualWordEmbsAug(
# model_path='roberta-base', action="substitute")
# aug = naw.SynonymAug(aug_src='wordnet')
# aug = naw.AntonymAug()
def roberta_augment_dataset(aug, dataset):
modified_dataset = []
for i in trange(len(dataset)):
text = tokenizer.decode(dataset[i].input_ids, skip_special_tokens=False)
hypothesis = re.search('<s>(.+?)</s>', text).group(1)
premise = re.search('</s>(.+?)</s>', text).group(1).replace('</s>', '')
modified_hypothesis = aug.augment(hypothesis)
modified_premise = aug.augment(premise)
dict_output = tokenizer(modified_hypothesis, modified_premise, padding='max_length', max_length=128, truncation=True)
dict_output['label'] = dataset[i].label
modified_dataset.append(dict_output)
return modified_dataset
def bert_augment_dataset(aug, dataset):
modified_dataset = []
for i in trange(len(dataset)):
text = tokenizer.decode(dataset[i].input_ids, skip_special_tokens=False)
hypothesis = re.search('[CLS](.+?)[PAD]', text).group(1).replace('LS] ', '').replace(' [SE', '')
premise = re.search('[PAD](.+?)[PAD]', text).group(1).replace('] ', '').replace(' [SE', '')
modified_hypothesis = aug.augment(hypothesis)
modified_premise = aug.augment(premise)
dict_output = tokenizer(modified_hypothesis, modified_premise, padding='max_length', max_length=128, truncation=True)
dict_output['label'] = dataset[i].label
modified_dataset.append(dict_output)
return modified_dataset
# augmented_dataset = mnli_hard_dataset
augmented_dataset = bert_augment_dataset(aug, mnli_hard_dataset)
trainer = Trainer(model=model,
args=training_args,
eval_dataset=augmented_dataset,
tokenizer=tokenizer,
data_collator=default_data_collator,
compute_metrics=build_compute_metrics_fn('mnli'))
print(trainer.evaluate())