PyTorch中的gather函数详解
torch.gather函数笔记
在学习PPO网络代码的时候遇到了这样一句话:
old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
其中gather函数令人费解,官方文档以三维数组解释得也晦涩难懂,是这么写的:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters:
input (Tensor) – the source tensor
dim (int) – the axis along which to index
index (LongTensor) – the indices of elements to gather
通过具体计算例子记录一下自己的理解:
比如,input=[[1, 2], [3, 4]],index =[[0, 0], [1, 0]]
,都是二维数组,如果我使用torch.gather(input, dim=1, index)
,那么out就会是对input在第二维度上做以下几个操作:
out[0][0] = input[0][index[0][0]] = input[0][0] = 1
out[0][1] = input[0][index[0][1]] = input[0][0] = 1
out[1][0] = input[1][index[1][0]] = input[1][1] = 4
out[1][1] = input[1][index[1][1]] = input[1][0] = 3
所以out=[[1, 1], [4, 3]]
。
其实out的形状只和index的形状有关,而与input没有关系,因为等号右边遍历的是index中的值,如果i, j大于index的范围就会报错。
回到最开头PPO网络代码中,这里的action是一个形状为[19, 1]
的0/1数组,而self.actor(states)
得到的是一个形状为[19, 2]
的数组,在强化学习中代表的是在动作空间中分别选择两个动作的概率值,有19个step。
old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
通过这条语句,old_log_probs得到的是每一个step中选择的这19个动作分别的概率值。