-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathindex.html
182 lines (150 loc) · 27.8 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>12_AttentionIsAllYouNeed</title>
<link rel="stylesheet" href="https://stackedit.io/style.css" />
</head>
<body class="stackedit">
<div class="stackedit__html"><h1 id="attention-is-all-you-need">12 Attention Is All You Need</h1>
<h2 id="assignment">Assignment</h2>
<p>The code we are referring to comes from <a href="https://github.com/aladdinpersson/Machine-Learning-Collection/blob/a2ee9271b5280be6994660c7982d0f44c67c3b63/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py">here (Links to an external site.)</a>.</p>
<p>Take the code above, and make it work with any dataset. Submit the GitHub repo’s ReadMe file, where I can see answers to these questions:</p>
<ul>
<li>what dataset you have used</li>
<li>what problem have you solved (fill in the blank, translation, text generation, etc)</li>
<li>the output of your training for 10 epochs</li>
</ul>
<h2 id="solution">Solution</h2>
<table>
<thead>
<tr>
<th></th>
<th>NBViewer</th>
<th>Google Colab</th>
<th>Tensorboard Logs</th>
</tr>
</thead>
<tbody>
<tr>
<td>Attention Is All You Need - <strong>Solution</strong></td>
<td><a href="https://nbviewer.jupyter.org/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/12_AttentionIsAllYouNeed/Attention_Is_All_You_Need.ipynb"><img alt="Open In NBViewer" src="https://img.shields.io/badge/render-nbviewer-orange?logo=Jupyter"></a></td>
<td><a href="https://githubtocolab.com/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/12_AttentionIsAllYouNeed/Attention_Is_All_You_Need.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></td>
<td><a href="https://tensorboard.dev/experiment/pqsWXkyvS2umPiN8rr1snw/"><img src="https://img.shields.io/badge/logs-tensorboard-orange?logo=Tensorflow"></a></td>
</tr>
</tbody>
</table><h2 id="dataset">Dataset</h2>
<p>Dataset used was <strong>Multi30K</strong>, but the DataModule also supports <strong>IWSLT2016</strong></p>
<pre class=" language-python"><code class="prism language-python"><span class="token keyword">class</span> <span class="token class-name">TTCTranslation</span><span class="token punctuation">(</span>pl<span class="token punctuation">.</span>LightningDataModule<span class="token punctuation">)</span><span class="token punctuation">:</span>
DATASET_OPTIONS <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token string">'multi30k'</span><span class="token punctuation">,</span> <span class="token string">'iwslt2016'</span><span class="token punctuation">]</span>
LANGUAGE_OPTIONS <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token string">'en'</span><span class="token punctuation">,</span> <span class="token string">'fr'</span><span class="token punctuation">,</span> <span class="token string">'de'</span><span class="token punctuation">,</span> <span class="token string">'cs'</span><span class="token punctuation">,</span> <span class="token string">'ar'</span><span class="token punctuation">]</span>
<span class="token comment"># Define special symbols and indices</span>
UNK_IDX<span class="token punctuation">,</span> PAD_IDX<span class="token punctuation">,</span> BOS_IDX<span class="token punctuation">,</span> EOS_IDX <span class="token operator">=</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span>
<span class="token comment"># Make sure the tokens are in order of their indices to properly insert them in vocab</span>
SPECIAL_SYMBOLS <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token string">'<unk>'</span><span class="token punctuation">,</span> <span class="token string">'<pad>'</span><span class="token punctuation">,</span> <span class="token string">'<bos>'</span><span class="token punctuation">,</span> <span class="token string">'<eos>'</span><span class="token punctuation">]</span>
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> language_pair<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'en'</span><span class="token punctuation">,</span> <span class="token string">'de'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> spacy_language_pair<span class="token operator">=</span><span class="token punctuation">(</span><span class="token string">'en_core_web_sm'</span><span class="token punctuation">,</span> <span class="token string">'de_core_news_sm'</span><span class="token punctuation">)</span><span class="token punctuation">,</span> dataset<span class="token operator">=</span><span class="token string">'multi30k'</span><span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> batch_first<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token builtin">super</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
<span class="token keyword">assert</span> <span class="token builtin">len</span><span class="token punctuation">(</span>language_pair<span class="token punctuation">)</span> <span class="token operator">==</span> <span class="token number">2</span> <span class="token operator">and</span> <span class="token builtin">len</span><span class="token punctuation">(</span>spacy_language_pair<span class="token punctuation">)</span><span class="token punctuation">,</span> f<span class="token string">'tf are you doing? give me a language \"pair\"'</span>
<span class="token keyword">assert</span> dataset <span class="token keyword">in</span> self<span class="token punctuation">.</span>DATASET_OPTIONS<span class="token punctuation">,</span> f<span class="token string">'{self.DATASET_OPTIONS} are only supported'</span>
<span class="token keyword">assert</span> <span class="token builtin">all</span><span class="token punctuation">(</span>x <span class="token keyword">in</span> self<span class="token punctuation">.</span>LANGUAGE_OPTIONS <span class="token keyword">for</span> x <span class="token keyword">in</span> language_pair<span class="token punctuation">)</span><span class="token punctuation">,</span> f<span class="token string">'{self.LANGUAGE_OPTIONS} are only supported'</span>
self<span class="token punctuation">.</span>batch_size <span class="token operator">=</span> batch_size
self<span class="token punctuation">.</span>batch_first <span class="token operator">=</span> batch_first
self<span class="token punctuation">.</span>language_pair <span class="token operator">=</span> language_pair
self<span class="token punctuation">.</span>src_lang<span class="token punctuation">,</span> self<span class="token punctuation">.</span>tgt_lang <span class="token operator">=</span> language_pair
self<span class="token punctuation">.</span>spacy_src_lang<span class="token punctuation">,</span> self<span class="token punctuation">.</span>spacy_tgt_lang <span class="token operator">=</span> spacy_language_pair
<span class="token keyword">if</span> dataset <span class="token operator">==</span> <span class="token string">'multi30k'</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>train_dataset <span class="token operator">=</span> Multi30k<span class="token punctuation">(</span>split<span class="token operator">=</span><span class="token string">'train'</span><span class="token punctuation">,</span> language_pair<span class="token operator">=</span>self<span class="token punctuation">.</span>language_pair<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>val_dataset <span class="token operator">=</span> Multi30k<span class="token punctuation">(</span>split<span class="token operator">=</span><span class="token string">'valid'</span><span class="token punctuation">,</span> language_pair<span class="token operator">=</span>self<span class="token punctuation">.</span>language_pair<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>test_dataset <span class="token operator">=</span> Multi30k<span class="token punctuation">(</span>split<span class="token operator">=</span><span class="token string">'test'</span><span class="token punctuation">,</span> language_pair<span class="token operator">=</span>self<span class="token punctuation">.</span>language_pair<span class="token punctuation">)</span>
<span class="token keyword">elif</span> dataset <span class="token operator">==</span> <span class="token string">'iwslt2016'</span><span class="token punctuation">:</span>
self<span class="token punctuation">.</span>train_dataset <span class="token operator">=</span> IWSLT2016<span class="token punctuation">(</span>split<span class="token operator">=</span><span class="token string">'train'</span><span class="token punctuation">,</span> language_pair<span class="token operator">=</span>self<span class="token punctuation">.</span>language_pair<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>val_dataset <span class="token operator">=</span> IWSLT2016<span class="token punctuation">(</span>split<span class="token operator">=</span><span class="token string">'valid'</span><span class="token punctuation">,</span> language_pair<span class="token operator">=</span>self<span class="token punctuation">.</span>language_pair<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>test_dataset <span class="token operator">=</span> IWSLT2016<span class="token punctuation">(</span>split<span class="token operator">=</span><span class="token string">'test'</span><span class="token punctuation">,</span> language_pair<span class="token operator">=</span>self<span class="token punctuation">.</span>language_pair<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>train_dataset<span class="token punctuation">,</span> self<span class="token punctuation">.</span>val_dataset<span class="token punctuation">,</span> self<span class="token punctuation">.</span>test_dataset <span class="token operator">=</span> <span class="token builtin">list</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>train_dataset<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">list</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>val_dataset<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">list</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>test_dataset<span class="token punctuation">)</span>
<span class="token comment"># --- token transform ---</span>
self<span class="token punctuation">.</span>token_transform <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token punctuation">}</span>
self<span class="token punctuation">.</span>token_transform<span class="token punctuation">[</span>self<span class="token punctuation">.</span>src_lang<span class="token punctuation">]</span> <span class="token operator">=</span> get_tokenizer<span class="token punctuation">(</span><span class="token string">'spacy'</span><span class="token punctuation">,</span> language<span class="token operator">=</span>self<span class="token punctuation">.</span>spacy_src_lang<span class="token punctuation">)</span>
self<span class="token punctuation">.</span>token_transform<span class="token punctuation">[</span>self<span class="token punctuation">.</span>tgt_lang<span class="token punctuation">]</span> <span class="token operator">=</span> get_tokenizer<span class="token punctuation">(</span><span class="token string">'spacy'</span><span class="token punctuation">,</span> language<span class="token operator">=</span>self<span class="token punctuation">.</span>spacy_tgt_lang<span class="token punctuation">)</span>
<span class="token comment"># --- vocab transform ---</span>
<span class="token comment"># helper function to yield list of tokens</span>
<span class="token keyword">def</span> <span class="token function">yield_tokens</span><span class="token punctuation">(</span>data_iter<span class="token punctuation">:</span> Iterable<span class="token punctuation">,</span> language<span class="token punctuation">:</span> <span class="token builtin">str</span><span class="token punctuation">)</span> <span class="token operator">-</span><span class="token operator">></span> List<span class="token punctuation">[</span><span class="token builtin">str</span><span class="token punctuation">]</span><span class="token punctuation">:</span>
language_index <span class="token operator">=</span> <span class="token punctuation">{</span>self<span class="token punctuation">.</span>src_lang<span class="token punctuation">:</span> <span class="token number">0</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>tgt_lang<span class="token punctuation">:</span> <span class="token number">1</span><span class="token punctuation">}</span>
<span class="token keyword">for</span> data_sample <span class="token keyword">in</span> data_iter<span class="token punctuation">:</span>
<span class="token keyword">yield</span> self<span class="token punctuation">.</span>token_transform<span class="token punctuation">[</span>language<span class="token punctuation">]</span><span class="token punctuation">(</span>data_sample<span class="token punctuation">[</span>language_index<span class="token punctuation">[</span>language<span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>vocab_transform <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token punctuation">}</span>
<span class="token keyword">for</span> ln <span class="token keyword">in</span> self<span class="token punctuation">.</span>language_pair<span class="token punctuation">:</span>
<span class="token comment"># Create torchtext's Vocab object </span>
self<span class="token punctuation">.</span>vocab_transform<span class="token punctuation">[</span>ln<span class="token punctuation">]</span> <span class="token operator">=</span> build_vocab_from_iterator<span class="token punctuation">(</span>yield_tokens<span class="token punctuation">(</span>self<span class="token punctuation">.</span>train_dataset<span class="token punctuation">,</span> ln<span class="token punctuation">)</span><span class="token punctuation">,</span>
min_freq<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">,</span>
specials<span class="token operator">=</span>self<span class="token punctuation">.</span>SPECIAL_SYMBOLS<span class="token punctuation">,</span>
special_first<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
<span class="token comment"># Set UNK_IDX as the default index. This index is returned when the token is not found. </span>
<span class="token comment"># If not set, it throws RuntimeError when the queried token is not found in the Vocabulary. </span>
<span class="token keyword">for</span> ln <span class="token keyword">in</span> self<span class="token punctuation">.</span>language_pair<span class="token punctuation">:</span>
self<span class="token punctuation">.</span>vocab_transform<span class="token punctuation">[</span>ln<span class="token punctuation">]</span><span class="token punctuation">.</span>set_default_index<span class="token punctuation">(</span>self<span class="token punctuation">.</span>UNK_IDX<span class="token punctuation">)</span>
<span class="token comment"># --- text/tensor transform ---</span>
<span class="token comment"># function to add BOS/EOS and create tensor for input sequence indices</span>
<span class="token keyword">def</span> <span class="token function">tensor_transform</span><span class="token punctuation">(</span>token_ids<span class="token punctuation">:</span> List<span class="token punctuation">[</span><span class="token builtin">int</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> torch<span class="token punctuation">.</span>cat<span class="token punctuation">(</span><span class="token punctuation">(</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>BOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>token_ids<span class="token punctuation">)</span><span class="token punctuation">,</span>
torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span><span class="token punctuation">[</span>self<span class="token punctuation">.</span>EOS_IDX<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token comment"># src and tgt language text transforms to convert raw strings into tensors indices</span>
self<span class="token punctuation">.</span>text_transform <span class="token operator">=</span> <span class="token punctuation">{</span><span class="token punctuation">}</span>
<span class="token keyword">for</span> ln <span class="token keyword">in</span> self<span class="token punctuation">.</span>language_pair<span class="token punctuation">:</span>
self<span class="token punctuation">.</span>text_transform<span class="token punctuation">[</span>ln<span class="token punctuation">]</span> <span class="token operator">=</span> sequential_transforms<span class="token punctuation">(</span>self<span class="token punctuation">.</span>token_transform<span class="token punctuation">[</span>ln<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token comment">#Tokenization</span>
self<span class="token punctuation">.</span>vocab_transform<span class="token punctuation">[</span>ln<span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token comment">#Numericalization</span>
tensor_transform<span class="token punctuation">)</span> <span class="token comment"># Add BOS/EOS and create tensor</span>
<span class="token keyword">def</span> <span class="token function">prepare_data</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token comment"># download, split, etc...</span>
<span class="token comment"># only called on 1 GPU/TPU in distributed</span>
<span class="token keyword">pass</span>
<span class="token keyword">def</span> <span class="token function">setup</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> stage<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token comment"># make assignments here (val/train/test split)</span>
<span class="token comment"># called on every process in DDP</span>
<span class="token keyword">pass</span>
<span class="token keyword">def</span> <span class="token function">train_dataloader</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> DataLoader<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>train_dataset<span class="token punctuation">,</span>
batch_size<span class="token operator">=</span>self<span class="token punctuation">.</span>batch_size<span class="token punctuation">,</span>
collate_fn<span class="token operator">=</span>self<span class="token punctuation">.</span>collator_fn
<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">val_dataloader</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> DataLoader<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>val_dataset<span class="token punctuation">,</span>
batch_size<span class="token operator">=</span>self<span class="token punctuation">.</span>batch_size<span class="token punctuation">,</span>
collate_fn<span class="token operator">=</span>self<span class="token punctuation">.</span>collator_fn
<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">test_dataloader</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">return</span> DataLoader<span class="token punctuation">(</span>
self<span class="token punctuation">.</span>test_dataset<span class="token punctuation">,</span>
batch_size<span class="token operator">=</span>self<span class="token punctuation">.</span>batch_size<span class="token punctuation">,</span>
collate_fn<span class="token operator">=</span>self<span class="token punctuation">.</span>collator_fn
<span class="token punctuation">)</span>
@<span class="token builtin">property</span>
<span class="token keyword">def</span> <span class="token function">collator_fn</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token keyword">def</span> <span class="token function">wrapper</span><span class="token punctuation">(</span>batch<span class="token punctuation">)</span><span class="token punctuation">:</span>
src_batch<span class="token punctuation">,</span> tgt_batch <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
<span class="token keyword">for</span> src_sample<span class="token punctuation">,</span> tgt_sample <span class="token keyword">in</span> batch<span class="token punctuation">:</span>
src_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>self<span class="token punctuation">.</span>text_transform<span class="token punctuation">[</span>self<span class="token punctuation">.</span>src_lang<span class="token punctuation">]</span><span class="token punctuation">(</span>src_sample<span class="token punctuation">.</span>rstrip<span class="token punctuation">(</span><span class="token string">"\n"</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
tgt_batch<span class="token punctuation">.</span>append<span class="token punctuation">(</span>self<span class="token punctuation">.</span>text_transform<span class="token punctuation">[</span>self<span class="token punctuation">.</span>tgt_lang<span class="token punctuation">]</span><span class="token punctuation">(</span>tgt_sample<span class="token punctuation">.</span>rstrip<span class="token punctuation">(</span><span class="token string">"\n"</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
src_batch <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>rnn<span class="token punctuation">.</span>pad_sequence<span class="token punctuation">(</span>src_batch<span class="token punctuation">,</span> padding_value<span class="token operator">=</span>self<span class="token punctuation">.</span>PAD_IDX<span class="token punctuation">,</span> batch_first<span class="token operator">=</span>self<span class="token punctuation">.</span>batch_first<span class="token punctuation">)</span>
tgt_batch <span class="token operator">=</span> torch<span class="token punctuation">.</span>nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>rnn<span class="token punctuation">.</span>pad_sequence<span class="token punctuation">(</span>tgt_batch<span class="token punctuation">,</span> padding_value<span class="token operator">=</span>self<span class="token punctuation">.</span>PAD_IDX<span class="token punctuation">,</span> batch_first<span class="token operator">=</span>self<span class="token punctuation">.</span>batch_first<span class="token punctuation">)</span>
<span class="token keyword">return</span> src_batch<span class="token punctuation">,</span> tgt_batch
<span class="token keyword">return</span> wrapper
<span class="token keyword">def</span> <span class="token function">teardown</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> stage<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token comment"># clean up after fit or test</span>
<span class="token comment"># called on every process in DDP</span>
<span class="token keyword">pass</span>
</code></pre>
<h2 id="problem-solved">Problem Solved</h2>
<p>Transformers can be used to solve various NLP Tasks like (fill in the blank, translation, text generation, etc), in this notebook i used to to solve Language Translation Problem.</p>
<pre class=" language-python"><code class="prism language-python">sentence <span class="token operator">=</span> <span class="token string">"Two young, White males are outside near many bushes."</span>
translate_sentence<span class="token punctuation">(</span>transformer<span class="token punctuation">,</span> ttc_translation<span class="token punctuation">,</span> sentence<span class="token punctuation">)</span>
<span class="token operator">>></span><span class="token operator">></span> <span class="token string">'Zwei junge Männer sind in weißen Nähe von Büschen .'</span>
</code></pre>
<h2 id="training-logs">Training Logs</h2>
<p>Tensorboard Logs: <a href="https://tensorboard.dev/experiment/pqsWXkyvS2umPiN8rr1snw/">https://tensorboard.dev/experiment/pqsWXkyvS2umPiN8rr1snw/</a></p>
</div>
</body>
</html>