Navigation
  • Print
  • Share
  • Copy URL
  • Breadcrumb

    torch.no_grad()函数的理解以及反向传播计算

    [Ash]

    torch.no_grad() 是 PyTorch 中的一个上下文管理器,主要用于在前向传播时禁用梯度计算。它的作用是

    • 停止自动求导机制的跟踪,使得计算图不会被创建,从而节省内存和计算资源。

    使用场景:

    • 推理/测试: 当你在推理或测试模型时,不需要计算梯度,使用 torch.no_grad() 可以提高效率并节省内存。
    • 冻结模型参数: 如果你只关心前向传播的结果而不进行梯度更新,可以将模型的某些部分设置为不需要梯度。例如,当你冻结一部分网络(只进行前向传播)时,使用 torch.no_grad() 是非常有用的。
    • 不参与梯度更新的中间计算: 如果某些计算只是为了获取中间变量,而不影响模型的梯度更新,可以将这些计算包裹在 torch.no_grad() 中。

    举个例子

    import torch
    
    # 创建输入张量,并设置 requires_grad=True 以计算梯度
    a = torch.tensor([2.0], dtype=torch.float32, requires_grad=True)  # 修改形状为 (1,)
    
    b = 2 * a
    
    b.backward()
    
    # 输出梯度
    print("输入 a 的梯度:", a.grad)
    
    with torch.no_grad():
        a = a + 1
        print("输入 a 的梯度:", a.grad)
    
    
    
    -------------
    a 的梯度: tensor([2.])
    a 的梯度: None
    

    在with no_grad()中,a用 =+ 的方式被变成了另外一个新的变量,因此a没有了之前的梯度信息,变为None,但是当我们使用 += 的方式进行操作,却可以计算出它的梯度了,因为在上下文管理器中只修改了tensor的data,而没有修改属性。

    with torch.no_grad():
        a += 1
        print("a 的梯度:", a.grad)
    
    -------------
    a 的梯度: tensor([2.])
    

    在具体的神经网络中,反向传播的计算:

    import torch
    import torch.nn as nn
    
    
    class NN(nn.Module):
        def __init__(self):
            super(NN, self).__init__()
            self.w1 = 3
            self.w2 = 2
    
        def forward(self, x):
            x = self.w1 * x + self.w2 * x**2
            return x
    
    # 创建输入张量,并设置 requires_grad=True 以计算梯度
    a = torch.tensor([2.0], dtype=torch.float32, requires_grad=True)  # 修改形状为 (1,)
    
    # 保持a作为叶子张量
    a.retain_grad()
    
    # 实例化模型
    net = NN()
    
    # 前向传播
    y = net(a)
    
    # 设定目标值(假设希望 y 变成 10)
    target = torch.tensor([10.0], dtype=torch.float32)
    
    # 定义损失函数 (均方误差 MSE Loss)
    loss_fn = nn.MSELoss()
    loss = loss_fn(y, target)
    
    # 反向传播计算梯度
    loss.backward()
    
    # 输出梯度
    print("输入 a 的梯度:", a.grad)
    
    
    -------------
    输入 a 的梯度: 88
    

    The forward pass is given by:

    [ y = w_1 \cdot x + w_2 \cdot x^2 ]

    where ( w_1 = 3 ), ( w_2 = 2 ), and ( x = 2 ). Substituting the values:

    [ y = 3 \cdot 2 + 2 \cdot 2^2 = 6 + 8 = 14 ]

    Loss Function:

    The loss function (mean squared error) is:

    [ \text{loss} = \frac{1}{2} (y - 10)^2 ]

    Substituting ( y = 14 ):

    [ \text{loss} = \frac{1}{2} (14 - 10)^2 = \frac{1}{2} \times 16 = 8 ]

    Gradient Calculation:

    1. Compute the derivative of ( y ) with respect to ( x ):

    [ \frac{dy}{dx} = \frac{d}{dx} \left( 3x + 2x^2 \right) = 3 + 4x ]

    Substitute ( x = 2 ):

    [ \frac{dy}{dx} = 3 + 4 \cdot 2 = 3 + 8 = 11 ]

    1. Compute the derivative of the loss with respect to ( y ):

    [ \frac{d\text{loss}}{dy} = \frac{d}{dy} \left( \frac{1}{1} (y - 10)^2 \right) = 2\cdot(y - 10) ]

    Substitute ( y = 14 ):

    [ \frac{d\text{loss}}{dy} = 2\cdot(14 - 10) = 8 ]

    1. Use the chain rule to compute the gradient of the loss with respect to ( x ):

    [ \frac{d\text{loss}}{dx} = \frac{d\text{loss}}{dy} \cdot \frac{dy}{dx} = 8 \times 11 = 88 ]

    Thus, the final gradient is ( \frac{d\text{loss}}{dx} = 88 ).