MHA 的变体,使多个 Query 头共享同一个 K 头和 V 头,可以降低参数量和显存占用。
class GQA(nn.Module):
def __init__(self, num_head, num_kv_head, n_embed, block_size):
super().__init__()
self.n_head = num_head
self.n_kv_head = num_kv_head
self.n_rep = self.n_head // self.n_kv_head
self.head_dim = n_embed // num_head
self.block_size = block_size
self.q_proj = nn.Linear(n_embed, n_embed, bias=False)
self.k_proj = nn.Linear(n_embed, self.n_kv_head * self.head_dim, bias=False)
self.v_proj = nn.Linear(n_embed, self.n_kv_head * self.head_dim, bias=False)
self.o_proj = nn.Linear(n_embed, n_embed, bias=False)
# self.dropout = nn.Dropout(dropout)
self.rope = RoPE(self.head_dim)
def forward(self, x):
B, T, C = x.size()
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim)
k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim)
q = self.rope(q)
k = self.rope(k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=True,
enable_gqa=True
)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.o_proj(y)
return y这里直接用了 scaled_dot_product_attention 支持的 enable_gqa 参数,只要 KV 头数可以整除 Q 头数就可以自动适配。