Navigation
  • Print
  • Share
  • Copy URL
  • Breadcrumb

    PyTorch中的gather函数详解

    torch.gather函数笔记

    [Ash]

    在学习PPO网络代码的时候遇到了这样一句话:

    old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
    

    其中gather函数令人费解,官方文档以三维数组解释得也晦涩难懂,是这么写的:

    out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
    out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
    out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
    
    Parameters:
        input (Tensor) – the source tensor
        dim (int) – the axis along which to index
        index (LongTensor) – the indices of elements to gather
    

    通过具体计算例子记录一下自己的理解:

    比如,input=[[1, 2], [3, 4]],index =[[0, 0], [1, 0]],都是二维数组,如果我使用torch.gather(input, dim=1, index),那么out就会是对input在第二维度上做以下几个操作:

    out[0][0] = input[0][index[0][0]] = input[0][0] = 1
    out[0][1] = input[0][index[0][1]] = input[0][0] = 1
    out[1][0] = input[1][index[1][0]] = input[1][1] = 4
    out[1][1] = input[1][index[1][1]] = input[1][0] = 3
    

    所以out=[[1, 1], [4, 3]]

    其实out的形状只和index的形状有关,而与input没有关系,因为等号右边遍历的是index中的值,如果i, j大于index的范围就会报错。

    回到最开头PPO网络代码中,这里的action是一个形状为[19, 1]的0/1数组,而self.actor(states)得到的是一个形状为[19, 2]的数组,在强化学习中代表的是在动作空间中分别选择两个动作的概率值,有19个step。

    old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()
    

    通过这条语句,old_log_probs得到的是每一个step中选择的这19个动作分别的概率值。