用于实现需要手动指定 backward 操作的自定义算子。常见于手写了一个底层的 CUDA 算子,或需要实现一些无法自动求导的操作,比如量化等。

class MySquare(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input) # 保存输入用于求导
        return input ** 2
 
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output * 2 * input # 手动写出 x^2 的导数 **2x**

@staticmethod

@staticmethod 必须加上,因为 Pytorch 希望使用的是一个静态的无需实例化即可调用的类。

调用时,通过 apply() 方法调用,如

class MyModule(nn.Module):
	def forward(self, x):
		MySquare.apply(x)