nn.Module 中,用 self.register_buffer('name', tensor) 来定义不参与梯度下降更新,但又是模型状态的一部分的变量

  • 随模型移动:调用 model.to('cuda')model.half() 时,会自动随模型移动或变换精度
  • 包含在状态字典中model.state_dict() 导出字典时会被包含
  • 不计算梯度:不会被优化器更新

应用场景

有一些经典的应用场景:

  • BatchNorm 中的移动平均值
  • Transformer 中的位置编码
  • 因果掩码