[徵文] Attention is all you need

看板DataScience作者 (Lai_can)時間6年前 (2018/07/31 10:42), 6年前編輯推噓4(406)
留言10則, 4人參與, 6年前最新討論串1/1
1) 論文介紹 arXiv 連結: https://arxiv.org/abs/1706.03762 這篇是 Google 發表在 NIPS2017 上的 paper,其最大的亮點是提出一個新的 encoder-decoder 架構完全依賴於 Attention 機制而完全沒用到 CNN 以及 RNN,這樣 的好處是提高了訓練時可平行化處理的部分 (RNN 依賴於序列處理,而 Attention 可以 是矩陣乘法運算) 並且讓 input, output sequences 中的文字彼此之間跨越距離的障礙 能夠找到有關聯的字詞。 在這篇論文之前,為了減少序列的計算量,被提出的方法有 ByteNet, ConvS2S 等網路架 構,都是用到了 CNN 來做 downsampling,但這樣的方法仍然會遇到距離障礙的問題: 相 距越遠的字詞會越難找到彼此的資訊,其解決方法就是使用 self-attention 取代 encode, decode 階段,讓每一個詞都去計算整個 sequence 的表示,此外其也能解決過 長的 RNN 架構可能會造成的梯度消失問題。 這篇提出的架構 Transformer 如下圖: https://imgur.com/e5JWQNg
也是 encoder-decoder 的形式,只不過都換成了使用 attention + fully connected layers 來實現。 對於 input, output 序列每個字詞過 Embedding 之後須加上 Positional Encoding 主 要是因為這個架構不像 RNN 是有序的,但是為了捕捉到字詞前後的關係所以須加上基於 位置的 Embedding。 Encoder: 在 Encoder 中堆疊了六個相同的 layer,每個 layer 都包含兩個 sublayer,分別是 multi-head self-attention 和 position-wise fully connected network,並且都是 以殘差連接的方式,好處是能夠加深網路,並且都過 layer normalization 加速收斂, 因此每一層 sublayer 都可以以 LayerNorm(x + SubLayer(x)) 來表示。 Decoder: Decoder 同樣堆疊六層,但每一層包含了三個 sublayer,其同樣有目標序列的 self-attention,再加上了 decoder 向 encoder 的 attention 機制,最後同樣過全連 接層輸出,多的 sublayer 就是在負責從 encoder 藉由 attention 抓取重要資訊來作為 輸出參考,其中比較需要注意的是 decoder 的 self-attention 需要加上遮罩機制,也 就是讓位置 t 的字詞只能 attend 到自己以前的字詞,不能向後偷看。 其用到的 Attention 計算公式為 Scaled Dot-Product Attention 流程如下圖: https://imgur.com/15WtVKI
其中 Q 是 query (發起 attention 的 matrix) K, V 分別是 key, value (被 attention 的 matrix),在這篇論文中使用的 K, V 是相 同的,都是某時間點的 hidden state 寫成公式如下: https://imgur.com/ju9EjgP
其實與 dot product attention 的計算方式幾乎一樣,使用 Q, K 進行點積得到 attention weight(知道對每個時間點的該注意的程度),再和 value 相乘得到加權結果 。 只差在 Q, K 進行點積之後除以一個 hidden dimension (d_k) 的根號,是為了避免點積 的結果太大影響訓練穩定程度 ( 除以根號 d_k 之後可以讓方差變成 1 ) https://imgur.com/QTzusmn
Multi-Head Attention 其使用的 attention 計算公式為上述的 scaled dot-product attention,而其使用的機 制為 multi-head attention,概念是分別將 Q, K, V 經過線性轉換(learnable的)成 h 個,再 h 個各自平行地去做 attention,最後再將 h 個結果 concatenate 在一起得到 最終的結果,通常會希望 concatenate 之後的維度與原來相同(d_k),因此在做線性轉換 時通常會把 h 個轉換出來的結果維度為 d_k / h 寫成公式如下: https://imgur.com/Jqs3CYV
Position-wise Feed Forward Networks 其實就是兩層的 fully connected layers 搭配 Relu 總體架構來說,就是讓源句子做 self-attention 以殘差連接和 layer normalization, 接著丟去 feed forward 也以殘差連接以及 layer normalization,這個動作重複 6 層 後當作 Encoder 輸出,Decoder 階段目標句子同樣先做 self-attention (不過帶有 mask) 後,對 Encoder 的輸出做 attention 再丟進 feed forward,這三個步驟也都是 殘差連接以及 layer normalization,並也做 6 層後經過一個線性轉換以及 softmax 輸 出預測句子。 Why Self-Attention? 作者們認為主要有三大好處: 1. 降低了每一層的計算複雜度:只要 sequence length 小於 hidden dimension 就會比 RNN 複雜度低 2. 增加了可平行化處理的程度:加快訓練速度 3. 解決long-dependency的問題:字詞相距很遠難以關注到彼此的問題 Training 在訓練過程中他們也使用了 learning rate 遞減、dropout、label smoothing 等等 tips,在此不贅述。 Result https://imgur.com/2JJ5m4Z
當時在機器翻譯資料集 WMT2014 取得了 state-of-the-art 的結果,值得注意的是 Transformer 的 Training Cost 是比其他模型少許多的。 這篇並沒有詳細提到 Positional Encoding 怎麼 init,因為本篇重點應該比較注重在 Attention,並且 Google 也在今年(2018)提出了一種新的將位置加入 Transformer 的 方法,詳細可以參考 https://arxiv.org/abs/1803.02155 2) 個人心得 這篇應該算是 NLP 近年來最多人關注的 paper 之一,Google 也還在針對這個架構進行 研究發展新的 paper 來改善增進 Transformer 的問題與能力。舉例來說目前 Transformer 的其中一個問題是無法像 RNN-based model 做 schedule sampling 的訓 練,讓模型在訓練階段都只能看 ground truth 而在測試階段就要看自己前一個時間點的 輸出結果。 底下附上我自己實作 Transformer 應用在 PTT Gossiping QA Dataset (https://github.com/zake7749/Gossiping-Chinese-Corpus ) 的一些結果: https://imgur.com/l0AsUa5
第一次寫這類型的文章,若有理解錯誤或是表達的不精確請各位大大指正~謝謝 -- ※ 發信站: 批踢踢實業坊(ptt.cc), 來自: 61.218.53.138 ※ 文章網址: https://www.ptt.cc/bbs/DataScience/M.1533004952.A.244.html ※ 編輯: dav1a1223 (114.136.163.75), 07/31/2018 10:43:56

