前言
這篇筆記整理 torch.nn.Transformer 中一個常見限制:d_model 必須能被 nhead 整除。
從理論上看,Multi-Head Attention 可以想成對同一份 input 做多組 Self-Attention,再把多個 head 的輸出接起來。因此直覺上會覺得 feature size 與 head 數量不一定要整除。但 PyTorch 的實作為了效率,會把 embedding dimension 平均切給每個 head,這就是限制的來源。
Multi-Head Transformer 理論
Multi-Head Transformer 的概念是:對同一個 input 做多個 Self-Attention,將多次輸出 concat 後,再透過一個矩陣投影回原本大小。具體流程如下:
在這個抽象描述裡,輸入 x 的 feature size 和 head 數量看起來沒有硬性關係。也就是說,如果只看理論流程,不論 feature size 與 head 數量是多少,似乎都可以訓練。
為什麼 d_model 需要被 nhead 整除
如果對 nn.Transformer 填入任意的 feature size 與 head 數量,可能會遇到錯誤訊息,提示 embed_dim 必須能被 num_heads 整除。
原因可以從 PyTorch source code 看出來。首先看 nn.Transformer,其中與 nhead、d_model 相關的部分會進到 TransformerEncoderLayer。
1 | class Transformer(Module): |
接著看 TransformerEncoderLayer,可以發現真正處理 attention 的類別是 MultiheadAttention。
1 | class TransformerEncoderLayer(Module): |
繼續追 MultiheadAttention,會看到它把 forward 的核心邏輯交給 F.multi_head_attention_forward。
1 | class MultiheadAttention(Module): |
最後看 F.multi_head_attention_forward。關鍵在這段:PyTorch 會先用 embed_dim // num_heads 算出每個 head 分到的維度,並 assert 這個拆分必須剛好整除。
1 | def multi_head_attention_forward( |
也就是說,PyTorch 的實作會把 Q、K、V reshape 成 4 維:
batch sizenumber of headtarget/source lengthhead dimension
原本的 feature size 會被拆成 number of head * head dimension。如果 embed_dim 不能被 num_heads 整除,就無法平均 reshape,因此會直接 assert。
簡單來說,torch.nn.Transformer 的 multi-head 實作不是把完整 input 重複餵給每個 head,而是將 feature size 平均拆成多份,每份交給不同 head 計算,最後再接回來。這種做法可以大幅節省運算量與記憶體使用,但代價就是 d_model 必須被 nhead 整除。
小結
這個例子展示了理論描述與工程實作之間的差異。理論上 Multi-Head Attention 可以用比較抽象的方式理解,但在框架實作中,為了讓張量 reshape、batch 運算與 GPU 加速更有效率,會加入更明確的維度限制。
這也是讀 framework source code 很有價值的地方:除了理解模型,也能看到實作用哪些假設換取效能與穩定性。