model.eval() 和 model.train() 主要影响以下两个层:
- Dropout 层:
.train(): 随机将一部分神经元的激活值设为 0,防止过拟合.eval(): 所有神经元保持激活
- BatchNorm层:
.train(): 使用当前 Batch 的均值和方差来归一化数据,更新模型内部记录的均值和方差.eval(): 固定模型内部学习到的全局均值和方差
model.eval() 常常配合 torch.no_grad() 使用,以节省显存和计算资源