07/31 10:52, 6年前 , 1F
加快訓練大概能加快多少?
07/31 10:52, 1F

07/31 10:55, 6年前 , 2F
另外這個QA有什麼衡量performance的方式嗎 跟其他方法
07/31 10:55, 2F

07/31 10:55, 6年前 , 3F
比有沒有明顯差別
07/31 10:55, 3F

07/31 10:57, 6年前 , 4F
以和rnn每一層的複雜度來比較的話,rnn是O(n*d^2)而se
07/31 10:57, 4F

07/31 10:57, 6年前 , 5F
lf-attn是O(n^2*d)其中n是序列長d是hidden dim
07/31 10:57, 5F

07/31 11:04, 6年前 , 6F
QA的部分我只是做好玩的並沒有特別去算bleu之類的指標
07/31 11:04, 6F

07/31 11:04, 6年前 , 7F
,肉眼看的話,我同時也實作gru encoder decoder with
07/31 11:04, 7F

07/31 11:04, 6年前 , 8F
attn,結果句子相較是不通順許多
07/31 11:04, 8F

07/31 11:43, 6年前 , 9F
在做inference的時候還是要照順序生,所以還是有點慢
07/31 11:43, 9F

12/17 03:35, 6年前 , 10F
感謝分享!
12/17 03:35, 10F
文章代碼(AID): #1RNyoO94 (DataScience)
文章代碼(AID): #1RNyoO94 (DataScience)