首页 > 编程知识 正文

pytorch tensor 维度,pytorch余弦相似度

时间:2023-05-03 17:46:40 阅读:208565 作者:3335

pytorch中数组维度理解与numpy中类似,pytorch中维度用dim表示,numpy中用axis表示
这里主要想说下维度的变化。
dim = x ,表示在第x为上进行操作,那个维度会发生变化。

一、二维数组 1. 两个二维数组的拼接

维度为(2,3)与(2,4)的数组拼接后的维度是(2,7)

import torcha = torch.Tensor(np.arange(6).reshape(2,3))b = torch.Tensor(np.arange(8).reshape(2,4))print(a,'n ',a.shape)print(b,'n',b.shape)c = torch.cat((a,b),dim = 1)print('concatenate:n',c,'n',c.shape)

结果

tensor([[0., 1., 2.], [3., 4., 5.]]) a: torch.Size([2, 3])tensor([[0., 1., 2., 3.], [4., 5., 6., 7.]]) torch.Size([2, 4])concatenate: tensor([[0., 1., 2., 0., 1., 2., 3.], [3., 4., 5., 4., 5., 6., 7.]]) torch.Size([2, 7]) 2. 二维数组求sum、max等

dim = 0,第一个维度划掉,得到一个一维向量。比如,a是(2,3),dim = 0,得到的结果是(3,)维的;如果dim=1,得到的结果是(2,)

print('sum dim=0',torch.sum(a,dim=0))print('sum dim=1',torch.sum(a,dim=1))print('******* max *****')print('max dim=0',torch.max(a,dim=0))print('max dim=1',torch.max(a,dim=1))

输出

tensor([[0., 1., 2.], [3., 4., 5.]]) torch.Size([2, 3])sum dim=0 tensor([3., 5., 7.])sum dim=1 tensor([ 3., 12.])******* max *****max dim=0 torch.return_types.max(values=tensor([3., 4., 5.]),indices=tensor([1, 1, 1]))max dim=1 torch.return_types.max(values=tensor([2., 5.]),indices=tensor([2, 2])) 二、三维数组 1. 两个三维数组的拼接

两个三位数组拼接,有个要求,除了dim维,其余维的维度要相同。

比如 a是(2,3,4),b是(3,2,4)那么a与b无论在哪个维上都不能拼接。因为它们没有两个相同的维度。如果a与b维度相同,都是(2,3,4),那么他们无论在哪个维上都可以拼接。dim = 0,结果是(4,3,4),dim = 1,结果是(2,6,4),dim =2,结果是(2,3,8)dim = x,就将两个数组dim维上的数字相加,得到最终输出维度。 a = torch.Tensor(np.arange(24).reshape(2,3,4))b = torch.Tensor(np.arange(24,48).reshape(2,3,4))print(a,'n ',a.shape)print(b,'n',b.shape)c = torch.cat((a,b),dim = 2)print('concatenate:n',c,'n',c.shape)

输出结果

tensor([[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], [[12., 13., 14., 15.], [16., 17., 18., 19.], [20., 21., 22., 23.]]]) torch.Size([2, 3, 4])tensor([[[24., 25., 26., 27.], [28., 29., 30., 31.], [32., 33., 34., 35.]], [[36., 37., 38., 39.], [40., 41., 42., 43.], [44., 45., 46., 47.]]]) torch.Size([2, 3, 4])concatenate: tensor([[[ 0., 1., 2., 3., 24., 25., 26., 27.], [ 4., 5., 6., 7., 28., 29., 30., 31.], [ 8., 9., 10., 11., 32., 33., 34., 35.]], [[12., 13., 14., 15., 36., 37., 38., 39.], [16., 17., 18., 19., 40., 41., 42., 43.], [20., 21., 22., 23., 44., 45., 46., 47.]]]) torch.Size([2, 3, 8]) 2. 三维数组求sum、max等 类似于二维数组,会消去dim维度shape=(2,3,4)的数组,在dim=0上求和或者取最大后,结果的shape = (3,4)pytorch求max,同时返回两个值(max,indices) a = torch.Tensor(np.arange(24).reshape(2,3,4))print(a,'n',a.shape)print('sum dim=0',torch.sum(a,dim=0))print('sum dim=1',torch.sum(a,dim=1))print('sum dim=2',torch.sum(a,dim=2))print('******* max *****')print('max dim=0',torch.max(a,dim=0))print('max dim=1',torch.max(a,dim=1))print('max dim=2',torch.max(a,dim=2))

结果

tensor([[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], [[12., 13., 14., 15.], [16., 17., 18., 19.], [20., 21., 22., 23.]]]) torch.Size([2, 3, 4])sum dim=0 tensor([[12., 14., 16., 18.], [20., 22., 24., 26.], [28., 30., 32., 34.]])sum dim=1 tensor([[12., 15., 18., 21.], [48., 51., 54., 57.]])sum dim=2 tensor([[ 6., 22., 38.], [54., 70., 86.]])******* max *****max dim=0 torch.return_types.max(values=tensor([[12., 13., 14., 15.], [16., 17., 18., 19.], [20., 21., 22., 23.]]),indices=tensor([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]))max dim=1 torch.return_types.max(values=tensor([[ 8., 9., 10., 11.], [20., 21., 22., 23.]]),indices=tensor([[2, 2, 2, 2], [2, 2, 2, 2]]))max dim=2 torch.return_types.max(values=tensor([[ 3., 7., 11.], [15., 19., 23.]]),indices=tensor([[3, 3, 3], [3, 3, 3]]))

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。