在 nn.Module 中,用 self.register_buffer('name', tensor) 来定义不参与梯度下降更新,但又是模型状态的一部分的变量
- 随模型移动:调用
model.to('cuda')或model.half()时,会自动随模型移动或变换精度 - 包含在状态字典中:
model.state_dict()导出字典时会被包含 - 不计算梯度:不会被优化器更新
应用场景
有一些经典的应用场景:
- BatchNorm 中的移动平均值
- Transformer 中的位置编码
- 因果掩码