self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])和普通的 Python List 的区别在于:
- 调用
model.parameters()时,能找到ModuleList里的参数。 - 调用
model.to(device)时,里面的所有层会同步移动到 GPU。 - 保存模型(
state_dict)时,里面的权重会被包含在内。
相反,如果用 Python List 存储,里面的层的参数不会被优化器更新。