Navigation
Share
Breadcrumb

torch.no_grad()函数的理解以及反向传播计算

[Ash]

torch.no_grad() 是 PyTorch 中的一个上下文管理器,主要用于在前向传播时禁用梯度计算。它的作用是

  • 停止自动求导机制的跟踪,使得计算图不会被创建,从而节省内存和计算资源。

使用场景:

  • 推理/测试: 当你在推理或测试模型时,不需要计算梯度,使用 torch.no_grad() 可以提高效率并节省内存。
  • 冻结模型参数: 如果你只关心前向传播的结果而不进行梯度更新,可以将模型的某些部分设置为不需要梯度。例如,当你冻结一部分网络(只进行前向传播)时,使用 torch.no_grad() 是非常有用的。
  • 不参与梯度更新的中间计算: 如果某些计算只是为了获取中间变量,而不影响模型的梯度更新,可以将这些计算包裹在 torch.no_grad() 中。

举个例子

import torch

# 创建输入张量,并设置 requires_grad=True 以计算梯度
a = torch.tensor([2.0], dtype=torch.float32, requires_grad=True)  # 修改形状为 (1,)

b = 2 * a

b.backward()

# 输出梯度
print("输入 a 的梯度:", a.grad)

with torch.no_grad():
    a = a + 1
    print("输入 a 的梯度:", a.grad)



-------------
a 的梯度: tensor([2.])
a 的梯度: None

在with no_grad()中,a用 =+ 的方式被变成了另外一个新的变量,因此a没有了之前的梯度信息,变为None,但是当我们使用 += 的方式进行操作,却可以计算出它的梯度了,因为在上下文管理器中只修改了tensor的data,而没有修改属性。

with torch.no_grad():
    a += 1
    print("a 的梯度:", a.grad)

-------------
a 的梯度: tensor([2.])

在具体的神经网络中,反向传播的计算:

import torch
import torch.nn as nn


class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.w1 = 3
        self.w2 = 2

    def forward(self, x):
        x = self.w1 * x + self.w2 * x**2
        return x

# 创建输入张量,并设置 requires_grad=True 以计算梯度
a = torch.tensor([2.0], dtype=torch.float32, requires_grad=True)  # 修改形状为 (1,)

# 保持a作为叶子张量
a.retain_grad()

# 实例化模型
net = NN()

# 前向传播
y = net(a)

# 设定目标值(假设希望 y 变成 10)
target = torch.tensor([10.0], dtype=torch.float32)

# 定义损失函数 (均方误差 MSE Loss)
loss_fn = nn.MSELoss()
loss = loss_fn(y, target)

# 反向传播计算梯度
loss.backward()

# 输出梯度
print("输入 a 的梯度:", a.grad)


-------------
输入 a 的梯度: 88

The forward pass is given by:

[ y = w_1 \cdot x + w_2 \cdot x^2 ]

where ( w_1 = 3 ), ( w_2 = 2 ), and ( x = 2 ). Substituting the values:

[ y = 3 \cdot 2 + 2 \cdot 2^2 = 6 + 8 = 14 ]

Loss Function:

The loss function (mean squared error) is:

[ \text{loss} = \frac{1}{2} (y - 10)^2 ]

Substituting ( y = 14 ):

[ \text{loss} = \frac{1}{2} (14 - 10)^2 = \frac{1}{2} \times 16 = 8 ]

Gradient Calculation:

  1. Compute the derivative of ( y ) with respect to ( x ):

[ \frac{dy}{dx} = \frac{d}{dx} \left( 3x + 2x^2 \right) = 3 + 4x ]

Substitute ( x = 2 ):

[ \frac{dy}{dx} = 3 + 4 \cdot 2 = 3 + 8 = 11 ]

  1. Compute the derivative of the loss with respect to ( y ):

[ \frac{d\text{loss}}{dy} = \frac{d}{dy} \left( \frac{1}{1} (y - 10)^2 \right) = 2\cdot(y - 10) ]

Substitute ( y = 14 ):

[ \frac{d\text{loss}}{dy} = 2\cdot(14 - 10) = 8 ]

  1. Use the chain rule to compute the gradient of the loss with respect to ( x ):

[ \frac{d\text{loss}}{dx} = \frac{d\text{loss}}{dy} \cdot \frac{dy}{dx} = 8 \times 11 = 88 ]

Thus, the final gradient is ( \frac{d\text{loss}}{dx} = 88 ).