model.eval()model.train() 主要影响以下两个层:

  • Dropout 层:
    • .train(): 随机将一部分神经元的激活值设为 0,防止过拟合
    • .eval(): 所有神经元保持激活
  • BatchNorm层:
    • .train(): 使用当前 Batch 的均值和方差来归一化数据,更新模型内部记录的均值和方差
    • .eval(): 固定模型内部学习到的全局均值和方差

model.eval() 常常配合 torch.no_grad() 使用,以节省显存和计算资源