self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

和普通的 Python List 的区别在于:

  • 调用 model.parameters() 时,能找到 ModuleList 里的参数。
  • 调用 model.to(device) 时,里面的所有层会同步移动到 GPU。
  • 保存模型(state_dict)时,里面的权重会被包含在内。

相反,如果用 Python List 存储,里面的层的参数不会被优化器更新。