MHA 的变体,使多个 Query 头共享同一个 K 头和 V 头,可以降低参数量和显存占用。

class GQA(nn.Module):
    def __init__(self, num_head, num_kv_head, n_embed, block_size):
        super().__init__()
        self.n_head = num_head
        self.n_kv_head = num_kv_head
        self.n_rep = self.n_head // self.n_kv_head 
        self.head_dim = n_embed // num_head
        self.block_size = block_size
        self.q_proj = nn.Linear(n_embed, n_embed, bias=False)
        self.k_proj = nn.Linear(n_embed, self.n_kv_head * self.head_dim, bias=False)
        self.v_proj = nn.Linear(n_embed, self.n_kv_head * self.head_dim, bias=False)
        self.o_proj = nn.Linear(n_embed, n_embed, bias=False)
        # self.dropout = nn.Dropout(dropout)
        self.rope = RoPE(self.head_dim)
 
    def forward(self, x):
        B, T, C = x.size()
 
        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim)
        k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim)
        q = self.rope(q)
        k = self.rope(k)
 
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
 
        y = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=0.0,
            is_causal=True,
            enable_gqa=True
        )
 
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.o_proj(y)
        return y

这里直接用了 scaled_dot_product_attention 支持的 enable_gqa 参数,只要 KV 头数可以整除 Q 头数就可以自动适配。