class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(C, H, bias=False)
self.query = nn.Linear(C, H, bias=False)
self.value = nn.Linear(C, H, bias=False)
def forward(self, x):
k = self.key(x) # (B, T, H)
q = self.query(x) # (B, T, H)
v = self.value(x) # (B, T, H)
wei = q @ k.transpose(-2, -1) # (B, T, T)
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v
return outQ(Query): “我需要什么?”——当前 token 想找什么样的 tokenK(Key): “我有什么?”——当前 token 用于响应时的特征,决定自己被关注的程度V(Value): “我能贡献什么?”——当前 token 包含的信息- 除以 是为了归一化,减小方差,避免出现极端值导致 softmax 退化成 One-hot
Note
- Self Attention 是一种通信机制
- Self Attention 本身并不包含位置信息,因此总是会加入位置编码辅助
- Self Attention 本身也不包含 Casual Mask,可以视情况决定要不要加 Casual Mask,比如 Transformer 的 Encoder 中就没有 Casual Mask
Multi-headed self attention
将 H 拆分为 num_heads 个 Head,每个 Head 分别做 self attention
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
def forward(self, x):
return torch.cat([h(x) for h in self.heads], dim=-1)Multi Head vs Single Head
- 通过并行多个 Head,每个 Head 拥有独立的 Softmax 归一化,保证了模型可以同时捕捉到多种联系
- 相对地,单 Head 的 Attention 如果想同时捕捉多个特征,这些特征的得分在 Softmax 中会产生竞争,导致某些较弱但同样重要的特征被掩盖掉