class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(C, H, bias=False)
        self.query = nn.Linear(C, H, bias=False)
        self.value = nn.Linear(C, H, bias=False)
 
    def forward(self, x):
        k = self.key(x) # (B, T, H)
        q = self.query(x) # (B, T, H)
        v = self.value(x) # (B, T, H)
        wei = q @ k.transpose(-2, -1) # (B, T, T)
 
        tril = torch.tril(torch.ones(T, T))
        wei = wei.masked_fill(tril == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        out = wei @ v
        return out

  • Q (Query): “我需要什么?”——当前 token 想找什么样的 token
  • K (Key): “我什么?”——当前 token 用于响应时的特征,决定自己被关注的程度
  • V (Value): “我能贡献什么?”——当前 token 包含的信息
  • 除以 是为了归一化,减小方差,避免出现极端值导致 softmax 退化成 One-hot

Note

  • Self Attention 是一种通信机制
  • Self Attention 本身并不包含位置信息,因此总是会加入位置编码辅助
  • Self Attention 本身也不包含 Casual Mask,可以视情况决定要不要加 Casual Mask,比如 Transformer 的 Encoder 中就没有 Casual Mask

Multi-headed self attention

H 拆分为 num_heads 个 Head,每个 Head 分别做 self attention

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
 
    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

Multi Head vs Single Head

  • 通过并行多个 Head,每个 Head 拥有独立的 Softmax 归一化,保证了模型可以同时捕捉到多种联系
  • 相对地,单 Head 的 Attention 如果想同时捕捉多个特征,这些特征的得分在 Softmax 中会产生竞争,导致某些较弱但同样重要的特征被掩盖掉