torch.argmax ()函数
3358 www.Sina.com/: torch.arg max (input,dim=None,keepdim=False ) )返回指定维的最大值的编号。argmax函数.也就是说叫做dim
1 ) dim的不同值表示不同的维度。 的dim=0表示二维列,dim=1表示二维矩阵中的行。 广义地说,无论一个矩阵是一维,例如一个矩阵的维数如下。 (d0,d1,dn1 ) ),则dim=0表示d0即对应于第一个维度,dim=1表示对应于
也就是说在第二维中,一次类推。
2 )知道dim的值是什么意思还不行,在函数中知道这个dim被传递时会发生什么。
通过将这两者组合起来,可以看出dim在函数中的作用。 我举两个例子说明上面的第二点。
示例torch.argmax ) )函数的dim表示该维将消失。
这个消失是什么意思? 英语的解释是dim(int(thedimensiontoreduce )。
我们知道argmax是获得最大值的序号的索引。 对于维(d0,d1 )的矩阵,我想求出每行中最大数目的该行的列编号。 最后得到的是维度) d0,1 )的矩阵。 此时,列将消失。
因此,如果希望请求每行的最大列标签,请指定dim=1。 这表示我们不列入列,留下行的size就可以了。
如果想求出每行最大的行标记,可以指定dim=0,表示没有行。
example 1
importtorcha=torch.tensor ([ 1,5,5,2 ]、[9,- 6,2,8 ]、[-3,7,- 9,1 ] ) b=torch.argmax(a ) a
tensor (1,2,0,1 ) (torch.size ) ) 3,4 ) ) dim=0维中为3,也就是说,在这3组数据中进行比较,求出各行中最大的标号,因此(1,2,2 )
example2是三维坐标
importtorcha=torch.tensor [ [ 1,5,5,2 ]、[9,- 6,2,8 ]、[-3,7,- 9,1 ]、[-1,7,- 5,2 ] dim=1 ) )因此,为了在两组中进行比较,需要将上下两个[3*4]矩阵在分别对应的位置设定为b=Torch.argmax(a,dim=1) " " tensor ([ 1,2,0,0,4 ] ) " 因为垂直压缩为一维,所以可以使用[ 1,2,0,1 ]; 同样[ 1,2,2,1 ];' ' ' b=Torch.argmax(a,dim=2) ' ' tensor ([ 2,0,1 ],[ 1,0,2 ] ) ' #dim=2,矩阵维即将消失