用于实现需要手动指定 backward 操作的自定义算子。常见于手写了一个底层的 CUDA 算子,或需要实现一些无法自动求导的操作,比如量化等。
class MySquare(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input) # 保存输入用于求导
return input ** 2
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * 2 * input # 手动写出 x^2 的导数 **2x**@staticmethod
@staticmethod必须加上,因为 Pytorch 希望使用的是一个静态的无需实例化即可调用的类。
调用时,通过 apply() 方法调用,如
class MyModule(nn.Module):
def forward(self, x):
MySquare.apply(x)