torch.no_grad()函数的理解以及反向传播计算
torch.no_grad() 是 PyTorch 中的一个上下文管理器,主要用于在前向传播时禁用梯度计算。它的作用是
- 停止自动求导机制的跟踪,使得计算图不会被创建,从而节省内存和计算资源。
使用场景:
- 推理/测试: 当你在推理或测试模型时,不需要计算梯度,使用 torch.no_grad() 可以提高效率并节省内存。
- 冻结模型参数: 如果你只关心前向传播的结果而不进行梯度更新,可以将模型的某些部分设置为不需要梯度。例如,当你冻结一部分网络(只进行前向传播)时,使用 torch.no_grad() 是非常有用的。
- 不参与梯度更新的中间计算: 如果某些计算只是为了获取中间变量,而不影响模型的梯度更新,可以将这些计算包裹在 torch.no_grad() 中。
举个例子
import torch
# 创建输入张量,并设置 requires_grad=True 以计算梯度
a = torch.tensor([2.0], dtype=torch.float32, requires_grad=True) # 修改形状为 (1,)
b = 2 * a
b.backward()
# 输出梯度
print("输入 a 的梯度:", a.grad)
with torch.no_grad():
a = a + 1
print("输入 a 的梯度:", a.grad)
-------------
a 的梯度: tensor([2.])
a 的梯度: None
在with no_grad()中,a用 =+ 的方式被变成了另外一个新的变量,因此a没有了之前的梯度信息,变为None,但是当我们使用 += 的方式进行操作,却可以计算出它的梯度了,因为在上下文管理器中只修改了tensor的data,而没有修改属性。
with torch.no_grad():
a += 1
print("a 的梯度:", a.grad)
-------------
a 的梯度: tensor([2.])
在具体的神经网络中,反向传播的计算:
import torch
import torch.nn as nn
class NN(nn.Module):
def __init__(self):
super(NN, self).__init__()
self.w1 = 3
self.w2 = 2
def forward(self, x):
x = self.w1 * x + self.w2 * x**2
return x
# 创建输入张量,并设置 requires_grad=True 以计算梯度
a = torch.tensor([2.0], dtype=torch.float32, requires_grad=True) # 修改形状为 (1,)
# 保持a作为叶子张量
a.retain_grad()
# 实例化模型
net = NN()
# 前向传播
y = net(a)
# 设定目标值(假设希望 y 变成 10)
target = torch.tensor([10.0], dtype=torch.float32)
# 定义损失函数 (均方误差 MSE Loss)
loss_fn = nn.MSELoss()
loss = loss_fn(y, target)
# 反向传播计算梯度
loss.backward()
# 输出梯度
print("输入 a 的梯度:", a.grad)
-------------
输入 a 的梯度: 88
The forward pass is given by:
[ y = w_1 \cdot x + w_2 \cdot x^2 ]
where ( w_1 = 3 ), ( w_2 = 2 ), and ( x = 2 ). Substituting the values:
[ y = 3 \cdot 2 + 2 \cdot 2^2 = 6 + 8 = 14 ]
Loss Function:
The loss function (mean squared error) is:
[ \text{loss} = \frac{1}{2} (y - 10)^2 ]
Substituting ( y = 14 ):
[ \text{loss} = \frac{1}{2} (14 - 10)^2 = \frac{1}{2} \times 16 = 8 ]
Gradient Calculation:
- Compute the derivative of ( y ) with respect to ( x ):
[ \frac{dy}{dx} = \frac{d}{dx} \left( 3x + 2x^2 \right) = 3 + 4x ]
Substitute ( x = 2 ):
[ \frac{dy}{dx} = 3 + 4 \cdot 2 = 3 + 8 = 11 ]
- Compute the derivative of the loss with respect to ( y ):
[ \frac{d\text{loss}}{dy} = \frac{d}{dy} \left( \frac{1}{1} (y - 10)^2 \right) = 2\cdot(y - 10) ]
Substitute ( y = 14 ):
[ \frac{d\text{loss}}{dy} = 2\cdot(14 - 10) = 8 ]
- Use the chain rule to compute the gradient of the loss with respect to ( x ):
[ \frac{d\text{loss}}{dx} = \frac{d\text{loss}}{dy} \cdot \frac{dy}{dx} = 8 \times 11 = 88 ]
Thus, the final gradient is ( \frac{d\text{loss}}{dx} = 88 ).